Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
564a898a
Unverified
Commit
564a898a
authored
Jul 13, 2024
by
Liangsheng Yin
Committed by
GitHub
Jul 13, 2024
Browse files
Optimize mem indices mangement (#619)
parent
5d264a90
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
254 additions
and
181 deletions
+254
-181
benchmark/latency_throughput/bench_one.py
benchmark/latency_throughput/bench_one.py
+6
-3
python/sglang/backend/runtime_endpoint.py
python/sglang/backend/runtime_endpoint.py
+14
-4
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+0
-1
python/sglang/global_config.py
python/sglang/global_config.py
+1
-0
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+2
-2
python/sglang/lang/ir.py
python/sglang/lang/ir.py
+3
-3
python/sglang/srt/managers/controller/cuda_graph_runner.py
python/sglang/srt/managers/controller/cuda_graph_runner.py
+36
-12
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+28
-18
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+15
-4
python/sglang/srt/managers/controller/radix_cache.py
python/sglang/srt/managers/controller/radix_cache.py
+2
-1
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+3
-1
python/sglang/srt/memory_pool.py
python/sglang/srt/memory_pool.py
+16
-16
python/sglang/srt/models/minicpm.py
python/sglang/srt/models/minicpm.py
+1
-8
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+126
-107
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+1
-1
No files found.
benchmark/latency_throughput/bench_one.py
View file @
564a898a
...
@@ -17,7 +17,8 @@ def run_one_batch_size(bs):
...
@@ -17,7 +17,8 @@ def run_one_batch_size(bs):
if
args
.
input_len
:
if
args
.
input_len
:
input_ids
=
[
input_ids
=
[
[
int
(
x
)
for
x
in
np
.
random
.
randint
(
0
,
high
=
16384
,
size
=
(
args
.
input_len
,))]
for
_
in
range
(
bs
)
[
int
(
x
)
for
x
in
np
.
random
.
randint
(
0
,
high
=
16384
,
size
=
(
args
.
input_len
,))]
for
_
in
range
(
bs
)
]
]
else
:
else
:
text
=
[
f
"
{
i
,
}
"
for
i
in
range
(
bs
)]
text
=
[
f
"
{
i
,
}
"
for
i
in
range
(
bs
)]
...
@@ -116,9 +117,11 @@ if __name__ == "__main__":
...
@@ -116,9 +117,11 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--backend"
,
type
=
str
,
default
=
"srt"
)
parser
.
add_argument
(
"--backend"
,
type
=
str
,
default
=
"srt"
)
parser
.
add_argument
(
"--input-len"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--input-len"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
nargs
=
'*'
,
default
=
[
1
])
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
nargs
=
"*"
,
default
=
[
1
])
parser
.
add_argument
(
"--max-tokens"
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
"--max-tokens"
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
"--vllm-model-name"
,
type
=
str
,
default
=
"meta-llama/Meta-Llama-3-70B"
)
parser
.
add_argument
(
"--vllm-model-name"
,
type
=
str
,
default
=
"meta-llama/Meta-Llama-3-70B"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
port
is
None
:
if
args
.
port
is
None
:
...
...
python/sglang/backend/runtime_endpoint.py
View file @
564a898a
...
@@ -12,7 +12,6 @@ from sglang.utils import http_request
...
@@ -12,7 +12,6 @@ from sglang.utils import http_request
class
RuntimeEndpoint
(
BaseBackend
):
class
RuntimeEndpoint
(
BaseBackend
):
def
__init__
(
def
__init__
(
self
,
self
,
base_url
:
str
,
base_url
:
str
,
...
@@ -38,7 +37,8 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -38,7 +37,8 @@ class RuntimeEndpoint(BaseBackend):
self
.
model_info
=
res
.
json
()
self
.
model_info
=
res
.
json
()
self
.
chat_template
=
get_chat_template_by_model_path
(
self
.
chat_template
=
get_chat_template_by_model_path
(
self
.
model_info
[
"model_path"
])
self
.
model_info
[
"model_path"
]
)
def
get_model_name
(
self
):
def
get_model_name
(
self
):
return
self
.
model_info
[
"model_path"
]
return
self
.
model_info
[
"model_path"
]
...
@@ -124,7 +124,12 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -124,7 +124,12 @@ class RuntimeEndpoint(BaseBackend):
else
:
else
:
raise
RuntimeError
(
f
"Invalid dtype:
{
sampling_params
.
dtype
}
"
)
raise
RuntimeError
(
f
"Invalid dtype:
{
sampling_params
.
dtype
}
"
)
for
item
in
[
"return_logprob"
,
"logprob_start_len"
,
"top_logprobs_num"
,
"return_text_in_logprobs"
]:
for
item
in
[
"return_logprob"
,
"logprob_start_len"
,
"top_logprobs_num"
,
"return_text_in_logprobs"
,
]:
value
=
getattr
(
sampling_params
,
item
,
None
)
value
=
getattr
(
sampling_params
,
item
,
None
)
if
value
is
not
None
:
if
value
is
not
None
:
data
[
item
]
=
value
data
[
item
]
=
value
...
@@ -171,7 +176,12 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -171,7 +176,12 @@ class RuntimeEndpoint(BaseBackend):
else
:
else
:
raise
RuntimeError
(
f
"Invalid dtype:
{
sampling_params
.
dtype
}
"
)
raise
RuntimeError
(
f
"Invalid dtype:
{
sampling_params
.
dtype
}
"
)
for
item
in
[
"return_logprob"
,
"logprob_start_len"
,
"top_logprobs_num"
,
"return_text_in_logprobs"
]:
for
item
in
[
"return_logprob"
,
"logprob_start_len"
,
"top_logprobs_num"
,
"return_text_in_logprobs"
,
]:
value
=
getattr
(
sampling_params
,
item
,
None
)
value
=
getattr
(
sampling_params
,
item
,
None
)
if
value
is
not
None
:
if
value
is
not
None
:
data
[
item
]
=
value
data
[
item
]
=
value
...
...
python/sglang/bench_latency.py
View file @
564a898a
...
@@ -32,7 +32,6 @@ import logging
...
@@ -32,7 +32,6 @@ import logging
import
multiprocessing
import
multiprocessing
import
time
import
time
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
...
python/sglang/global_config.py
View file @
564a898a
...
@@ -44,4 +44,5 @@ class GlobalConfig:
...
@@ -44,4 +44,5 @@ class GlobalConfig:
# adjust_cache: Adjust the position embedding of KV cache.
# adjust_cache: Adjust the position embedding of KV cache.
self
.
concate_and_append_mode
=
"no_adjust"
self
.
concate_and_append_mode
=
"no_adjust"
global_config
=
GlobalConfig
()
global_config
=
GlobalConfig
()
python/sglang/lang/chat_template.py
View file @
564a898a
...
@@ -84,7 +84,7 @@ register_chat_template(
...
@@ -84,7 +84,7 @@ register_chat_template(
"system"
:
(
"SYSTEM:"
,
"
\n
"
),
"system"
:
(
"SYSTEM:"
,
"
\n
"
),
"user"
:
(
"USER:"
,
"
\n
"
),
"user"
:
(
"USER:"
,
"
\n
"
),
"assistant"
:
(
"ASSISTANT:"
,
"
\n
"
),
"assistant"
:
(
"ASSISTANT:"
,
"
\n
"
),
}
}
,
)
)
)
)
...
@@ -177,7 +177,7 @@ register_chat_template(
...
@@ -177,7 +177,7 @@ register_chat_template(
"assistant"
:
(
""
,
"<|im_end|>
\n
"
),
"assistant"
:
(
""
,
"<|im_end|>
\n
"
),
},
},
style
=
ChatTemplateStyle
.
PLAIN
,
style
=
ChatTemplateStyle
.
PLAIN
,
stop_str
=
(
"<|im_end|>"
,)
stop_str
=
(
"<|im_end|>"
,)
,
)
)
)
)
...
...
python/sglang/lang/ir.py
View file @
564a898a
...
@@ -24,9 +24,9 @@ class SglSamplingParams:
...
@@ -24,9 +24,9 @@ class SglSamplingParams:
presence_penalty
:
float
=
0.0
presence_penalty
:
float
=
0.0
ignore_eos
:
bool
=
False
ignore_eos
:
bool
=
False
return_logprob
:
Optional
[
bool
]
=
None
return_logprob
:
Optional
[
bool
]
=
None
logprob_start_len
:
Optional
[
int
]
=
None
,
logprob_start_len
:
Optional
[
int
]
=
(
None
,
)
top_logprobs_num
:
Optional
[
int
]
=
None
,
top_logprobs_num
:
Optional
[
int
]
=
(
None
,
)
return_text_in_logprobs
:
Optional
[
bool
]
=
None
,
return_text_in_logprobs
:
Optional
[
bool
]
=
(
None
,
)
# for constrained generation, not included in to_xxx_kwargs
# for constrained generation, not included in to_xxx_kwargs
dtype
:
Optional
[
str
]
=
None
dtype
:
Optional
[
str
]
=
None
...
...
python/sglang/srt/managers/controller/cuda_graph_runner.py
View file @
564a898a
...
@@ -8,7 +8,10 @@ from vllm.distributed.parallel_state import graph_capture
...
@@ -8,7 +8,10 @@ from vllm.distributed.parallel_state import graph_capture
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.layers.logits_processor
import
LogitProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitProcessorOutput
from
sglang.srt.managers.controller.infer_batch
import
(
from
sglang.srt.managers.controller.infer_batch
import
(
Batch
,
ForwardMode
,
InputMetadata
,
init_flashinfer_args
Batch
,
ForwardMode
,
InputMetadata
,
init_flashinfer_args
,
)
)
...
@@ -24,18 +27,28 @@ class CudaGraphRunner:
...
@@ -24,18 +27,28 @@ class CudaGraphRunner:
# Common inputs
# Common inputs
self
.
max_bs
=
max_batch_size_to_capture
self
.
max_bs
=
max_batch_size_to_capture
self
.
input_ids
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
input_ids
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
req_pool_indices
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
req_pool_indices
=
torch
.
zeros
(
(
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
seq_lens
=
torch
.
ones
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
seq_lens
=
torch
.
ones
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
position_ids_offsets
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
position_ids_offsets
=
torch
.
zeros
(
self
.
out_cache_loc
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
(
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
out_cache_loc
=
torch
.
zeros
(
(
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# FlashInfer inputs
# FlashInfer inputs
self
.
flashinfer_workspace_buffer
=
self
.
model_runner
.
flashinfer_workspace_buffers
[
0
]
self
.
flashinfer_workspace_buffer
=
(
self
.
model_runner
.
flashinfer_workspace_buffers
[
0
]
)
self
.
flashinfer_kv_indptr
=
torch
.
zeros
(
self
.
flashinfer_kv_indptr
=
torch
.
zeros
(
(
self
.
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
(
self
.
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
)
self
.
flashinfer_kv_indices
=
torch
.
zeros
(
self
.
flashinfer_kv_indices
=
torch
.
zeros
(
(
self
.
max_bs
*
model_runner
.
model_config
.
context_len
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
(
self
.
max_bs
*
model_runner
.
model_config
.
context_len
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
)
self
.
flashinfer_kv_last_page_len
=
torch
.
ones
(
self
.
flashinfer_kv_last_page_len
=
torch
.
ones
(
(
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
(
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
...
@@ -49,7 +62,12 @@ class CudaGraphRunner:
...
@@ -49,7 +62,12 @@ class CudaGraphRunner:
with
graph_capture
()
as
graph_capture_context
:
with
graph_capture
()
as
graph_capture_context
:
self
.
stream
=
graph_capture_context
.
stream
self
.
stream
=
graph_capture_context
.
stream
for
bs
in
batch_size_list
:
for
bs
in
batch_size_list
:
graph
,
input_buffers
,
output_buffers
,
flashinfer_handler
=
self
.
capture_one_batch_size
(
bs
)
(
graph
,
input_buffers
,
output_buffers
,
flashinfer_handler
,
)
=
self
.
capture_one_batch_size
(
bs
)
self
.
graphs
[
bs
]
=
graph
self
.
graphs
[
bs
]
=
graph
self
.
input_buffers
[
bs
]
=
input_buffers
self
.
input_buffers
[
bs
]
=
input_buffers
self
.
output_buffers
[
bs
]
=
output_buffers
self
.
output_buffers
[
bs
]
=
output_buffers
...
@@ -71,17 +89,19 @@ class CudaGraphRunner:
...
@@ -71,17 +89,19 @@ class CudaGraphRunner:
# FlashInfer inputs
# FlashInfer inputs
if
not
_grouped_size_compiled_for_decode_kernels
(
if
not
_grouped_size_compiled_for_decode_kernels
(
self
.
model_runner
.
model_config
.
num_attention_heads
//
self
.
model_runner
.
tp_size
,
self
.
model_runner
.
model_config
.
num_attention_heads
//
self
.
model_runner
.
tp_size
,
self
.
model_runner
.
model_config
.
get_num_kv_heads
(
self
.
model_runner
.
tp_size
),
self
.
model_runner
.
model_config
.
get_num_kv_heads
(
self
.
model_runner
.
tp_size
),
):
):
use_tensor_cores
=
True
use_tensor_cores
=
True
else
:
else
:
use_tensor_cores
=
False
use_tensor_cores
=
False
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
,
"NHD"
,
self
.
flashinfer_workspace_buffer
,
"NHD"
,
use_cuda_graph
=
True
,
use_cuda_graph
=
True
,
use_tensor_cores
=
use_tensor_cores
,
use_tensor_cores
=
use_tensor_cores
,
paged_kv_indptr_buffer
=
self
.
flashinfer_kv_indptr
[:
bs
+
1
],
paged_kv_indptr_buffer
=
self
.
flashinfer_kv_indptr
[:
bs
+
1
],
paged_kv_indices_buffer
=
self
.
flashinfer_kv_indices
,
paged_kv_indices_buffer
=
self
.
flashinfer_kv_indices
,
paged_kv_last_page_len_buffer
=
self
.
flashinfer_kv_last_page_len
[:
bs
],
paged_kv_last_page_len_buffer
=
self
.
flashinfer_kv_last_page_len
[:
bs
],
)
)
...
@@ -163,10 +183,14 @@ class CudaGraphRunner:
...
@@ -163,10 +183,14 @@ class CudaGraphRunner:
else
:
else
:
output
=
LogitProcessorOutput
(
output
=
LogitProcessorOutput
(
next_token_logits
=
output
.
next_token_logits
[:
raw_bs
],
next_token_logits
=
output
.
next_token_logits
[:
raw_bs
],
next_token_logprobs
=
output
.
next_token_logprobs
[:
raw_bs
]
if
output
.
next_token_logprobs
is
not
None
else
None
,
next_token_logprobs
=
output
.
next_token_logprobs
[:
raw_bs
]
if
output
.
next_token_logprobs
is
not
None
else
None
,
normalized_prompt_logprobs
=
None
,
normalized_prompt_logprobs
=
None
,
prefill_token_logprobs
=
None
,
prefill_token_logprobs
=
None
,
prefill_top_logprobs
=
None
,
prefill_top_logprobs
=
None
,
decode_top_logprobs
=
output
.
decode_top_logprobs
[:
raw_bs
]
if
output
.
decode_top_logprobs
is
not
None
else
None
,
decode_top_logprobs
=
output
.
decode_top_logprobs
[:
raw_bs
]
if
output
.
decode_top_logprobs
is
not
None
else
None
,
)
)
return
output
return
output
python/sglang/srt/managers/controller/infer_batch.py
View file @
564a898a
...
@@ -668,7 +668,9 @@ class Batch:
...
@@ -668,7 +668,9 @@ class Batch:
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
warnings
.
warn
(
f
"Ignore errors in sampling:
{
e
}
"
)
warnings
.
warn
(
f
"Ignore errors in sampling:
{
e
}
"
)
sampled_index
=
torch
.
ones
(
probs_sort
.
shape
[:
-
1
]
+
(
1
,),
dtype
=
torch
.
int64
,
device
=
probs
.
device
)
sampled_index
=
torch
.
ones
(
probs_sort
.
shape
[:
-
1
]
+
(
1
,),
dtype
=
torch
.
int64
,
device
=
probs
.
device
)
batch_next_token_ids
=
torch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
batch_next_token_ids
=
torch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
-
1
-
1
)
)
...
@@ -749,8 +751,14 @@ class InputMetadata:
...
@@ -749,8 +751,14 @@ class InputMetadata:
skip_flashinfer_init
=
False
,
skip_flashinfer_init
=
False
,
):
):
if
not
skip_flashinfer_init
and
not
model_runner
.
server_args
.
disable_flashinfer
:
if
not
skip_flashinfer_init
and
not
model_runner
.
server_args
.
disable_flashinfer
:
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
init_flashinfer_args
(
model_runner
.
flashinfer_decode_wrapper
)
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
model_runner
.
flashinfer_decode_wrapper
,
)
batch_size
=
len
(
req_pool_indices
)
batch_size
=
len
(
req_pool_indices
)
...
@@ -807,16 +815,24 @@ class InputMetadata:
...
@@ -807,16 +815,24 @@ class InputMetadata:
)
)
if
model_runner
.
server_args
.
disable_flashinfer
:
if
model_runner
.
server_args
.
disable_flashinfer
:
(
ret
.
triton_max_seq_len
,
(
ret
.
triton_max_extend_len
,
ret
.
triton_max_seq_len
,
ret
.
triton_start_loc
,
ret
.
triton_max_extend_len
,
ret
.
triton_prefix_lens
)
=
init_triton_args
(
forward_mode
,
seq_lens
,
prefix_lens
)
ret
.
triton_start_loc
,
ret
.
triton_prefix_lens
,
)
=
init_triton_args
(
forward_mode
,
seq_lens
,
prefix_lens
)
return
ret
return
ret
def
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
def
init_flashinfer_args
(
flashinfer_decode_wrapper
):
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
flashinfer_decode_wrapper
,
):
num_qo_heads
=
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
num_qo_heads
=
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
)
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
)
head_dim
=
model_runner
.
model_config
.
head_dim
head_dim
=
model_runner
.
model_config
.
head_dim
...
@@ -827,9 +843,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
...
@@ -827,9 +843,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
else
:
else
:
paged_kernel_lens
=
prefix_lens
paged_kernel_lens
=
prefix_lens
kv_indptr
=
torch
.
zeros
(
kv_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
(
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
req_pool_indices_cpu
=
req_pool_indices
.
cpu
().
numpy
()
req_pool_indices_cpu
=
req_pool_indices
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
...
@@ -842,9 +856,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
...
@@ -842,9 +856,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
],
],
dim
=
0
,
dim
=
0
,
).
contiguous
()
).
contiguous
()
kv_last_page_len
=
torch
.
ones
(
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
forward_mode
==
ForwardMode
.
DECODE
:
if
forward_mode
==
ForwardMode
.
DECODE
:
flashinfer_decode_wrapper
.
end_forward
()
flashinfer_decode_wrapper
.
end_forward
()
...
@@ -859,9 +871,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
...
@@ -859,9 +871,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
)
)
else
:
else
:
# extend part
# extend part
qo_indptr
=
torch
.
zeros
(
qo_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
(
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
model_runner
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
564a898a
...
@@ -16,7 +16,12 @@ from vllm.model_executor.model_loader import get_model
...
@@ -16,7 +16,12 @@ from vllm.model_executor.model_loader import get_model
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.managers.controller.infer_batch
import
Batch
,
ForwardMode
,
InputMetadata
,
global_server_args_dict
from
sglang.srt.managers.controller.infer_batch
import
(
Batch
,
ForwardMode
,
InputMetadata
,
global_server_args_dict
,
)
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
...
@@ -83,7 +88,9 @@ class ModelRunner:
...
@@ -83,7 +88,9 @@ class ModelRunner:
# Set some global args
# Set some global args
global_server_args_dict
[
"disable_flashinfer"
]
=
server_args
.
disable_flashinfer
global_server_args_dict
[
"disable_flashinfer"
]
=
server_args
.
disable_flashinfer
global_server_args_dict
[
"attention_reduce_in_fp32"
]
=
server_args
.
attention_reduce_in_fp32
global_server_args_dict
[
"attention_reduce_in_fp32"
]
=
server_args
.
attention_reduce_in_fp32
# Load the model and create memory pool
# Load the model and create memory pool
self
.
load_model
()
self
.
load_model
()
...
@@ -217,7 +224,9 @@ class ModelRunner:
...
@@ -217,7 +224,9 @@ class ModelRunner:
self
.
flashinfer_workspace_buffers
[
1
],
"NHD"
self
.
flashinfer_workspace_buffers
[
1
],
"NHD"
)
)
self
.
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffers
[
0
],
"NHD"
,
use_tensor_cores
=
use_tensor_cores
self
.
flashinfer_workspace_buffers
[
0
],
"NHD"
,
use_tensor_cores
=
use_tensor_cores
,
)
)
def
init_cuda_graphs
(
self
):
def
init_cuda_graphs
(
self
):
...
@@ -229,7 +238,9 @@ class ModelRunner:
...
@@ -229,7 +238,9 @@ class ModelRunner:
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Capture cuda graph begin."
)
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Capture cuda graph begin."
)
batch_size_list
=
[
1
,
2
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
16
)]
batch_size_list
=
[
1
,
2
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
16
)]
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
,
max_batch_size_to_capture
=
max
(
batch_size_list
))
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
,
max_batch_size_to_capture
=
max
(
batch_size_list
)
)
self
.
cuda_graph_runner
.
capture
(
batch_size_list
)
self
.
cuda_graph_runner
.
capture
(
batch_size_list
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
...
python/sglang/srt/managers/controller/radix_cache.py
View file @
564a898a
...
@@ -125,7 +125,8 @@ class RadixCache:
...
@@ -125,7 +125,8 @@ class RadixCache:
if
x
.
lock_ref
>
0
:
if
x
.
lock_ref
>
0
:
continue
continue
num_evicted
+=
evict_callback
(
x
.
value
)
evict_callback
(
x
.
value
)
num_evicted
+=
len
(
x
.
value
)
self
.
_delete_leaf
(
x
)
self
.
_delete_leaf
(
x
)
if
len
(
x
.
parent
.
children
)
==
0
:
if
len
(
x
.
parent
.
children
)
==
0
:
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
564a898a
...
@@ -314,7 +314,9 @@ class ModelTpServer:
...
@@ -314,7 +314,9 @@ class ModelTpServer:
self
.
forward_queue
.
append
(
req
)
self
.
forward_queue
.
append
(
req
)
def
get_new_fill_batch
(
self
)
->
Optional
[
Batch
]:
def
get_new_fill_batch
(
self
)
->
Optional
[
Batch
]:
running_bs
=
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
is
not
None
else
0
running_bs
=
(
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
is
not
None
else
0
)
if
running_bs
>=
self
.
max_running_requests
:
if
running_bs
>=
self
.
max_running_requests
:
return
return
...
...
python/sglang/srt/memory_pool.py
View file @
564a898a
...
@@ -39,10 +39,12 @@ class ReqToTokenPool:
...
@@ -39,10 +39,12 @@ class ReqToTokenPool:
class
TokenToKVPool
:
class
TokenToKVPool
:
def
__init__
(
self
,
size
,
dtype
,
head_num
,
head_dim
,
layer_num
):
def
__init__
(
self
,
size
,
dtype
,
head_num
,
head_dim
,
layer_num
):
self
.
size
=
size
self
.
size
=
size
# mem_state is the reference counter.
# This can be promised:
# assert torch.all(mem_state <= 1) and torch.all(mem_state >= 0)
# We also add one slot. This slot is used for writing dummy output from padded tokens.
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self
.
mem_state
=
torch
.
zeros
((
self
.
size
+
1
,),
dtype
=
torch
.
int16
,
device
=
"cuda"
)
self
.
mem_state
=
torch
.
zeros
((
self
.
size
+
1
,),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
self
.
total_ref_ct
=
0
self
.
total_size
=
self
.
size
self
.
total_alloc
=
0
# [size, key/value, head_num, head_dim] for each layer
# [size, key/value, head_num, head_dim] for each layer
self
.
kv_data
=
[
self
.
kv_data
=
[
...
@@ -71,7 +73,9 @@ class TokenToKVPool:
...
@@ -71,7 +73,9 @@ class TokenToKVPool:
addition_size
=
need_size
-
buffer_len
addition_size
=
need_size
-
buffer_len
alloc_size
=
max
(
addition_size
,
self
.
prefetch_chunk_size
)
alloc_size
=
max
(
addition_size
,
self
.
prefetch_chunk_size
)
select_index
=
torch
.
nonzero
(
self
.
mem_state
==
0
).
squeeze
(
1
)[:
alloc_size
].
to
(
torch
.
int32
)
select_index
=
(
torch
.
nonzero
(
self
.
mem_state
==
0
).
squeeze
(
1
)[:
alloc_size
].
to
(
torch
.
int32
)
)
if
select_index
.
shape
[
0
]
<
addition_size
:
if
select_index
.
shape
[
0
]
<
addition_size
:
return
None
return
None
...
@@ -105,26 +109,22 @@ class TokenToKVPool:
...
@@ -105,26 +109,22 @@ class TokenToKVPool:
return
select_index
.
to
(
torch
.
int32
),
start_loc
,
start_loc
+
need_size
return
select_index
.
to
(
torch
.
int32
),
start_loc
,
start_loc
+
need_size
def
used_size
(
self
):
def
used_size
(
self
):
return
len
(
torch
.
nonzero
(
self
.
mem_state
).
squeeze
(
1
))
return
self
.
total_alloc
def
available_size
(
self
):
def
available_size
(
self
):
return
torch
.
sum
(
self
.
mem_state
==
0
).
item
()
+
len
(
self
.
prefetch_buffer
)
return
self
.
total_size
-
self
.
total_alloc
+
len
(
self
.
prefetch_buffer
)
def
add_refs
(
self
,
token_index
:
torch
.
Tensor
):
def
add_refs
(
self
,
token_index
:
torch
.
Tensor
):
self
.
total_
ref_ct
+=
len
(
token_index
)
self
.
total_
alloc
+=
len
(
token_index
)
self
.
mem_state
[
token_index
]
+
=
1
self
.
mem_state
[
token_index
]
^
=
True
def
dec_refs
(
self
,
token_index
:
torch
.
Tensor
):
def
dec_refs
(
self
,
token_index
:
torch
.
Tensor
):
self
.
total_ref_ct
-=
len
(
token_index
)
self
.
total_alloc
-=
len
(
token_index
)
self
.
mem_state
[
token_index
]
-=
1
self
.
mem_state
[
token_index
]
^=
True
num_freed
=
torch
.
sum
(
self
.
mem_state
[
token_index
]
==
0
)
return
num_freed
def
clear
(
self
):
def
clear
(
self
):
self
.
mem_state
.
fill_
(
0
)
self
.
mem_state
.
fill_
(
0
)
self
.
total_
ref_ct
=
0
self
.
total_
alloc
=
0
# We also add one slot. This slot is used for writing dummy output from padded tokens.
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self
.
add_refs
(
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
))
self
.
mem_state
[
0
]
=
True
python/sglang/srt/models/minicpm.py
View file @
564a898a
...
@@ -5,12 +5,9 @@ from typing import Any, Dict, Iterable, Optional, Tuple
...
@@ -5,12 +5,9 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
...
@@ -31,7 +28,6 @@ from sglang.srt.managers.controller.model_runner import InputMetadata
...
@@ -31,7 +28,6 @@ from sglang.srt.managers.controller.model_runner import InputMetadata
class
MiniCPMMLP
(
nn
.
Module
):
class
MiniCPMMLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
hidden_size
:
int
,
hidden_size
:
int
,
...
@@ -67,7 +63,6 @@ class MiniCPMMLP(nn.Module):
...
@@ -67,7 +63,6 @@ class MiniCPMMLP(nn.Module):
class
MiniCPMAttention
(
nn
.
Module
):
class
MiniCPMAttention
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
hidden_size
:
int
,
hidden_size
:
int
,
...
@@ -152,7 +147,6 @@ class MiniCPMAttention(nn.Module):
...
@@ -152,7 +147,6 @@ class MiniCPMAttention(nn.Module):
class
MiniCPMDecoderLayer
(
nn
.
Module
):
class
MiniCPMDecoderLayer
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
config
,
config
,
...
@@ -217,7 +211,6 @@ class MiniCPMDecoderLayer(nn.Module):
...
@@ -217,7 +211,6 @@ class MiniCPMDecoderLayer(nn.Module):
class
MiniCPMModel
(
nn
.
Module
):
class
MiniCPMModel
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
config
,
config
,
...
@@ -274,7 +267,7 @@ class MiniCPMForCausalLM(nn.Module):
...
@@ -274,7 +267,7 @@ class MiniCPMForCausalLM(nn.Module):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
0
)
self
.
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
0
)
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
MiniCPMModel
(
config
,
quant_config
=
quant_config
)
self
.
model
=
MiniCPMModel
(
config
,
quant_config
=
quant_config
)
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
564a898a
...
@@ -8,24 +8,28 @@ import torch
...
@@ -8,24 +8,28 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
from
vllm.distributed
import
(
tensor_model_parallel_all_reduce
)
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
MergedColumnParallelLinear
,
ReplicatedLinear
,
QKVParallelLinear
,
RowParallelLinear
)
ReplicatedLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
...
@@ -34,8 +38,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
...
@@ -34,8 +38,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.controller.model_runner
import
InputMetadata
from
sglang.srt.managers.controller.model_runner
import
InputMetadata
class
Qwen2MoeMLP
(
nn
.
Module
):
class
Qwen2MoeMLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
hidden_size
:
int
,
hidden_size
:
int
,
...
@@ -46,17 +50,20 @@ class Qwen2MoeMLP(nn.Module):
...
@@ -46,17 +50,20 @@ class Qwen2MoeMLP(nn.Module):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
reduce_results
=
reduce_results
,
hidden_size
,
)
bias
=
False
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
raise
ValueError
(
"Only silu is supported for now."
)
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -67,7 +74,6 @@ class Qwen2MoeMLP(nn.Module):
...
@@ -67,7 +74,6 @@ class Qwen2MoeMLP(nn.Module):
class
Qwen2MoeSparseMoeBlock
(
nn
.
Module
):
class
Qwen2MoeSparseMoeBlock
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
...
@@ -79,20 +85,22 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
...
@@ -79,20 +85,22 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
if
self
.
tp_size
>
config
.
num_experts
:
if
self
.
tp_size
>
config
.
num_experts
:
raise
ValueError
(
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
config
.
num_experts
}
."
)
f
"the number of experts
{
config
.
num_experts
}
."
)
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
num_experts
,
top_k
=
config
.
num_experts_per_tok
,
self
.
experts
=
FusedMoE
(
hidden_size
=
config
.
hidden_size
,
num_experts
=
config
.
num_experts
,
intermediate_size
=
config
.
moe_intermediate_size
,
top_k
=
config
.
num_experts_per_tok
,
reduce_results
=
False
,
hidden_size
=
config
.
hidden_size
,
renormalize
=
config
.
norm_topk_prob
,
intermediate_size
=
config
.
moe_intermediate_size
,
quant_config
=
quant_config
)
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
quant_config
=
quant_config
,
config
.
num_experts
,
)
bias
=
False
,
quant_config
=
None
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
num_experts
,
bias
=
False
,
quant_config
=
None
)
if
config
.
shared_expert_intermediate_size
>
0
:
if
config
.
shared_expert_intermediate_size
>
0
:
self
.
shared_expert
=
Qwen2MoeMLP
(
self
.
shared_expert
=
Qwen2MoeMLP
(
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
...
@@ -103,9 +111,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
...
@@ -103,9 +111,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
)
)
else
:
else
:
self
.
shared_expert
=
None
self
.
shared_expert
=
None
self
.
shared_expert_gate
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
self
.
shared_expert_gate
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
1
,
bias
=
False
)
1
,
bias
=
False
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
...
@@ -114,24 +120,24 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
...
@@ -114,24 +120,24 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
if
self
.
shared_expert
is
not
None
:
if
self
.
shared_expert
is
not
None
:
shared_output
=
self
.
shared_expert
(
hidden_states
)
shared_output
=
self
.
shared_expert
(
hidden_states
)
if
self
.
shared_expert_gate
is
not
None
:
if
self
.
shared_expert_gate
is
not
None
:
shared_output
=
F
.
sigmoid
(
shared_output
=
(
self
.
shared_expert_gate
(
hidden_states
))
*
shared_output
F
.
sigmoid
(
self
.
shared_expert_gate
(
hidden_states
))
*
shared_output
)
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
final_hidden_states
=
self
.
experts
(
router_logits
=
router_logits
)
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
shared_output
is
not
None
:
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
class
Qwen2MoeAttention
(
nn
.
Module
):
class
Qwen2MoeAttention
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
hidden_size
:
int
,
hidden_size
:
int
,
...
@@ -190,17 +196,19 @@ class Qwen2MoeAttention(nn.Module):
...
@@ -190,17 +196,19 @@ class Qwen2MoeAttention(nn.Module):
base
=
rope_theta
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
)
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
attn
=
RadixAttention
(
self
.
head_dim
,
self
.
num_heads
,
self
.
scaling
,
self
.
head_dim
,
num_kv_heads
=
self
.
num_kv_heads
,
self
.
scaling
,
layer_id
=
layer_id
)
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
)
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
...
@@ -211,7 +219,6 @@ class Qwen2MoeAttention(nn.Module):
...
@@ -211,7 +219,6 @@ class Qwen2MoeAttention(nn.Module):
class
Qwen2MoeDecoderLayer
(
nn
.
Module
):
class
Qwen2MoeDecoderLayer
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
...
@@ -223,8 +230,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
...
@@ -223,8 +230,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
8192
)
self
.
self_attn
=
Qwen2MoeAttention
(
self
.
self_attn
=
Qwen2MoeAttention
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
...
@@ -239,13 +245,13 @@ class Qwen2MoeDecoderLayer(nn.Module):
...
@@ -239,13 +245,13 @@ class Qwen2MoeDecoderLayer(nn.Module):
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
# `mlp_only_layers` in the config.
# `mlp_only_layers` in the config.
mlp_only_layers
=
([]
if
not
hasattr
(
config
,
"mlp_only_layers"
)
else
mlp_only_layers
=
(
config
.
mlp_only_layers
)
[]
if
not
hasattr
(
config
,
"mlp_only_layers"
)
else
config
.
mlp_only_layers
)
if
(
layer_id
not
in
mlp_only_layers
)
and
(
if
(
layer_id
not
in
mlp_only_layers
)
and
(
config
.
num_experts
>
0
and
config
.
num_experts
>
0
and
(
layer_id
+
1
)
%
config
.
decoder_sparse_step
==
0
(
layer_id
+
1
)
%
config
.
decoder_sparse_step
==
0
):
):
self
.
mlp
=
Qwen2MoeSparseMoeBlock
(
config
=
config
,
self
.
mlp
=
Qwen2MoeSparseMoeBlock
(
config
=
config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
else
:
else
:
self
.
mlp
=
Qwen2MoeMLP
(
self
.
mlp
=
Qwen2MoeMLP
(
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
...
@@ -253,10 +259,10 @@ class Qwen2MoeDecoderLayer(nn.Module):
...
@@ -253,10 +259,10 @@ class Qwen2MoeDecoderLayer(nn.Module):
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
eps
=
config
.
rms_norm_eps
)
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -270,23 +276,20 @@ class Qwen2MoeDecoderLayer(nn.Module):
...
@@ -270,23 +276,20 @@ class Qwen2MoeDecoderLayer(nn.Module):
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
input_metadata
=
input_metadata
,
)
)
# Fully Connected
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
return
hidden_states
,
residual
class
Qwen2MoeModel
(
nn
.
Module
):
class
Qwen2MoeModel
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
...
@@ -301,13 +304,14 @@ class Qwen2MoeModel(nn.Module):
...
@@ -301,13 +304,14 @@ class Qwen2MoeModel(nn.Module):
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
)
)
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
(
Qwen2MoeDecoderLayer
(
config
,
[
layer_id
,
Qwen2MoeDecoderLayer
(
cache_config
,
config
,
layer_id
,
cache_config
,
quant_config
=
quant_config
quant_config
=
quant_config
)
)
for
layer_id
in
range
(
config
.
num_hidden_layers
)
for
layer_id
in
range
(
config
.
num_hidden_layers
)
])
]
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
def
forward
(
...
@@ -315,7 +319,7 @@ class Qwen2MoeModel(nn.Module):
...
@@ -315,7 +319,7 @@ class Qwen2MoeModel(nn.Module):
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
input_embeds
:
torch
.
Tensor
=
None
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
...
@@ -324,10 +328,9 @@ class Qwen2MoeModel(nn.Module):
...
@@ -324,10 +328,9 @@ class Qwen2MoeModel(nn.Module):
residual
=
None
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
=
layer
(
hidden_states
,
positions
,
hidden_states
,
input_metadata
,
residual
input_metadata
,
)
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
...
@@ -346,9 +349,9 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -346,9 +349,9 @@ class Qwen2MoeForCausalLM(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
Qwen2MoeModel
(
config
,
cache_config
,
quant_config
)
self
.
model
=
Qwen2MoeModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
hidden_size
,
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
quant_config
=
quant_config
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
@@ -357,17 +360,22 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -357,17 +360,22 @@ class Qwen2MoeForCausalLM(nn.Module):
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
input_embeds
:
torch
.
Tensor
=
None
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
def
compute_logits
(
self
,
input_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
logits
=
self
.
logits_processor
(
input_embeds
)
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
)
input_metadata
)
def
compute_logits
(
self
,
input_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
return
logits
return
logits
def
sample
(
def
sample
(
...
@@ -391,11 +399,18 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -391,11 +399,18 @@ class Qwen2MoeForCausalLM(nn.Module):
expert_params_mapping
=
[
expert_params_mapping
=
[
# These are the weights for the experts
# These are the weights for the experts
# (param_name, weight_name, expert_id, shard_id)
# (param_name, weight_name, expert_id, shard_id)
(
"experts.w13_weight"
if
weight_name
in
[
"gate_proj"
,
"up_proj"
]
(
else
"experts.w2_weight"
,
"experts.w13_weight"
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
,
shard_id
)
if
weight_name
in
[
"gate_proj"
,
"up_proj"
]
for
expert_id
in
range
(
self
.
config
.
num_experts
)
for
shard_id
,
else
"experts.w2_weight"
,
weight_name
in
enumerate
([
"gate_proj"
,
"down_proj"
,
"up_proj"
])
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
,
shard_id
,
)
for
expert_id
in
range
(
self
.
config
.
num_experts
)
for
shard_id
,
weight_name
in
enumerate
(
[
"gate_proj"
,
"down_proj"
,
"up_proj"
]
)
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
...
@@ -433,11 +448,13 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -433,11 +448,13 @@ class Qwen2MoeForCausalLM(nn.Module):
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
weight_loader
(
loaded_weight
,
param
,
weight_name
,
loaded_weight
,
shard_id
=
shard_id
,
weight_name
,
expert_id
=
expert_id
)
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
break
break
else
:
else
:
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
...
@@ -447,8 +464,10 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -447,8 +464,10 @@ class Qwen2MoeForCausalLM(nn.Module):
continue
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
default_weight_loader
)
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
Qwen2MoeForCausalLM
EntryClass
=
Qwen2MoeForCausalLM
python/sglang/srt/utils.py
View file @
564a898a
...
@@ -474,9 +474,9 @@ def monkey_patch_vllm_dummy_weight_loader():
...
@@ -474,9 +474,9 @@ def monkey_patch_vllm_dummy_weight_loader():
DummyModelLoader
,
DummyModelLoader
,
LoRAConfig
,
LoRAConfig
,
ModelConfig
,
ModelConfig
,
MultiModalConfig
,
ParallelConfig
,
ParallelConfig
,
SchedulerConfig
,
SchedulerConfig
,
MultiModalConfig
,
_initialize_model
,
_initialize_model
,
initialize_dummy_weights
,
initialize_dummy_weights
,
nn
,
nn
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment