Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
36d5acfc
"docs/vscode:/vscode.git/clone" did not exist on "bdaa130997e28486ac20ebfa21908e9009e8ec97"
Unverified
Commit
36d5acfc
authored
Sep 30, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 30, 2024
Browse files
Rename InputMetadata -> ForwardBatch (#1543)
parent
3f0fe08d
Changes
44
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
201 additions
and
199 deletions
+201
-199
docs/en/model_support.md
docs/en/model_support.md
+1
-1
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+4
-4
python/sglang/srt/layers/attention_backend.py
python/sglang/srt/layers/attention_backend.py
+54
-54
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+15
-15
python/sglang/srt/layers/pooler.py
python/sglang/srt/layers/pooler.py
+3
-3
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+3
-3
python/sglang/srt/lora/lora.py
python/sglang/srt/lora/lora.py
+1
-1
python/sglang/srt/lora/lora_manager.py
python/sglang/srt/lora/lora_manager.py
+7
-7
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+3
-3
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+1
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+10
-10
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+8
-6
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+13
-13
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+3
-3
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+21
-21
python/sglang/srt/models/baichuan.py
python/sglang/srt/models/baichuan.py
+10
-10
python/sglang/srt/models/chatglm.py
python/sglang/srt/models/chatglm.py
+12
-12
python/sglang/srt/models/commandr.py
python/sglang/srt/models/commandr.py
+10
-10
python/sglang/srt/models/dbrx.py
python/sglang/srt/models/dbrx.py
+12
-12
python/sglang/srt/models/deepseek.py
python/sglang/srt/models/deepseek.py
+10
-10
No files found.
docs/en/model_support.md
View file @
36d5acfc
...
@@ -30,6 +30,6 @@ To port a model from vLLM to SGLang, you can compare these two files [SGLang Lla
...
@@ -30,6 +30,6 @@ To port a model from vLLM to SGLang, you can compare these two files [SGLang Lla
-
Replace vllm's
`LogitsProcessor`
with SGLang's
`LogitsProcessor`
.
-
Replace vllm's
`LogitsProcessor`
with SGLang's
`LogitsProcessor`
.
-
Replace other vLLM layers with SGLang layers (e.g.,
`RMSNorm`
,
`SiluAndMul`
).
-
Replace other vLLM layers with SGLang layers (e.g.,
`RMSNorm`
,
`SiluAndMul`
).
-
Remove
`Sample`
.
-
Remove
`Sample`
.
-
Change
`forward()`
functions, and add
`
input_metadata
`
.
-
Change
`forward()`
functions, and add
`
forward_batch
`
.
-
Add
`EntryClass`
at the end.
-
Add
`EntryClass`
at the end.
python/sglang/bench_latency.py
View file @
36d5acfc
...
@@ -225,16 +225,16 @@ def extend(reqs, model_runner):
...
@@ -225,16 +225,16 @@ def extend(reqs, model_runner):
tree_cache
=
None
,
tree_cache
=
None
,
)
)
batch
.
prepare_for_extend
(
model_runner
.
model_config
.
vocab_size
)
batch
.
prepare_for_extend
(
model_runner
.
model_config
.
vocab_size
)
input_metadata
=
batch
.
get_input_metadata
()
forward_batch
=
batch
.
get_forward_batch
()
logits_output
=
model_runner
.
forward
(
input_metadata
)
logits_output
=
model_runner
.
forward
(
forward_batch
)
next_token_ids
=
model_runner
.
sample
(
logits_output
,
batch
).
tolist
()
next_token_ids
=
model_runner
.
sample
(
logits_output
,
batch
).
tolist
()
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
def
decode
(
input_token_ids
,
batch
,
model_runner
):
def
decode
(
input_token_ids
,
batch
,
model_runner
):
batch
.
prepare_for_decode
(
input_token_ids
)
batch
.
prepare_for_decode
(
input_token_ids
)
input_metadata
=
batch
.
get_input_metadata
()
forward_batch
=
batch
.
get_forward_batch
()
logits_output
=
model_runner
.
forward
(
input_metadata
)
logits_output
=
model_runner
.
forward
(
forward_batch
)
next_token_ids
=
model_runner
.
sample
(
logits_output
,
batch
).
tolist
()
next_token_ids
=
model_runner
.
sample
(
logits_output
,
batch
).
tolist
()
return
next_token_ids
,
logits_output
.
next_token_logits
return
next_token_ids
,
logits_output
.
next_token_logits
...
...
python/sglang/srt/layers/attention_backend.py
View file @
36d5acfc
...
@@ -16,7 +16,7 @@ import torch.nn as nn
...
@@ -16,7 +16,7 @@ import torch.nn as nn
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.layers.flashinfer_utils
import
update_flashinfer_indices
from
sglang.srt.layers.flashinfer_utils
import
update_flashinfer_indices
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
Forward
Mode
,
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
Forward
Batch
,
ForwardMode
from
sglang.srt.utils
import
is_hip
from
sglang.srt.utils
import
is_hip
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -37,7 +37,7 @@ class AttentionBackend(ABC):
...
@@ -37,7 +37,7 @@ class AttentionBackend(ABC):
"""The base class of attention backends"""
"""The base class of attention backends"""
@
abstractmethod
@
abstractmethod
def
init_forward_metadata
(
self
,
input_metadata
:
InputMetadata
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init the metadata for a forward pass."""
"""Init the metadata for a forward pass."""
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -61,18 +61,18 @@ class AttentionBackend(ABC):
...
@@ -61,18 +61,18 @@ class AttentionBackend(ABC):
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
forward
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
input_metadata
:
InputMetadata
):
def
forward
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
"""Run forward on an attention layer."""
"""Run forward on an attention layer."""
if
input_metadata
.
forward_mode
.
is_decode
():
if
forward_batch
.
forward_mode
.
is_decode
():
return
self
.
forward_decode
(
q
,
k
,
v
,
layer
,
input_metadata
)
return
self
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
)
else
:
else
:
return
self
.
forward_extend
(
q
,
k
,
v
,
layer
,
input_metadata
)
return
self
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
)
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
input_metadata
:
InputMetadata
):
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
"""Run a forward for decode."""
"""Run a forward for decode."""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
input_metadata
:
InputMetadata
):
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
"""Run a forward for extend."""
"""Run a forward for extend."""
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -131,31 +131,31 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -131,31 +131,31 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
forward_metadata
=
None
self
.
forward_metadata
=
None
self
.
cuda_graph_metadata
=
{}
self
.
cuda_graph_metadata
=
{}
def
init_forward_metadata
(
self
,
input_metadata
:
InputMetadata
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
if
input_metadata
.
forward_mode
.
is_decode
():
if
forward_batch
.
forward_mode
.
is_decode
():
prefix_lens
=
None
prefix_lens
=
None
use_ragged
=
False
use_ragged
=
False
extend_no_prefix
=
False
extend_no_prefix
=
False
total_num_tokens
=
None
total_num_tokens
=
None
else
:
else
:
prefix_lens
=
input_metadata
.
extend_prefix_lens
prefix_lens
=
forward_batch
.
extend_prefix_lens
# Some heuristics to check whether to use ragged forward
# Some heuristics to check whether to use ragged forward
use_ragged
=
False
use_ragged
=
False
if
(
if
(
torch
.
sum
(
input_metadata
.
seq_lens
).
item
()
>=
4096
torch
.
sum
(
forward_batch
.
seq_lens
).
item
()
>=
4096
and
self
.
model_runner
.
sliding_window_size
is
None
and
self
.
model_runner
.
sliding_window_size
is
None
):
):
use_ragged
=
True
use_ragged
=
True
total_num_tokens
=
torch
.
sum
(
input_metadata
.
seq_lens
).
item
()
total_num_tokens
=
torch
.
sum
(
forward_batch
.
seq_lens
).
item
()
extend_no_prefix
=
not
torch
.
any
(
input_metadata
.
extend_prefix_lens
).
item
()
extend_no_prefix
=
not
torch
.
any
(
forward_batch
.
extend_prefix_lens
).
item
()
update_flashinfer_indices
(
update_flashinfer_indices
(
input_metadata
.
forward_mode
,
forward_batch
.
forward_mode
,
self
.
model_runner
,
self
.
model_runner
,
input_metadata
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
input_metadata
.
seq_lens
,
forward_batch
.
seq_lens
,
prefix_lens
,
prefix_lens
,
use_ragged
=
use_ragged
,
use_ragged
=
use_ragged
,
)
)
...
@@ -248,7 +248,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -248,7 +248,7 @@ class FlashInferAttnBackend(AttentionBackend):
def
get_cuda_graph_seq_len_fill_value
(
self
):
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
0
return
0
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
input_metadata
:
InputMetadata
):
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
if
not
isinstance
(
self
.
prefill_wrapper_paged
,
list
):
if
not
isinstance
(
self
.
prefill_wrapper_paged
,
list
):
prefill_wrapper_paged
=
self
.
prefill_wrapper_paged
prefill_wrapper_paged
=
self
.
prefill_wrapper_paged
else
:
else
:
...
@@ -264,12 +264,12 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -264,12 +264,12 @@ class FlashInferAttnBackend(AttentionBackend):
if
not
use_ragged
:
if
not
use_ragged
:
if
k
is
not
None
:
if
k
is
not
None
:
assert
v
is
not
None
assert
v
is
not
None
input_metadata
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
input_metadata
.
out_cache_loc
,
k
,
v
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
)
)
o
=
prefill_wrapper_paged
.
forward
(
o
=
prefill_wrapper_paged
.
forward
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
causal
=
True
,
causal
=
True
,
sm_scale
=
layer
.
scaling
,
sm_scale
=
layer
.
scaling
,
window_left
=
layer
.
sliding_window_size
,
window_left
=
layer
.
sliding_window_size
,
...
@@ -290,7 +290,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -290,7 +290,7 @@ class FlashInferAttnBackend(AttentionBackend):
else
:
else
:
o2
,
s2
=
prefill_wrapper_paged
.
forward_return_lse
(
o2
,
s2
=
prefill_wrapper_paged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
causal
=
False
,
causal
=
False
,
sm_scale
=
layer
.
scaling
,
sm_scale
=
layer
.
scaling
,
logits_soft_cap
=
layer
.
logit_cap
,
logits_soft_cap
=
layer
.
logit_cap
,
...
@@ -298,13 +298,13 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -298,13 +298,13 @@ class FlashInferAttnBackend(AttentionBackend):
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
input_metadata
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
input_metadata
.
out_cache_loc
,
k
,
v
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
)
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
input_metadata
:
InputMetadata
):
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
use_ragged
,
extend_no_prefix
,
total_num_tokens
,
decode_wrapper
=
(
use_ragged
,
extend_no_prefix
,
total_num_tokens
,
decode_wrapper
=
(
self
.
forward_metadata
self
.
forward_metadata
)
)
...
@@ -317,13 +317,13 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -317,13 +317,13 @@ class FlashInferAttnBackend(AttentionBackend):
if
k
is
not
None
:
if
k
is
not
None
:
assert
v
is
not
None
assert
v
is
not
None
input_metadata
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
input_metadata
.
out_cache_loc
,
k
,
v
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
)
)
o
=
decode_wrapper
.
forward
(
o
=
decode_wrapper
.
forward
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
sm_scale
=
layer
.
scaling
,
sm_scale
=
layer
.
scaling
,
logits_soft_cap
=
layer
.
logit_cap
,
logits_soft_cap
=
layer
.
logit_cap
,
)
)
...
@@ -358,26 +358,26 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -358,26 +358,26 @@ class TritonAttnBackend(AttentionBackend):
self
.
cuda_graph_max_seq_len
=
model_runner
.
model_config
.
context_len
self
.
cuda_graph_max_seq_len
=
model_runner
.
model_config
.
context_len
def
init_forward_metadata
(
self
,
input_metadata
:
InputMetadata
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init auxiliary variables for triton attention backend."""
"""Init auxiliary variables for triton attention backend."""
if
input_metadata
.
forward_mode
.
is_decode
():
if
forward_batch
.
forward_mode
.
is_decode
():
start_loc
=
torch
.
zeros_like
(
input_metadata
.
seq_lens
,
dtype
=
torch
.
int32
)
start_loc
=
torch
.
zeros_like
(
forward_batch
.
seq_lens
,
dtype
=
torch
.
int32
)
start_loc
[
1
:]
=
torch
.
cumsum
(
input_metadata
.
seq_lens
[:
-
1
],
dim
=
0
)
start_loc
[
1
:]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
[:
-
1
],
dim
=
0
)
total_num_tokens
=
torch
.
sum
(
input_metadata
.
seq_lens
).
item
()
total_num_tokens
=
torch
.
sum
(
forward_batch
.
seq_lens
).
item
()
attn_logits
=
torch
.
empty
(
attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
total_num_tokens
),
(
self
.
num_head
,
total_num_tokens
),
dtype
=
self
.
reduce_dtype
,
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
max_seq_len
=
torch
.
max
(
input_metadata
.
seq_lens
).
item
()
max_seq_len
=
torch
.
max
(
forward_batch
.
seq_lens
).
item
()
max_extend_len
=
None
max_extend_len
=
None
else
:
else
:
start_loc
=
attn_logits
=
max_seq_len
=
None
start_loc
=
attn_logits
=
max_seq_len
=
None
prefix_lens
=
input_metadata
.
extend_prefix_lens
prefix_lens
=
forward_batch
.
extend_prefix_lens
max_extend_len
=
torch
.
max
(
input_metadata
.
seq_lens
-
prefix_lens
).
item
()
max_extend_len
=
torch
.
max
(
forward_batch
.
seq_lens
-
prefix_lens
).
item
()
self
.
forward_metadata
=
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
self
.
forward_metadata
=
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
...
@@ -415,15 +415,15 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -415,15 +415,15 @@ class TritonAttnBackend(AttentionBackend):
def
get_cuda_graph_seq_len_fill_value
(
self
):
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
return
1
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
input_metadata
:
InputMetadata
):
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
# TODO: reuse the buffer across layers
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
else
:
o
=
torch
.
empty_like
(
q
)
o
=
torch
.
empty_like
(
q
)
input_metadata
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
input_metadata
.
out_cache_loc
,
k
,
v
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
)
)
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
...
@@ -432,20 +432,20 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -432,20 +432,20 @@ class TritonAttnBackend(AttentionBackend):
k
.
contiguous
(),
k
.
contiguous
(),
v
.
contiguous
(),
v
.
contiguous
(),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
input_metadata
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
input_metadata
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_to_token_pool
.
req_to_token
,
input_metadata
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
input_metadata
.
seq_lens
,
forward_batch
.
seq_lens
,
input_metadata
.
extend_seq_lens
,
forward_batch
.
extend_seq_lens
,
input_metadata
.
extend_start_loc
,
forward_batch
.
extend_start_loc
,
max_extend_len
,
max_extend_len
,
layer
.
scaling
,
layer
.
scaling
,
layer
.
logit_cap
,
layer
.
logit_cap
,
)
)
return
o
return
o
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
input_metadata
:
InputMetadata
):
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
# During torch.compile, there is a bug in rotary_emb that causes the
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
# output value to have a 3D tensor shape. This reshapes the output correctly.
q
=
q
.
reshape
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
q
=
q
.
reshape
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
...
@@ -458,19 +458,19 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -458,19 +458,19 @@ class TritonAttnBackend(AttentionBackend):
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
input_metadata
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
input_metadata
.
out_cache_loc
,
k
,
v
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
)
)
self
.
decode_attention_fwd
(
self
.
decode_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
input_metadata
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
input_metadata
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_to_token_pool
.
req_to_token
,
input_metadata
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
start_loc
,
start_loc
,
input_metadata
.
seq_lens
,
forward_batch
.
seq_lens
,
attn_logits
,
attn_logits
,
max_seq_len
,
max_seq_len
,
layer
.
scaling
,
layer
.
scaling
,
...
...
python/sglang/srt/layers/logits_processor.py
View file @
36d5acfc
...
@@ -25,7 +25,7 @@ from vllm.distributed import (
...
@@ -25,7 +25,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
Forward
Mode
,
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
Forward
Batch
,
ForwardMode
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -61,26 +61,26 @@ class LogitsMetadata:
...
@@ -61,26 +61,26 @@ class LogitsMetadata:
extend_logprob_pruned_lens_cpu
:
Optional
[
List
[
int
]]
=
None
extend_logprob_pruned_lens_cpu
:
Optional
[
List
[
int
]]
=
None
@
classmethod
@
classmethod
def
from_
input_metadata
(
cls
,
input_metadata
:
InputMetadata
):
def
from_
forward_batch
(
cls
,
forward_batch
:
ForwardBatch
):
return_top_logprob
=
any
(
x
>
0
for
x
in
input_metadata
.
top_logprobs_nums
)
return_top_logprob
=
any
(
x
>
0
for
x
in
forward_batch
.
top_logprobs_nums
)
if
input_metadata
.
forward_mode
.
is_extend
():
if
forward_batch
.
forward_mode
.
is_extend
():
extend_logprob_pruned_lens_cpu
=
[
extend_logprob_pruned_lens_cpu
=
[
extend_len
-
start_len
extend_len
-
start_len
for
extend_len
,
start_len
in
zip
(
for
extend_len
,
start_len
in
zip
(
input_metadata
.
extend_seq_lens
,
forward_batch
.
extend_seq_lens
,
input_metadata
.
extend_logprob_start_lens_cpu
,
forward_batch
.
extend_logprob_start_lens_cpu
,
)
)
]
]
else
:
else
:
extend_logprob_pruned_lens_cpu
=
None
extend_logprob_pruned_lens_cpu
=
None
return
cls
(
return
cls
(
forward_mode
=
input_metadata
.
forward_mode
,
forward_mode
=
forward_batch
.
forward_mode
,
top_logprobs_nums
=
input_metadata
.
top_logprobs_nums
,
top_logprobs_nums
=
forward_batch
.
top_logprobs_nums
,
return_logprob
=
input_metadata
.
return_logprob
,
return_logprob
=
forward_batch
.
return_logprob
,
return_top_logprob
=
return_top_logprob
,
return_top_logprob
=
return_top_logprob
,
extend_seq_lens
=
input_metadata
.
extend_seq_lens
,
extend_seq_lens
=
forward_batch
.
extend_seq_lens
,
extend_seq_lens_cpu
=
input_metadata
.
extend_seq_lens_cpu
,
extend_seq_lens_cpu
=
forward_batch
.
extend_seq_lens_cpu
,
extend_logprob_start_lens_cpu
=
input_metadata
.
extend_logprob_start_lens_cpu
,
extend_logprob_start_lens_cpu
=
forward_batch
.
extend_logprob_start_lens_cpu
,
extend_logprob_pruned_lens_cpu
=
extend_logprob_pruned_lens_cpu
,
extend_logprob_pruned_lens_cpu
=
extend_logprob_pruned_lens_cpu
,
)
)
...
@@ -162,10 +162,10 @@ class LogitsProcessor(nn.Module):
...
@@ -162,10 +162,10 @@ class LogitsProcessor(nn.Module):
input_ids
,
input_ids
,
hidden_states
,
hidden_states
,
weight
,
weight
,
logits_metadata
:
Union
[
LogitsMetadata
,
InputMetadata
],
logits_metadata
:
Union
[
LogitsMetadata
,
ForwardBatch
],
):
):
if
isinstance
(
logits_metadata
,
InputMetadata
):
if
isinstance
(
logits_metadata
,
ForwardBatch
):
logits_metadata
=
LogitsMetadata
.
from_
input_metadata
(
logits_metadata
)
logits_metadata
=
LogitsMetadata
.
from_
forward_batch
(
logits_metadata
)
assert
isinstance
(
logits_metadata
,
LogitsMetadata
)
assert
isinstance
(
logits_metadata
,
LogitsMetadata
)
# Get the last hidden states and last logits for the next token prediction
# Get the last hidden states and last logits for the next token prediction
...
...
python/sglang/srt/layers/pooler.py
View file @
36d5acfc
...
@@ -7,7 +7,7 @@ from enum import IntEnum
...
@@ -7,7 +7,7 @@ from enum import IntEnum
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
sglang.srt.model_executor.model_runner
import
InputMetadata
from
sglang.srt.model_executor.model_runner
import
ForwardBatch
class
PoolingType
(
IntEnum
):
class
PoolingType
(
IntEnum
):
...
@@ -36,10 +36,10 @@ class Pooler(nn.Module):
...
@@ -36,10 +36,10 @@ class Pooler(nn.Module):
self
.
normalize
=
normalize
self
.
normalize
=
normalize
def
forward
(
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
)
->
EmbeddingPoolerOutput
:
)
->
EmbeddingPoolerOutput
:
if
self
.
pooling_type
==
PoolingType
.
LAST
:
if
self
.
pooling_type
==
PoolingType
.
LAST
:
last_token_indices
=
torch
.
cumsum
(
input_metadata
.
extend_seq_lens
,
dim
=
0
)
-
1
last_token_indices
=
torch
.
cumsum
(
forward_batch
.
extend_seq_lens
,
dim
=
0
)
-
1
pooled_data
=
hidden_states
[
last_token_indices
]
pooled_data
=
hidden_states
[
last_token_indices
]
else
:
else
:
raise
ValueError
(
f
"Invalid pooling type:
{
self
.
pooling_type
}
"
)
raise
ValueError
(
f
"Invalid pooling type:
{
self
.
pooling_type
}
"
)
...
...
python/sglang/srt/layers/radix_attention.py
View file @
36d5acfc
...
@@ -17,7 +17,7 @@ limitations under the License.
...
@@ -17,7 +17,7 @@ limitations under the License.
from
torch
import
nn
from
torch
import
nn
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
class
RadixAttention
(
nn
.
Module
):
class
RadixAttention
(
nn
.
Module
):
...
@@ -48,11 +48,11 @@ class RadixAttention(nn.Module):
...
@@ -48,11 +48,11 @@ class RadixAttention(nn.Module):
self
.
logit_cap
=
logit_cap
self
.
logit_cap
=
logit_cap
self
.
sliding_window_size
=
sliding_window_size
or
-
1
self
.
sliding_window_size
=
sliding_window_size
or
-
1
def
forward
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
def
forward
(
self
,
q
,
k
,
v
,
forward_batch
:
ForwardBatch
):
if
k
is
not
None
:
if
k
is
not
None
:
# For cross-layer sharing, kv can be None
# For cross-layer sharing, kv can be None
assert
v
is
not
None
assert
v
is
not
None
k
=
k
.
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
qk_head_dim
)
k
=
k
.
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
qk_head_dim
)
v
=
v
.
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
v_head_dim
)
v
=
v
.
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
v_head_dim
)
return
input_metadata
.
attn_backend
.
forward
(
q
,
k
,
v
,
self
,
input_metadata
)
return
forward_batch
.
attn_backend
.
forward
(
q
,
k
,
v
,
self
,
forward_batch
)
python/sglang/srt/lora/lora.py
View file @
36d5acfc
...
@@ -40,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -40,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
)
from
vllm.model_executor.model_loader.loader
import
DefaultModelLoader
from
vllm.model_executor.model_loader.loader
import
DefaultModelLoader
from
sglang.srt.model_executor.forward_batch_info
import
Forward
Mode
,
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
Forward
Batch
,
ForwardMode
class
BaseLayerWithLoRA
(
nn
.
Module
):
class
BaseLayerWithLoRA
(
nn
.
Module
):
...
...
python/sglang/srt/lora/lora_manager.py
View file @
36d5acfc
...
@@ -23,7 +23,7 @@ import torch
...
@@ -23,7 +23,7 @@ import torch
from
sglang.srt.lora.lora
import
LoRAAdapter
,
get_lora_layer
from
sglang.srt.lora.lora
import
LoRAAdapter
,
get_lora_layer
from
sglang.srt.lora.lora_config
import
LoRAConfig
from
sglang.srt.lora.lora_config
import
LoRAConfig
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
is_hip
,
replace_submodule
from
sglang.srt.utils
import
is_hip
,
replace_submodule
# ROCm: flashinfer available later
# ROCm: flashinfer available later
...
@@ -207,9 +207,9 @@ class LoRAManager:
...
@@ -207,9 +207,9 @@ class LoRAManager:
if
lora_weight_name
:
if
lora_weight_name
:
self
.
B_buffer
[
lora_weight_name
][
i
][
buffer_id
].
copy_
(
weights
)
self
.
B_buffer
[
lora_weight_name
][
i
][
buffer_id
].
copy_
(
weights
)
def
prepare_lora_batch
(
self
,
input_metadata
:
InputMetadata
):
def
prepare_lora_batch
(
self
,
forward_batch
:
ForwardBatch
):
# load active loras into lora memory pool
# load active loras into lora memory pool
cur_uids
=
set
(
input_metadata
.
lora_paths
)
cur_uids
=
set
(
forward_batch
.
lora_paths
)
assert
len
(
cur_uids
)
<=
self
.
max_loras_per_batch
assert
len
(
cur_uids
)
<=
self
.
max_loras_per_batch
i
=
0
i
=
0
evictable_uids
=
list
(
self
.
active_uids
)
evictable_uids
=
list
(
self
.
active_uids
)
...
@@ -229,14 +229,14 @@ class LoRAManager:
...
@@ -229,14 +229,14 @@ class LoRAManager:
return
return
# setup lora in forward modules
# setup lora in forward modules
bs
=
input_metadata
.
batch_size
bs
=
forward_batch
.
batch_size
seg_lens
=
(
seg_lens
=
(
input_metadata
.
extend_seq_lens
forward_batch
.
extend_seq_lens
if
input_metadata
.
forward_mode
.
is_extend
()
if
forward_batch
.
forward_mode
.
is_extend
()
else
torch
.
ones
(
bs
)
else
torch
.
ones
(
bs
)
)
)
weight_indices
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
weight_indices
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
for
i
,
lora_path
in
enumerate
(
input_metadata
.
lora_paths
):
for
i
,
lora_path
in
enumerate
(
forward_batch
.
lora_paths
):
weight_indices
[
i
]
=
self
.
buffer_id
[
lora_path
]
weight_indices
[
i
]
=
self
.
buffer_id
[
lora_path
]
for
module_name
,
module
in
self
.
lora_modules
:
for
module_name
,
module
in
self
.
lora_modules
:
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
36d5acfc
...
@@ -29,7 +29,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
...
@@ -29,7 +29,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.model_executor.forward_batch_info
import
Forward
Mode
,
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
Forward
Batch
,
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
...
@@ -511,8 +511,8 @@ class ScheduleBatch:
...
@@ -511,8 +511,8 @@ class ScheduleBatch:
self
.
extend_logprob_start_lens_cpu
=
[
r
.
extend_logprob_start_len
for
r
in
reqs
]
self
.
extend_logprob_start_lens_cpu
=
[
r
.
extend_logprob_start_len
for
r
in
reqs
]
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
vocab_size
)
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
vocab_size
)
def
get_
input_metadata
(
self
):
def
get_
forward_batch
(
self
):
return
InputMetadata
.
from_schedule_batch
(
self
)
return
ForwardBatch
.
from_schedule_batch
(
self
)
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
self
.
forward_mode
=
ForwardMode
.
MIXED
self
.
forward_mode
=
ForwardMode
.
MIXED
...
...
python/sglang/srt/managers/schedule
r
_policy.py
→
python/sglang/srt/managers/schedule_policy.py
View file @
36d5acfc
...
@@ -32,7 +32,7 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
...
@@ -32,7 +32,7 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
CLIP_MAX_NEW_TOKENS
=
int
(
os
.
environ
.
get
(
"SGLANG_CLIP_MAX_NEW_TOKENS"
,
"4096"
))
CLIP_MAX_NEW_TOKENS
=
int
(
os
.
environ
.
get
(
"SGLANG_CLIP_MAX_NEW_TOKENS"
,
"4096"
))
class
Schedule
r
Policy
:
class
SchedulePolicy
:
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
):
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
):
if
tree_cache
.
disable
and
policy
in
[
"lpm"
,
"dfs-weight"
]:
if
tree_cache
.
disable
and
policy
in
[
"lpm"
,
"dfs-weight"
]:
# LPM and DFS-weight is meaningless when the tree cache is disabled.
# LPM and DFS-weight is meaningless when the tree cache is disabled.
...
...
python/sglang/srt/managers/scheduler.py
View file @
36d5acfc
...
@@ -50,8 +50,8 @@ from sglang.srt.managers.schedule_batch import (
...
@@ -50,8 +50,8 @@ from sglang.srt.managers.schedule_batch import (
Req
,
Req
,
ScheduleBatch
,
ScheduleBatch
,
)
)
from
sglang.srt.managers.schedule
r
_policy
import
PrefillAdder
,
Schedule
r
Policy
from
sglang.srt.managers.schedule_policy
import
PrefillAdder
,
SchedulePolicy
from
sglang.srt.managers.tp_worker
import
Model
Tp
Worker
from
sglang.srt.managers.tp_worker
import
Tp
ModelWorker
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
...
@@ -134,7 +134,7 @@ class Scheduler:
...
@@ -134,7 +134,7 @@ class Scheduler:
)
)
# Launch a tensor parallel worker
# Launch a tensor parallel worker
self
.
tp_worker
=
Model
Tp
Worker
(
self
.
tp_worker
=
Tp
ModelWorker
(
gpu_id
=
gpu_id
,
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
server_args
=
server_args
,
server_args
=
server_args
,
...
@@ -179,7 +179,7 @@ class Scheduler:
...
@@ -179,7 +179,7 @@ class Scheduler:
disable
=
server_args
.
disable_radix_cache
,
disable
=
server_args
.
disable_radix_cache
,
)
)
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
policy
=
Schedule
r
Policy
(
self
.
schedule_policy
,
self
.
tree_cache
)
self
.
policy
=
SchedulePolicy
(
self
.
schedule_policy
,
self
.
tree_cache
)
# Init running status
# Init running status
self
.
waiting_queue
:
List
[
Req
]
=
[]
self
.
waiting_queue
:
List
[
Req
]
=
[]
...
@@ -575,9 +575,9 @@ class Scheduler:
...
@@ -575,9 +575,9 @@ class Scheduler:
if
self
.
is_generation
:
if
self
.
is_generation
:
# Forward and sample the next tokens
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
if
batch
.
extend_num_tokens
!=
0
:
input_metadata
=
batch
.
get_input_metadata
()
forward_batch
=
batch
.
get_forward_batch
()
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
input_metadata
,
batch
forward_batch
,
batch
)
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
next_token_ids
...
@@ -641,8 +641,8 @@ class Scheduler:
...
@@ -641,8 +641,8 @@ class Scheduler:
)
)
else
:
else
:
assert
batch
.
extend_num_tokens
!=
0
assert
batch
.
extend_num_tokens
!=
0
input_metadata
=
batch
.
get_input_metadata
()
forward_batch
=
batch
.
get_forward_batch
()
embeddings
=
self
.
tp_worker
.
forward_batch_embedding
(
input_metadata
)
embeddings
=
self
.
tp_worker
.
forward_batch_embedding
(
forward_batch
)
# Check finish conditions
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
...
@@ -771,9 +771,9 @@ class Scheduler:
...
@@ -771,9 +771,9 @@ class Scheduler:
batch
.
prepare_for_decode
()
batch
.
prepare_for_decode
()
# Forward and sample the next tokens
# Forward and sample the next tokens
input_metadata
=
batch
.
get_input_metadata
()
forward_batch
=
batch
.
get_forward_batch
()
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
input_metadata
,
batch
forward_batch
,
batch
)
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
next_token_ids
...
...
python/sglang/srt/managers/tp_worker.py
View file @
36d5acfc
...
@@ -21,7 +21,7 @@ import logging
...
@@ -21,7 +21,7 @@ import logging
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.managers.io_struct
import
UpdateWeightReqInput
from
sglang.srt.managers.io_struct
import
UpdateWeightReqInput
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
broadcast_pyobj
,
is_multimodal_model
,
set_random_seed
from
sglang.srt.utils
import
broadcast_pyobj
,
is_multimodal_model
,
set_random_seed
...
@@ -29,7 +29,9 @@ from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_se
...
@@ -29,7 +29,9 @@ from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_se
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
ModelTpWorker
:
class
TpModelWorker
:
"""A tensor parallel model worker."""
def
__init__
(
def
__init__
(
self
,
self
,
gpu_id
:
int
,
gpu_id
:
int
,
...
@@ -106,13 +108,13 @@ class ModelTpWorker:
...
@@ -106,13 +108,13 @@ class ModelTpWorker:
self
.
random_seed
,
self
.
random_seed
,
)
)
def
forward_batch_generation
(
self
,
input_metadata
:
InputMetadata
,
batch
):
def
forward_batch_generation
(
self
,
forward_batch
:
ForwardBatch
,
batch
):
logits_output
=
self
.
model_runner
.
forward
(
input_metadata
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
batch
)
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
batch
)
return
logits_output
,
next_token_ids
return
logits_output
,
next_token_ids
def
forward_batch_embedding
(
self
,
input_metadata
:
InputMetadata
):
def
forward_batch_embedding
(
self
,
forward_batch
:
ForwardBatch
):
logits_output
=
self
.
model_runner
.
forward
(
input_metadata
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
embeddings
=
logits_output
.
embeddings
.
tolist
()
embeddings
=
logits_output
.
embeddings
.
tolist
()
return
embeddings
return
embeddings
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
36d5acfc
...
@@ -31,7 +31,7 @@ from sglang.srt.layers.logits_processor import (
...
@@ -31,7 +31,7 @@ from sglang.srt.layers.logits_processor import (
LogitsProcessor
,
LogitsProcessor
,
LogitsProcessorOutput
,
LogitsProcessorOutput
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
Forward
Mode
,
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
Forward
Batch
,
ForwardMode
from
sglang.srt.utils
import
monkey_patch_vllm_all_gather
from
sglang.srt.utils
import
monkey_patch_vllm_all_gather
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -196,7 +196,7 @@ class CudaGraphRunner:
...
@@ -196,7 +196,7 @@ class CudaGraphRunner:
# Run and capture
# Run and capture
def
run_once
():
def
run_once
():
input_metadata
=
InputMetadata
(
forward_batch
=
ForwardBatch
(
forward_mode
=
ForwardMode
.
DECODE
,
forward_mode
=
ForwardMode
.
DECODE
,
batch_size
=
bs
,
batch_size
=
bs
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
...
@@ -210,7 +210,7 @@ class CudaGraphRunner:
...
@@ -210,7 +210,7 @@ class CudaGraphRunner:
top_logprobs_nums
=
[
0
]
*
bs
,
top_logprobs_nums
=
[
0
]
*
bs
,
positions
=
torch
.
clamp
((
seq_lens
-
1
),
min
=
0
).
to
(
torch
.
int64
),
positions
=
torch
.
clamp
((
seq_lens
-
1
),
min
=
0
).
to
(
torch
.
int64
),
)
)
return
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
return
forward
(
input_ids
,
forward_batch
.
positions
,
forward_batch
)
for
_
in
range
(
2
):
for
_
in
range
(
2
):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -233,9 +233,9 @@ class CudaGraphRunner:
...
@@ -233,9 +233,9 @@ class CudaGraphRunner:
self
.
graph_memory_pool
=
graph
.
pool
()
self
.
graph_memory_pool
=
graph
.
pool
()
return
graph
,
out
return
graph
,
out
def
replay
(
self
,
input_metadata
:
InputMetadata
):
def
replay
(
self
,
forward_batch
:
ForwardBatch
):
assert
input_metadata
.
out_cache_loc
is
not
None
assert
forward_batch
.
out_cache_loc
is
not
None
raw_bs
=
input_metadata
.
batch_size
raw_bs
=
forward_batch
.
batch_size
# Pad
# Pad
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
...
@@ -245,10 +245,10 @@ class CudaGraphRunner:
...
@@ -245,10 +245,10 @@ class CudaGraphRunner:
self
.
out_cache_loc
.
zero_
()
self
.
out_cache_loc
.
zero_
()
# Common inputs
# Common inputs
self
.
input_ids
[:
raw_bs
]
=
input_metadata
.
input_ids
self
.
input_ids
[:
raw_bs
]
=
forward_batch
.
input_ids
self
.
req_pool_indices
[:
raw_bs
]
=
input_metadata
.
req_pool_indices
self
.
req_pool_indices
[:
raw_bs
]
=
forward_batch
.
req_pool_indices
self
.
seq_lens
[:
raw_bs
]
=
input_metadata
.
seq_lens
self
.
seq_lens
[:
raw_bs
]
=
forward_batch
.
seq_lens
self
.
out_cache_loc
[:
raw_bs
]
=
input_metadata
.
out_cache_loc
self
.
out_cache_loc
[:
raw_bs
]
=
forward_batch
.
out_cache_loc
# Attention backend
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
...
@@ -271,15 +271,15 @@ class CudaGraphRunner:
...
@@ -271,15 +271,15 @@ class CudaGraphRunner:
)
)
# Extract logprobs
# Extract logprobs
if
input_metadata
.
return_logprob
:
if
forward_batch
.
return_logprob
:
logits_output
.
next_token_logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
logits_output
.
next_token_logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
logits_output
.
next_token_logits
,
dim
=-
1
logits_output
.
next_token_logits
,
dim
=-
1
)
)
return_top_logprob
=
any
(
x
>
0
for
x
in
input_metadata
.
top_logprobs_nums
)
return_top_logprob
=
any
(
x
>
0
for
x
in
forward_batch
.
top_logprobs_nums
)
if
return_top_logprob
:
if
return_top_logprob
:
logits_metadata
=
LogitsMetadata
(
logits_metadata
=
LogitsMetadata
(
forward_mode
=
ForwardMode
.
DECODE
,
forward_mode
=
ForwardMode
.
DECODE
,
top_logprobs_nums
=
input_metadata
.
top_logprobs_nums
,
top_logprobs_nums
=
forward_batch
.
top_logprobs_nums
,
)
)
logits_output
.
output_top_logprobs
=
LogitsProcessor
.
get_top_logprobs
(
logits_output
.
output_top_logprobs
=
LogitsProcessor
.
get_top_logprobs
(
logits_output
.
next_token_logprobs
,
logits_metadata
logits_output
.
next_token_logprobs
,
logits_metadata
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
36d5acfc
...
@@ -18,7 +18,7 @@ limitations under the License.
...
@@ -18,7 +18,7 @@ limitations under the License.
"""Meta data for a forward pass."""
"""Meta data for a forward pass."""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
enum
import
IntEnum
,
auto
from
typing
import
TYPE_CHECKING
,
List
,
Set
from
typing
import
TYPE_CHECKING
,
List
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -53,8 +53,8 @@ class ForwardMode(IntEnum):
...
@@ -53,8 +53,8 @@ class ForwardMode(IntEnum):
@
dataclass
@
dataclass
class
InputMetadata
:
class
ForwardBatch
:
"""Store all in
foramtion
of a forward pass."""
"""Store all in
puts
of a forward pass."""
# The forward mode
# The forward mode
forward_mode
:
ForwardMode
forward_mode
:
ForwardMode
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
36d5acfc
...
@@ -48,7 +48,7 @@ from sglang.srt.mem_cache.memory_pool import (
...
@@ -48,7 +48,7 @@ from sglang.srt.mem_cache.memory_pool import (
MLATokenToKVPool
,
MLATokenToKVPool
,
ReqToTokenPool
,
ReqToTokenPool
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
...
@@ -466,47 +466,47 @@ class ModelRunner:
...
@@ -466,47 +466,47 @@ class ModelRunner:
logger
.
info
(
"Capture cuda graph begin. This can take up to several minutes."
)
logger
.
info
(
"Capture cuda graph begin. This can take up to several minutes."
)
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
def
forward_decode
(
self
,
input_metadata
:
InputMetadata
):
def
forward_decode
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
input_metadata
.
batch_size
forward_batch
.
batch_size
):
):
return
self
.
cuda_graph_runner
.
replay
(
input_metadata
)
return
self
.
cuda_graph_runner
.
replay
(
forward_batch
)
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
input_metadata
.
input_ids
,
input_metadata
.
positions
,
input_metadata
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
)
)
def
forward_extend
(
self
,
input_metadata
:
InputMetadata
):
def
forward_extend
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
is_generation
:
if
self
.
is_generation
:
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
input_metadata
.
input_ids
,
input_metadata
.
positions
,
input_metadata
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
)
)
else
:
else
:
# Only embedding models have get_embedding parameter
# Only embedding models have get_embedding parameter
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
input_metadata
.
input_ids
,
forward_batch
.
input_ids
,
input_metadata
.
positions
,
forward_batch
.
positions
,
input_metadata
,
forward_batch
,
get_embedding
=
True
,
get_embedding
=
True
,
)
)
def
forward
(
self
,
input_metadata
:
InputMetadata
)
->
LogitsProcessorOutput
:
def
forward
(
self
,
forward_batch
:
ForwardBatch
)
->
LogitsProcessorOutput
:
# Attach attention information
# Attach attention information
input_metadata
.
req_to_token_pool
=
self
.
req_to_token_pool
forward_batch
.
req_to_token_pool
=
self
.
req_to_token_pool
input_metadata
.
token_to_kv_pool
=
self
.
token_to_kv_pool
forward_batch
.
token_to_kv_pool
=
self
.
token_to_kv_pool
input_metadata
.
attn_backend
=
self
.
attn_backend
forward_batch
.
attn_backend
=
self
.
attn_backend
input_metadata
.
attn_backend
.
init_forward_metadata
(
input_metadata
)
forward_batch
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
# Attach lora information
# Attach lora information
if
self
.
server_args
.
lora_paths
is
not
None
:
if
self
.
server_args
.
lora_paths
is
not
None
:
self
.
lora_manager
.
prepare_lora_batch
(
input_metadata
)
self
.
lora_manager
.
prepare_lora_batch
(
forward_batch
)
if
input_metadata
.
forward_mode
.
is_decode
():
if
forward_batch
.
forward_mode
.
is_decode
():
return
self
.
forward_decode
(
input_metadata
)
return
self
.
forward_decode
(
forward_batch
)
elif
input_metadata
.
forward_mode
.
is_extend
():
elif
forward_batch
.
forward_mode
.
is_extend
():
return
self
.
forward_extend
(
input_metadata
)
return
self
.
forward_extend
(
forward_batch
)
else
:
else
:
raise
ValueError
(
f
"Invaid forward mode:
{
input_metadata
.
forward_mode
}
"
)
raise
ValueError
(
f
"Invaid forward mode:
{
forward_batch
.
forward_mode
}
"
)
def
_apply_logits_bias
(
def
_apply_logits_bias
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
...
...
python/sglang/srt/models/baichuan.py
View file @
36d5acfc
...
@@ -46,7 +46,7 @@ from sglang.srt.layers.layernorm import RMSNorm
...
@@ -46,7 +46,7 @@ from sglang.srt.layers.layernorm import RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
def
_get_alibi_slopes
(
total_num_heads
:
int
)
->
torch
.
Tensor
:
def
_get_alibi_slopes
(
total_num_heads
:
int
)
->
torch
.
Tensor
:
...
@@ -189,13 +189,13 @@ class BaiChuanAttention(nn.Module):
...
@@ -189,13 +189,13 @@ class BaiChuanAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
W_pack
(
hidden_states
)
qkv
,
_
=
self
.
W_pack
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
if
self
.
postion_embedding
!=
"ALIBI"
:
if
self
.
postion_embedding
!=
"ALIBI"
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -237,7 +237,7 @@ class BaiChuanDecoderLayer(nn.Module):
...
@@ -237,7 +237,7 @@ class BaiChuanDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
# Self Attention
...
@@ -249,7 +249,7 @@ class BaiChuanDecoderLayer(nn.Module):
...
@@ -249,7 +249,7 @@ class BaiChuanDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
)
# Fully Connected
# Fully Connected
...
@@ -292,7 +292,7 @@ class BaiChuanModel(nn.Module):
...
@@ -292,7 +292,7 @@ class BaiChuanModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
residual
=
None
...
@@ -301,7 +301,7 @@ class BaiChuanModel(nn.Module):
...
@@ -301,7 +301,7 @@ class BaiChuanModel(nn.Module):
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
positions
,
hidden_states
,
hidden_states
,
input_metadata
,
forward_batch
,
residual
,
residual
,
)
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
...
@@ -350,11 +350,11 @@ class BaiChuanBaseForCausalLM(nn.Module):
...
@@ -350,11 +350,11 @@ class BaiChuanBaseForCausalLM(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/chatglm.py
View file @
36d5acfc
...
@@ -42,7 +42,7 @@ from sglang.srt.layers.linear import (
...
@@ -42,7 +42,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
LoraConfig
=
None
LoraConfig
=
None
...
@@ -118,7 +118,7 @@ class GLMAttention(nn.Module):
...
@@ -118,7 +118,7 @@ class GLMAttention(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
...
@@ -127,7 +127,7 @@ class GLMAttention(nn.Module):
...
@@ -127,7 +127,7 @@ class GLMAttention(nn.Module):
q
,
q
,
k
,
k
,
v
,
v
,
input_metadata
,
forward_batch
,
)
)
attn_output
,
_
=
self
.
dense
(
context_layer
)
attn_output
,
_
=
self
.
dense
(
context_layer
)
return
attn_output
return
attn_output
...
@@ -220,7 +220,7 @@ class GLMBlock(nn.Module):
...
@@ -220,7 +220,7 @@ class GLMBlock(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# hidden_states: [num_tokens, h]
# hidden_states: [num_tokens, h]
# Layer norm at the beginning of the transformer layer.
# Layer norm at the beginning of the transformer layer.
...
@@ -229,7 +229,7 @@ class GLMBlock(nn.Module):
...
@@ -229,7 +229,7 @@ class GLMBlock(nn.Module):
attention_output
=
self
.
self_attention
(
attention_output
=
self
.
self_attention
(
hidden_states
=
layernorm_output
,
hidden_states
=
layernorm_output
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
)
# Residual connection.
# Residual connection.
...
@@ -288,14 +288,14 @@ class GLMTransformer(nn.Module):
...
@@ -288,14 +288,14 @@ class GLMTransformer(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
for
i
in
range
(
self
.
num_layers
):
for
i
in
range
(
self
.
num_layers
):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
hidden_states
=
layer
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
)
# Final layer norm.
# Final layer norm.
if
self
.
post_layer_norm
:
if
self
.
post_layer_norm
:
...
@@ -328,7 +328,7 @@ class ChatGLMModel(nn.Module):
...
@@ -328,7 +328,7 @@ class ChatGLMModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
embedding
(
input_ids
)
inputs_embeds
=
self
.
embedding
(
input_ids
)
...
@@ -336,7 +336,7 @@ class ChatGLMModel(nn.Module):
...
@@ -336,7 +336,7 @@ class ChatGLMModel(nn.Module):
hidden_states
=
self
.
encoder
(
hidden_states
=
self
.
encoder
(
hidden_states
=
inputs_embeds
,
hidden_states
=
inputs_embeds
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
)
return
hidden_states
return
hidden_states
...
@@ -376,11 +376,11 @@ class ChatGLMForCausalLM(nn.Module):
...
@@ -376,11 +376,11 @@ class ChatGLMForCausalLM(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
forward_batch
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/commandr.py
View file @
36d5acfc
...
@@ -63,7 +63,7 @@ from sglang.srt.layers.linear import (
...
@@ -63,7 +63,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.utils
import
set_weight_attrs
...
@@ -220,14 +220,14 @@ class CohereAttention(nn.Module):
...
@@ -220,14 +220,14 @@ class CohereAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
if
self
.
use_qk_norm
:
if
self
.
use_qk_norm
:
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -255,7 +255,7 @@ class CohereDecoderLayer(nn.Module):
...
@@ -255,7 +255,7 @@ class CohereDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
# Self Attention
...
@@ -264,7 +264,7 @@ class CohereDecoderLayer(nn.Module):
...
@@ -264,7 +264,7 @@ class CohereDecoderLayer(nn.Module):
hidden_states_attention
=
self
.
self_attn
(
hidden_states_attention
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
)
hidden_states_mlp
=
self
.
mlp
(
hidden_states
)
hidden_states_mlp
=
self
.
mlp
(
hidden_states
)
# Add everything together
# Add everything together
...
@@ -299,7 +299,7 @@ class CohereModel(nn.Module):
...
@@ -299,7 +299,7 @@ class CohereModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
residual
=
None
...
@@ -308,7 +308,7 @@ class CohereModel(nn.Module):
...
@@ -308,7 +308,7 @@ class CohereModel(nn.Module):
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
positions
,
hidden_states
,
hidden_states
,
input_metadata
,
forward_batch
,
residual
,
residual
,
)
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
...
@@ -333,15 +333,15 @@ class CohereForCausalLM(nn.Module):
...
@@ -333,15 +333,15 @@ class CohereForCausalLM(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
,
input_ids
,
positions
,
positions
,
input_metadata
,
forward_batch
,
)
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
forward_batch
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/dbrx.py
View file @
36d5acfc
...
@@ -44,7 +44,7 @@ from sglang.srt.layers.linear import (
...
@@ -44,7 +44,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.utils
import
set_weight_attrs
...
@@ -249,14 +249,14 @@ class DbrxAttention(nn.Module):
...
@@ -249,14 +249,14 @@ class DbrxAttention(nn.Module):
self
,
self
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
Wqkv
(
hidden_states
)
qkv
,
_
=
self
.
Wqkv
(
hidden_states
)
if
self
.
clip_qkv
is
not
None
:
if
self
.
clip_qkv
is
not
None
:
qkv
.
clamp_
(
min
=-
self
.
clip_qkv
,
max
=
self
.
clip_qkv
)
qkv
.
clamp_
(
min
=-
self
.
clip_qkv
,
max
=
self
.
clip_qkv
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
hidden_states
,
_
=
self
.
out_proj
(
attn_output
)
hidden_states
,
_
=
self
.
out_proj
(
attn_output
)
return
hidden_states
return
hidden_states
...
@@ -278,14 +278,14 @@ class DbrxFusedNormAttention(nn.Module):
...
@@ -278,14 +278,14 @@ class DbrxFusedNormAttention(nn.Module):
self
,
self
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
norm_1
(
hidden_states
)
hidden_states
=
self
.
norm_1
(
hidden_states
)
x
=
self
.
attn
(
x
=
self
.
attn
(
position_ids
=
position_ids
,
position_ids
=
position_ids
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
)
hidden_states
=
residual
+
x
hidden_states
=
residual
+
x
residual
=
hidden_states
residual
=
hidden_states
...
@@ -310,12 +310,12 @@ class DbrxBlock(nn.Module):
...
@@ -310,12 +310,12 @@ class DbrxBlock(nn.Module):
self
,
self
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
,
residual
=
self
.
norm_attn_norm
(
hidden_states
,
residual
=
self
.
norm_attn_norm
(
position_ids
=
position_ids
,
position_ids
=
position_ids
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
)
hidden_states
=
self
.
ffn
(
hidden_states
)
hidden_states
=
self
.
ffn
(
hidden_states
)
hidden_states
=
hidden_states
+
residual
hidden_states
=
hidden_states
+
residual
...
@@ -349,7 +349,7 @@ class DbrxModel(nn.Module):
...
@@ -349,7 +349,7 @@ class DbrxModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
if
input_embeds
is
None
:
...
@@ -358,7 +358,7 @@ class DbrxModel(nn.Module):
...
@@ -358,7 +358,7 @@ class DbrxModel(nn.Module):
hidden_states
=
input_embeds
hidden_states
=
input_embeds
for
i
in
range
(
len
(
self
.
blocks
)):
for
i
in
range
(
len
(
self
.
blocks
)):
block
=
self
.
blocks
[
i
]
block
=
self
.
blocks
[
i
]
hidden_states
=
block
(
position_ids
,
hidden_states
,
input_metadata
)
hidden_states
=
block
(
position_ids
,
hidden_states
,
forward_batch
)
hidden_states
=
self
.
norm_f
(
hidden_states
)
hidden_states
=
self
.
norm_f
(
hidden_states
)
return
hidden_states
return
hidden_states
...
@@ -388,11 +388,11 @@ class DbrxForCausalLM(nn.Module):
...
@@ -388,11 +388,11 @@ class DbrxForCausalLM(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
forward_batch
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/deepseek.py
View file @
36d5acfc
...
@@ -46,7 +46,7 @@ from sglang.srt.layers.linear import (
...
@@ -46,7 +46,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
class
DeepseekMLP
(
nn
.
Module
):
class
DeepseekMLP
(
nn
.
Module
):
...
@@ -246,12 +246,12 @@ class DeepseekAttention(nn.Module):
...
@@ -246,12 +246,12 @@ class DeepseekAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -303,7 +303,7 @@ class DeepseekDecoderLayer(nn.Module):
...
@@ -303,7 +303,7 @@ class DeepseekDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Self Attention
# Self Attention
...
@@ -315,7 +315,7 @@ class DeepseekDecoderLayer(nn.Module):
...
@@ -315,7 +315,7 @@ class DeepseekDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
)
# Fully Connected
# Fully Connected
...
@@ -356,14 +356,14 @@ class DeepseekModel(nn.Module):
...
@@ -356,14 +356,14 @@ class DeepseekModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
)
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
...
@@ -391,11 +391,11 @@ class DeepseekForCausalLM(nn.Module):
...
@@ -391,11 +391,11 @@ class DeepseekForCausalLM(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment