Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
564a898a
Unverified
Commit
564a898a
authored
Jul 13, 2024
by
Liangsheng Yin
Committed by
GitHub
Jul 13, 2024
Browse files
Optimize mem indices mangement (#619)
parent
5d264a90
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
254 additions
and
181 deletions
+254
-181
benchmark/latency_throughput/bench_one.py
benchmark/latency_throughput/bench_one.py
+6
-3
python/sglang/backend/runtime_endpoint.py
python/sglang/backend/runtime_endpoint.py
+14
-4
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+0
-1
python/sglang/global_config.py
python/sglang/global_config.py
+1
-0
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+2
-2
python/sglang/lang/ir.py
python/sglang/lang/ir.py
+3
-3
python/sglang/srt/managers/controller/cuda_graph_runner.py
python/sglang/srt/managers/controller/cuda_graph_runner.py
+36
-12
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+28
-18
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+15
-4
python/sglang/srt/managers/controller/radix_cache.py
python/sglang/srt/managers/controller/radix_cache.py
+2
-1
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+3
-1
python/sglang/srt/memory_pool.py
python/sglang/srt/memory_pool.py
+16
-16
python/sglang/srt/models/minicpm.py
python/sglang/srt/models/minicpm.py
+1
-8
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+126
-107
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+1
-1
No files found.
benchmark/latency_throughput/bench_one.py
View file @
564a898a
...
...
@@ -17,7 +17,8 @@ def run_one_batch_size(bs):
if
args
.
input_len
:
input_ids
=
[
[
int
(
x
)
for
x
in
np
.
random
.
randint
(
0
,
high
=
16384
,
size
=
(
args
.
input_len
,))]
for
_
in
range
(
bs
)
[
int
(
x
)
for
x
in
np
.
random
.
randint
(
0
,
high
=
16384
,
size
=
(
args
.
input_len
,))]
for
_
in
range
(
bs
)
]
else
:
text
=
[
f
"
{
i
,
}
"
for
i
in
range
(
bs
)]
...
...
@@ -116,9 +117,11 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--backend"
,
type
=
str
,
default
=
"srt"
)
parser
.
add_argument
(
"--input-len"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
nargs
=
'*'
,
default
=
[
1
])
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
nargs
=
"*"
,
default
=
[
1
])
parser
.
add_argument
(
"--max-tokens"
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
"--vllm-model-name"
,
type
=
str
,
default
=
"meta-llama/Meta-Llama-3-70B"
)
parser
.
add_argument
(
"--vllm-model-name"
,
type
=
str
,
default
=
"meta-llama/Meta-Llama-3-70B"
)
args
=
parser
.
parse_args
()
if
args
.
port
is
None
:
...
...
python/sglang/backend/runtime_endpoint.py
View file @
564a898a
...
...
@@ -12,7 +12,6 @@ from sglang.utils import http_request
class
RuntimeEndpoint
(
BaseBackend
):
def
__init__
(
self
,
base_url
:
str
,
...
...
@@ -38,7 +37,8 @@ class RuntimeEndpoint(BaseBackend):
self
.
model_info
=
res
.
json
()
self
.
chat_template
=
get_chat_template_by_model_path
(
self
.
model_info
[
"model_path"
])
self
.
model_info
[
"model_path"
]
)
def
get_model_name
(
self
):
return
self
.
model_info
[
"model_path"
]
...
...
@@ -124,7 +124,12 @@ class RuntimeEndpoint(BaseBackend):
else
:
raise
RuntimeError
(
f
"Invalid dtype:
{
sampling_params
.
dtype
}
"
)
for
item
in
[
"return_logprob"
,
"logprob_start_len"
,
"top_logprobs_num"
,
"return_text_in_logprobs"
]:
for
item
in
[
"return_logprob"
,
"logprob_start_len"
,
"top_logprobs_num"
,
"return_text_in_logprobs"
,
]:
value
=
getattr
(
sampling_params
,
item
,
None
)
if
value
is
not
None
:
data
[
item
]
=
value
...
...
@@ -171,7 +176,12 @@ class RuntimeEndpoint(BaseBackend):
else
:
raise
RuntimeError
(
f
"Invalid dtype:
{
sampling_params
.
dtype
}
"
)
for
item
in
[
"return_logprob"
,
"logprob_start_len"
,
"top_logprobs_num"
,
"return_text_in_logprobs"
]:
for
item
in
[
"return_logprob"
,
"logprob_start_len"
,
"top_logprobs_num"
,
"return_text_in_logprobs"
,
]:
value
=
getattr
(
sampling_params
,
item
,
None
)
if
value
is
not
None
:
data
[
item
]
=
value
...
...
python/sglang/bench_latency.py
View file @
564a898a
...
...
@@ -32,7 +32,6 @@ import logging
import
multiprocessing
import
time
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
...
...
python/sglang/global_config.py
View file @
564a898a
...
...
@@ -44,4 +44,5 @@ class GlobalConfig:
# adjust_cache: Adjust the position embedding of KV cache.
self
.
concate_and_append_mode
=
"no_adjust"
global_config
=
GlobalConfig
()
python/sglang/lang/chat_template.py
View file @
564a898a
...
...
@@ -84,7 +84,7 @@ register_chat_template(
"system"
:
(
"SYSTEM:"
,
"
\n
"
),
"user"
:
(
"USER:"
,
"
\n
"
),
"assistant"
:
(
"ASSISTANT:"
,
"
\n
"
),
}
}
,
)
)
...
...
@@ -177,7 +177,7 @@ register_chat_template(
"assistant"
:
(
""
,
"<|im_end|>
\n
"
),
},
style
=
ChatTemplateStyle
.
PLAIN
,
stop_str
=
(
"<|im_end|>"
,)
stop_str
=
(
"<|im_end|>"
,)
,
)
)
...
...
python/sglang/lang/ir.py
View file @
564a898a
...
...
@@ -24,9 +24,9 @@ class SglSamplingParams:
presence_penalty
:
float
=
0.0
ignore_eos
:
bool
=
False
return_logprob
:
Optional
[
bool
]
=
None
logprob_start_len
:
Optional
[
int
]
=
None
,
top_logprobs_num
:
Optional
[
int
]
=
None
,
return_text_in_logprobs
:
Optional
[
bool
]
=
None
,
logprob_start_len
:
Optional
[
int
]
=
(
None
,
)
top_logprobs_num
:
Optional
[
int
]
=
(
None
,
)
return_text_in_logprobs
:
Optional
[
bool
]
=
(
None
,
)
# for constrained generation, not included in to_xxx_kwargs
dtype
:
Optional
[
str
]
=
None
...
...
python/sglang/srt/managers/controller/cuda_graph_runner.py
View file @
564a898a
...
...
@@ -8,7 +8,10 @@ from vllm.distributed.parallel_state import graph_capture
from
sglang.global_config
import
global_config
from
sglang.srt.layers.logits_processor
import
LogitProcessorOutput
from
sglang.srt.managers.controller.infer_batch
import
(
Batch
,
ForwardMode
,
InputMetadata
,
init_flashinfer_args
Batch
,
ForwardMode
,
InputMetadata
,
init_flashinfer_args
,
)
...
...
@@ -24,18 +27,28 @@ class CudaGraphRunner:
# Common inputs
self
.
max_bs
=
max_batch_size_to_capture
self
.
input_ids
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
req_pool_indices
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
req_pool_indices
=
torch
.
zeros
(
(
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
seq_lens
=
torch
.
ones
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
position_ids_offsets
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
out_cache_loc
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
position_ids_offsets
=
torch
.
zeros
(
(
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
out_cache_loc
=
torch
.
zeros
(
(
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# FlashInfer inputs
self
.
flashinfer_workspace_buffer
=
self
.
model_runner
.
flashinfer_workspace_buffers
[
0
]
self
.
flashinfer_workspace_buffer
=
(
self
.
model_runner
.
flashinfer_workspace_buffers
[
0
]
)
self
.
flashinfer_kv_indptr
=
torch
.
zeros
(
(
self
.
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
flashinfer_kv_indices
=
torch
.
zeros
(
(
self
.
max_bs
*
model_runner
.
model_config
.
context_len
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
(
self
.
max_bs
*
model_runner
.
model_config
.
context_len
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
self
.
flashinfer_kv_last_page_len
=
torch
.
ones
(
(
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
...
...
@@ -49,7 +62,12 @@ class CudaGraphRunner:
with
graph_capture
()
as
graph_capture_context
:
self
.
stream
=
graph_capture_context
.
stream
for
bs
in
batch_size_list
:
graph
,
input_buffers
,
output_buffers
,
flashinfer_handler
=
self
.
capture_one_batch_size
(
bs
)
(
graph
,
input_buffers
,
output_buffers
,
flashinfer_handler
,
)
=
self
.
capture_one_batch_size
(
bs
)
self
.
graphs
[
bs
]
=
graph
self
.
input_buffers
[
bs
]
=
input_buffers
self
.
output_buffers
[
bs
]
=
output_buffers
...
...
@@ -71,17 +89,19 @@ class CudaGraphRunner:
# FlashInfer inputs
if
not
_grouped_size_compiled_for_decode_kernels
(
self
.
model_runner
.
model_config
.
num_attention_heads
//
self
.
model_runner
.
tp_size
,
self
.
model_runner
.
model_config
.
num_attention_heads
//
self
.
model_runner
.
tp_size
,
self
.
model_runner
.
model_config
.
get_num_kv_heads
(
self
.
model_runner
.
tp_size
),
):
use_tensor_cores
=
True
else
:
use_tensor_cores
=
False
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
,
"NHD"
,
self
.
flashinfer_workspace_buffer
,
"NHD"
,
use_cuda_graph
=
True
,
use_tensor_cores
=
use_tensor_cores
,
paged_kv_indptr_buffer
=
self
.
flashinfer_kv_indptr
[:
bs
+
1
],
paged_kv_indptr_buffer
=
self
.
flashinfer_kv_indptr
[:
bs
+
1
],
paged_kv_indices_buffer
=
self
.
flashinfer_kv_indices
,
paged_kv_last_page_len_buffer
=
self
.
flashinfer_kv_last_page_len
[:
bs
],
)
...
...
@@ -163,10 +183,14 @@ class CudaGraphRunner:
else
:
output
=
LogitProcessorOutput
(
next_token_logits
=
output
.
next_token_logits
[:
raw_bs
],
next_token_logprobs
=
output
.
next_token_logprobs
[:
raw_bs
]
if
output
.
next_token_logprobs
is
not
None
else
None
,
next_token_logprobs
=
output
.
next_token_logprobs
[:
raw_bs
]
if
output
.
next_token_logprobs
is
not
None
else
None
,
normalized_prompt_logprobs
=
None
,
prefill_token_logprobs
=
None
,
prefill_top_logprobs
=
None
,
decode_top_logprobs
=
output
.
decode_top_logprobs
[:
raw_bs
]
if
output
.
decode_top_logprobs
is
not
None
else
None
,
decode_top_logprobs
=
output
.
decode_top_logprobs
[:
raw_bs
]
if
output
.
decode_top_logprobs
is
not
None
else
None
,
)
return
output
python/sglang/srt/managers/controller/infer_batch.py
View file @
564a898a
...
...
@@ -668,7 +668,9 @@ class Batch:
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
except
RuntimeError
as
e
:
warnings
.
warn
(
f
"Ignore errors in sampling:
{
e
}
"
)
sampled_index
=
torch
.
ones
(
probs_sort
.
shape
[:
-
1
]
+
(
1
,),
dtype
=
torch
.
int64
,
device
=
probs
.
device
)
sampled_index
=
torch
.
ones
(
probs_sort
.
shape
[:
-
1
]
+
(
1
,),
dtype
=
torch
.
int64
,
device
=
probs
.
device
)
batch_next_token_ids
=
torch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
-
1
)
...
...
@@ -749,8 +751,14 @@ class InputMetadata:
skip_flashinfer_init
=
False
,
):
if
not
skip_flashinfer_init
and
not
model_runner
.
server_args
.
disable_flashinfer
:
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
model_runner
.
flashinfer_decode_wrapper
)
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
model_runner
.
flashinfer_decode_wrapper
,
)
batch_size
=
len
(
req_pool_indices
)
...
...
@@ -807,16 +815,24 @@ class InputMetadata:
)
if
model_runner
.
server_args
.
disable_flashinfer
:
(
ret
.
triton_max_seq_len
,
ret
.
triton_max_extend_len
,
ret
.
triton_start_loc
,
ret
.
triton_prefix_lens
)
=
init_triton_args
(
forward_mode
,
seq_lens
,
prefix_lens
)
(
ret
.
triton_max_seq_len
,
ret
.
triton_max_extend_len
,
ret
.
triton_start_loc
,
ret
.
triton_prefix_lens
,
)
=
init_triton_args
(
forward_mode
,
seq_lens
,
prefix_lens
)
return
ret
def
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
flashinfer_decode_wrapper
):
def
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
flashinfer_decode_wrapper
,
):
num_qo_heads
=
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
)
head_dim
=
model_runner
.
model_config
.
head_dim
...
...
@@ -827,9 +843,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
else
:
paged_kernel_lens
=
prefix_lens
kv_indptr
=
torch
.
zeros
(
(
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
req_pool_indices_cpu
=
req_pool_indices
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
...
...
@@ -842,9 +856,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
],
dim
=
0
,
).
contiguous
()
kv_last_page_len
=
torch
.
ones
(
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
forward_mode
==
ForwardMode
.
DECODE
:
flashinfer_decode_wrapper
.
end_forward
()
...
...
@@ -859,9 +871,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
)
else
:
# extend part
qo_indptr
=
torch
.
zeros
(
(
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
model_runner
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
564a898a
...
...
@@ -16,7 +16,12 @@ from vllm.model_executor.model_loader import get_model
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.global_config
import
global_config
from
sglang.srt.managers.controller.infer_batch
import
Batch
,
ForwardMode
,
InputMetadata
,
global_server_args_dict
from
sglang.srt.managers.controller.infer_batch
import
(
Batch
,
ForwardMode
,
InputMetadata
,
global_server_args_dict
,
)
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
...
...
@@ -83,7 +88,9 @@ class ModelRunner:
# Set some global args
global_server_args_dict
[
"disable_flashinfer"
]
=
server_args
.
disable_flashinfer
global_server_args_dict
[
"attention_reduce_in_fp32"
]
=
server_args
.
attention_reduce_in_fp32
global_server_args_dict
[
"attention_reduce_in_fp32"
]
=
server_args
.
attention_reduce_in_fp32
# Load the model and create memory pool
self
.
load_model
()
...
...
@@ -217,7 +224,9 @@ class ModelRunner:
self
.
flashinfer_workspace_buffers
[
1
],
"NHD"
)
self
.
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffers
[
0
],
"NHD"
,
use_tensor_cores
=
use_tensor_cores
self
.
flashinfer_workspace_buffers
[
0
],
"NHD"
,
use_tensor_cores
=
use_tensor_cores
,
)
def
init_cuda_graphs
(
self
):
...
...
@@ -229,7 +238,9 @@ class ModelRunner:
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Capture cuda graph begin."
)
batch_size_list
=
[
1
,
2
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
16
)]
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
,
max_batch_size_to_capture
=
max
(
batch_size_list
))
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
,
max_batch_size_to_capture
=
max
(
batch_size_list
)
)
self
.
cuda_graph_runner
.
capture
(
batch_size_list
)
@
torch
.
inference_mode
()
...
...
python/sglang/srt/managers/controller/radix_cache.py
View file @
564a898a
...
...
@@ -125,7 +125,8 @@ class RadixCache:
if
x
.
lock_ref
>
0
:
continue
num_evicted
+=
evict_callback
(
x
.
value
)
evict_callback
(
x
.
value
)
num_evicted
+=
len
(
x
.
value
)
self
.
_delete_leaf
(
x
)
if
len
(
x
.
parent
.
children
)
==
0
:
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
564a898a
...
...
@@ -314,7 +314,9 @@ class ModelTpServer:
self
.
forward_queue
.
append
(
req
)
def
get_new_fill_batch
(
self
)
->
Optional
[
Batch
]:
running_bs
=
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
is
not
None
else
0
running_bs
=
(
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
is
not
None
else
0
)
if
running_bs
>=
self
.
max_running_requests
:
return
...
...
python/sglang/srt/memory_pool.py
View file @
564a898a
...
...
@@ -39,10 +39,12 @@ class ReqToTokenPool:
class
TokenToKVPool
:
def
__init__
(
self
,
size
,
dtype
,
head_num
,
head_dim
,
layer_num
):
self
.
size
=
size
# mem_state is the reference counter.
# This can be promised:
# assert torch.all(mem_state <= 1) and torch.all(mem_state >= 0)
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self
.
mem_state
=
torch
.
zeros
((
self
.
size
+
1
,),
dtype
=
torch
.
int16
,
device
=
"cuda"
)
self
.
total_ref_ct
=
0
self
.
mem_state
=
torch
.
zeros
((
self
.
size
+
1
,),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
self
.
total_size
=
self
.
size
self
.
total_alloc
=
0
# [size, key/value, head_num, head_dim] for each layer
self
.
kv_data
=
[
...
...
@@ -71,7 +73,9 @@ class TokenToKVPool:
addition_size
=
need_size
-
buffer_len
alloc_size
=
max
(
addition_size
,
self
.
prefetch_chunk_size
)
select_index
=
torch
.
nonzero
(
self
.
mem_state
==
0
).
squeeze
(
1
)[:
alloc_size
].
to
(
torch
.
int32
)
select_index
=
(
torch
.
nonzero
(
self
.
mem_state
==
0
).
squeeze
(
1
)[:
alloc_size
].
to
(
torch
.
int32
)
)
if
select_index
.
shape
[
0
]
<
addition_size
:
return
None
...
...
@@ -105,26 +109,22 @@ class TokenToKVPool:
return
select_index
.
to
(
torch
.
int32
),
start_loc
,
start_loc
+
need_size
def
used_size
(
self
):
return
len
(
torch
.
nonzero
(
self
.
mem_state
).
squeeze
(
1
))
return
self
.
total_alloc
def
available_size
(
self
):
return
torch
.
sum
(
self
.
mem_state
==
0
).
item
()
+
len
(
self
.
prefetch_buffer
)
return
self
.
total_size
-
self
.
total_alloc
+
len
(
self
.
prefetch_buffer
)
def
add_refs
(
self
,
token_index
:
torch
.
Tensor
):
self
.
total_
ref_ct
+=
len
(
token_index
)
self
.
mem_state
[
token_index
]
+
=
1
self
.
total_
alloc
+=
len
(
token_index
)
self
.
mem_state
[
token_index
]
^
=
True
def
dec_refs
(
self
,
token_index
:
torch
.
Tensor
):
self
.
total_ref_ct
-=
len
(
token_index
)
self
.
mem_state
[
token_index
]
-=
1
num_freed
=
torch
.
sum
(
self
.
mem_state
[
token_index
]
==
0
)
return
num_freed
self
.
total_alloc
-=
len
(
token_index
)
self
.
mem_state
[
token_index
]
^=
True
def
clear
(
self
):
self
.
mem_state
.
fill_
(
0
)
self
.
total_
ref_ct
=
0
self
.
total_
alloc
=
0
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self
.
add_refs
(
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
))
self
.
mem_state
[
0
]
=
True
python/sglang/srt/models/minicpm.py
View file @
564a898a
...
...
@@ -5,12 +5,9 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import
torch
from
torch
import
nn
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
...
@@ -31,7 +28,6 @@ from sglang.srt.managers.controller.model_runner import InputMetadata
class
MiniCPMMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
...
...
@@ -67,7 +63,6 @@ class MiniCPMMLP(nn.Module):
class
MiniCPMAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
...
...
@@ -152,7 +147,6 @@ class MiniCPMAttention(nn.Module):
class
MiniCPMDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
...
...
@@ -217,7 +211,6 @@ class MiniCPMDecoderLayer(nn.Module):
class
MiniCPMModel
(
nn
.
Module
):
def
__init__
(
self
,
config
,
...
...
@@ -274,7 +267,7 @@ class MiniCPMForCausalLM(nn.Module):
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
0
)
self
.
quant_config
=
quant_config
self
.
model
=
MiniCPMModel
(
config
,
quant_config
=
quant_config
)
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
564a898a
...
...
@@ -8,24 +8,28 @@ import torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
...
...
@@ -34,8 +38,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.controller.model_runner
import
InputMetadata
class
Qwen2MoeMLP
(
nn
.
Module
):
class
Qwen2MoeMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
...
...
@@ -46,17 +50,20 @@ class Qwen2MoeMLP(nn.Module):
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
)
quant_config
=
quant_config
,
reduce_results
=
reduce_results
,
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
...
...
@@ -67,7 +74,6 @@ class Qwen2MoeMLP(nn.Module):
class
Qwen2MoeSparseMoeBlock
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
...
...
@@ -79,20 +85,22 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
if
self
.
tp_size
>
config
.
num_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
config
.
num_experts
}
."
)
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
num_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
num_experts
,
bias
=
False
,
quant_config
=
None
)
f
"the number of experts
{
config
.
num_experts
}
."
)
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
num_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
num_experts
,
bias
=
False
,
quant_config
=
None
)
if
config
.
shared_expert_intermediate_size
>
0
:
self
.
shared_expert
=
Qwen2MoeMLP
(
hidden_size
=
config
.
hidden_size
,
...
...
@@ -103,9 +111,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
)
else
:
self
.
shared_expert
=
None
self
.
shared_expert_gate
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
1
,
bias
=
False
)
self
.
shared_expert_gate
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
1
,
bias
=
False
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
...
...
@@ -114,24 +120,24 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
if
self
.
shared_expert
is
not
None
:
shared_output
=
self
.
shared_expert
(
hidden_states
)
if
self
.
shared_expert_gate
is
not
None
:
shared_output
=
F
.
sigmoid
(
self
.
shared_expert_gate
(
hidden_states
))
*
shared_output
shared_output
=
(
F
.
sigmoid
(
self
.
shared_expert_gate
(
hidden_states
))
*
shared_output
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
class
Qwen2MoeAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
...
...
@@ -190,17 +196,19 @@ class Qwen2MoeAttention(nn.Module):
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
...
...
@@ -211,7 +219,6 @@ class Qwen2MoeAttention(nn.Module):
class
Qwen2MoeDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
...
...
@@ -223,8 +230,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
self_attn
=
Qwen2MoeAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
...
...
@@ -239,13 +245,13 @@ class Qwen2MoeDecoderLayer(nn.Module):
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
# `mlp_only_layers` in the config.
mlp_only_layers
=
([]
if
not
hasattr
(
config
,
"mlp_only_layers"
)
else
config
.
mlp_only_layers
)
mlp_only_layers
=
(
[]
if
not
hasattr
(
config
,
"mlp_only_layers"
)
else
config
.
mlp_only_layers
)
if
(
layer_id
not
in
mlp_only_layers
)
and
(
config
.
num_experts
>
0
and
(
layer_id
+
1
)
%
config
.
decoder_sparse_step
==
0
):
self
.
mlp
=
Qwen2MoeSparseMoeBlock
(
config
=
config
,
quant_config
=
quant_config
)
config
.
num_experts
>
0
and
(
layer_id
+
1
)
%
config
.
decoder_sparse_step
==
0
):
self
.
mlp
=
Qwen2MoeSparseMoeBlock
(
config
=
config
,
quant_config
=
quant_config
)
else
:
self
.
mlp
=
Qwen2MoeMLP
(
hidden_size
=
config
.
hidden_size
,
...
...
@@ -253,10 +259,10 @@ class Qwen2MoeDecoderLayer(nn.Module):
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
...
...
@@ -270,23 +276,20 @@ class Qwen2MoeDecoderLayer(nn.Module):
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
input_metadata
=
input_metadata
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
class
Qwen2MoeModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
...
...
@@ -301,13 +304,14 @@ class Qwen2MoeModel(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
Qwen2MoeDecoderLayer
(
config
,
layer_id
,
cache_config
,
quant_config
=
quant_config
)
for
layer_id
in
range
(
config
.
num_hidden_layers
)
])
self
.
layers
=
nn
.
ModuleList
(
[
Qwen2MoeDecoderLayer
(
config
,
layer_id
,
cache_config
,
quant_config
=
quant_config
)
for
layer_id
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
...
...
@@ -315,7 +319,7 @@ class Qwen2MoeModel(nn.Module):
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_embeds
:
torch
.
Tensor
=
None
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
...
...
@@ -324,10 +328,9 @@ class Qwen2MoeModel(nn.Module):
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
residual
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
...
...
@@ -346,9 +349,9 @@ class Qwen2MoeForCausalLM(nn.Module):
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
Qwen2MoeModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
...
...
@@ -357,17 +360,22 @@ class Qwen2MoeForCausalLM(nn.Module):
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_embeds
:
torch
.
Tensor
=
None
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
def
compute_logits
(
self
,
input_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
def
compute_logits
(
self
,
input_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
logits
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
return
logits
def
sample
(
...
...
@@ -391,11 +399,18 @@ class Qwen2MoeForCausalLM(nn.Module):
expert_params_mapping
=
[
# These are the weights for the experts
# (param_name, weight_name, expert_id, shard_id)
(
"experts.w13_weight"
if
weight_name
in
[
"gate_proj"
,
"up_proj"
]
else
"experts.w2_weight"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
,
shard_id
)
for
expert_id
in
range
(
self
.
config
.
num_experts
)
for
shard_id
,
weight_name
in
enumerate
([
"gate_proj"
,
"down_proj"
,
"up_proj"
])
(
"experts.w13_weight"
if
weight_name
in
[
"gate_proj"
,
"up_proj"
]
else
"experts.w2_weight"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
,
shard_id
,
)
for
expert_id
in
range
(
self
.
config
.
num_experts
)
for
shard_id
,
weight_name
in
enumerate
(
[
"gate_proj"
,
"down_proj"
,
"up_proj"
]
)
]
params_dict
=
dict
(
self
.
named_parameters
())
...
...
@@ -433,11 +448,13 @@ class Qwen2MoeForCausalLM(nn.Module):
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
weight_name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
weight_loader
(
param
,
loaded_weight
,
weight_name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
break
else
:
# Skip loading extra bias for GPTQ models.
...
...
@@ -447,8 +464,10 @@ class Qwen2MoeForCausalLM(nn.Module):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
Qwen2MoeForCausalLM
python/sglang/srt/utils.py
View file @
564a898a
...
...
@@ -474,9 +474,9 @@ def monkey_patch_vllm_dummy_weight_loader():
DummyModelLoader
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
ParallelConfig
,
SchedulerConfig
,
MultiModalConfig
,
_initialize_model
,
initialize_dummy_weights
,
nn
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment