Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
69b3bb9a
Unverified
Commit
69b3bb9a
authored
Sep 09, 2024
by
Liangsheng Yin
Committed by
GitHub
Sep 09, 2024
Browse files
Unify forward mode (#1360)
parent
689ff588
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
54 additions
and
58 deletions
+54
-58
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+2
-3
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+3
-3
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+2
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+7
-0
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+3
-8
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+23
-18
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+10
-20
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+2
-2
python/sglang/srt/models/llavavid.py
python/sglang/srt/models/llavavid.py
+2
-2
No files found.
python/sglang/bench_latency.py
View file @
69b3bb9a
...
...
@@ -60,7 +60,6 @@ import torch.distributed as dist
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
...
...
@@ -208,14 +207,14 @@ def extend(reqs, model_runner):
tree_cache
=
None
,
)
batch
.
prepare_for_extend
(
model_runner
.
model_config
.
vocab_size
)
sample_output
,
logits_output
=
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
sample_output
,
logits_output
=
model_runner
.
forward
(
batch
)
next_token_ids
=
sample_output
.
batch_next_token_ids
.
tolist
()
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
def
decode
(
input_token_ids
,
batch
,
model_runner
):
batch
.
prepare_for_decode
(
input_token_ids
)
sample_output
,
logits_output
=
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
sample_output
,
logits_output
=
model_runner
.
forward
(
batch
)
next_token_ids
=
sample_output
.
batch_next_token_ids
.
tolist
()
return
next_token_ids
,
logits_output
.
next_token_logits
...
...
python/sglang/srt/layers/logits_processor.py
View file @
69b3bb9a
...
...
@@ -103,7 +103,7 @@ class LogitsProcessor(nn.Module):
@
staticmethod
def
get_top_logprobs
(
all_logprobs
:
torch
.
Tensor
,
logits_metadata
:
LogitsMetadata
):
if
logits_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
if
logits_metadata
.
forward_mode
.
is_decode
()
:
output_top_logprobs
=
[]
max_k
=
max
(
logits_metadata
.
top_logprobs_nums
)
ret
=
all_logprobs
.
topk
(
max_k
,
dim
=
1
)
...
...
@@ -163,7 +163,7 @@ class LogitsProcessor(nn.Module):
assert
isinstance
(
logits_metadata
,
LogitsMetadata
)
# Get the last hidden states and last logits for the next token prediction
if
logits_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
if
logits_metadata
.
forward_mode
.
is_decode
()
:
last_index
=
None
last_hidden
=
hidden_states
else
:
...
...
@@ -195,7 +195,7 @@ class LogitsProcessor(nn.Module):
)
else
:
# When logprob is requested, compute the logits for all tokens.
if
logits_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
if
logits_metadata
.
forward_mode
.
is_decode
()
:
last_logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
last_logits
,
dim
=-
1
)
# Get the logprob of top-k tokens
...
...
python/sglang/srt/layers/radix_attention.py
View file @
69b3bb9a
...
...
@@ -197,9 +197,9 @@ class RadixAttention(nn.Module):
k
=
k
.
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
qk_head_dim
)
v
=
v
.
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
v_head_dim
)
if
input_metadata
.
forward_mode
==
ForwardMode
.
EXTEND
:
if
input_metadata
.
forward_mode
.
is_extend
()
:
return
self
.
extend_forward
(
q
,
k
,
v
,
input_metadata
)
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
elif
input_metadata
.
forward_mode
.
is_decode
()
:
return
self
.
decode_forward
(
q
,
k
,
v
,
input_metadata
)
def
store_kv_cache
(
self
,
cache_k
,
cache_v
,
input_metadata
:
InputMetadata
):
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
69b3bb9a
...
...
@@ -29,6 +29,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
if
TYPE_CHECKING
:
...
...
@@ -334,6 +335,8 @@ class ScheduleBatch:
token_to_kv_pool
:
BaseTokenToKVPool
tree_cache
:
BasePrefixCache
forward_mode
:
ForwardMode
=
None
# Batched arguments to model runner
input_ids
:
torch
.
Tensor
=
None
req_pool_indices
:
torch
.
Tensor
=
None
...
...
@@ -397,6 +400,8 @@ class ScheduleBatch:
return
out_cache_loc
def
prepare_for_extend
(
self
,
vocab_size
:
int
):
self
.
forward_mode
=
ForwardMode
.
EXTEND
bs
=
self
.
batch_size
()
reqs
=
self
.
reqs
input_ids
=
[
r
.
fill_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
...
...
@@ -626,6 +631,8 @@ class ScheduleBatch:
return
jump_forward_reqs
def
prepare_for_decode
(
self
,
input_ids
=
None
):
self
.
forward_mode
=
ForwardMode
.
DECODE
if
input_ids
is
None
:
input_ids
=
[
r
.
output_ids
[
-
1
]
if
r
.
output_ids
else
r
.
origin_input_ids
[
-
1
]
...
...
python/sglang/srt/managers/tp_worker.py
View file @
69b3bb9a
...
...
@@ -53,7 +53,6 @@ from sglang.srt.managers.schedule_batch import (
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
...
...
@@ -521,9 +520,7 @@ class ModelTpServer:
if
self
.
model_runner
.
is_generation
:
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
sample_output
,
logits_output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
sample_output
,
logits_output
=
self
.
model_runner
.
forward
(
batch
)
next_token_ids
=
batch
.
check_sample_results
(
sample_output
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
...
...
@@ -588,7 +585,7 @@ class ModelTpServer:
pt
+=
req
.
extend_input_len
else
:
assert
batch
.
extend_num_tokens
!=
0
logits_output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
logits_output
=
self
.
model_runner
.
forward
(
batch
)
embeddings
=
logits_output
.
embeddings
.
tolist
()
# Check finish conditions
...
...
@@ -699,9 +696,7 @@ class ModelTpServer:
batch
.
prepare_for_decode
()
# Forward and sample the next tokens
sample_output
,
logits_output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
sample_output
,
logits_output
=
self
.
model_runner
.
forward
(
batch
)
next_token_ids
=
batch
.
check_sample_results
(
sample_output
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
69b3bb9a
...
...
@@ -25,10 +25,9 @@ import torch
import
triton
import
triton.language
as
tl
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
...
...
@@ -41,6 +40,15 @@ class ForwardMode(IntEnum):
# Decode one token.
DECODE
=
auto
()
def
is_prefill
(
self
):
return
self
==
ForwardMode
.
PREFILL
def
is_extend
(
self
):
return
self
==
ForwardMode
.
EXTEND
def
is_decode
(
self
):
return
self
==
ForwardMode
.
DECODE
@
dataclass
class
InputMetadata
:
...
...
@@ -102,7 +110,7 @@ class InputMetadata:
def
compute_positions
(
self
,
batch
:
ScheduleBatch
):
position_ids_offsets
=
batch
.
position_ids_offsets
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
if
self
.
forward_mode
.
is_decode
()
:
if
True
:
self
.
positions
=
self
.
seq_lens
-
1
else
:
...
...
@@ -141,7 +149,7 @@ class InputMetadata:
self
.
positions
=
self
.
positions
.
to
(
torch
.
int64
)
def
compute_extend_infos
(
self
,
batch
:
ScheduleBatch
):
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
if
self
.
forward_mode
.
is_decode
()
:
self
.
extend_seq_lens
=
self
.
extend_start_loc
=
self
.
extend_no_prefix
=
None
self
.
extend_seq_lens_cpu
=
self
.
logprob_start_lens_cpu
=
None
else
:
...
...
@@ -173,10 +181,9 @@ class InputMetadata:
cls
,
model_runner
:
"ModelRunner"
,
batch
:
ScheduleBatch
,
forward_mode
:
ForwardMode
,
):
ret
=
cls
(
forward_mode
=
forward_mode
,
forward_mode
=
batch
.
forward_mode
,
sampling_info
=
batch
.
sampling_info
,
batch_size
=
batch
.
batch_size
(),
req_pool_indices
=
batch
.
req_pool_indices
,
...
...
@@ -194,13 +201,11 @@ class InputMetadata:
ret
.
compute_extend_infos
(
batch
)
if
(
forward_mode
!=
ForwardMode
.
DECODE
or
model_runner
.
server_args
.
disable_flashinfer
):
fm
=
batch
.
forward_mode
if
not
fm
.
is_decode
()
or
model_runner
.
server_args
.
disable_flashinfer
:
ret
.
total_num_tokens
=
int
(
torch
.
sum
(
ret
.
seq_lens
))
if
forward_mode
!=
ForwardMode
.
DECODE
:
if
not
fm
.
is_decode
()
:
ret
.
init_multimuldal_info
(
batch
)
if
model_runner
.
server_args
.
disable_flashinfer
:
...
...
@@ -209,7 +214,7 @@ class InputMetadata:
flashinfer_use_ragged
=
False
if
not
model_runner
.
server_args
.
disable_flashinfer
:
if
(
forward_mode
!=
ForwardMode
.
DECODE
not
fm
.
is_decode
()
and
int
(
torch
.
sum
(
ret
.
seq_lens
))
>
4096
and
model_runner
.
sliding_window_size
is
None
):
...
...
@@ -226,7 +231,7 @@ class InputMetadata:
self
.
triton_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
,
dtype
=
torch
.
int32
)
self
.
triton_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
seq_lens
[:
-
1
],
dim
=
0
)
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
if
self
.
forward_mode
.
is_decode
()
:
self
.
triton_max_extend_len
=
None
else
:
self
.
triton_prefix_lens
=
torch
.
tensor
(
batch
.
prefix_lens_cpu
,
device
=
"cuda"
)
...
...
@@ -239,7 +244,7 @@ class InputMetadata:
prefix_lens_cpu
,
flashinfer_use_ragged
,
):
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
if
self
.
forward_mode
.
is_decode
()
:
prefix_lens
=
None
else
:
prefix_lens
=
self
.
extend_prefix_lens
...
...
@@ -339,7 +344,7 @@ def update_flashinfer_indices(
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
forward_mode
==
ForwardMode
.
DECODE
:
if
forward_mode
.
is_decode
()
:
# CUDA graph uses different flashinfer_decode_wrapper
if
flashinfer_decode_wrapper
is
None
:
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
...
...
@@ -388,7 +393,7 @@ def update_flashinfer_indices(
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
if
forward_mode
==
ForwardMode
.
DECODE
:
if
forward_mode
.
is_decode
()
:
paged_kernel_lens
=
torch
.
minimum
(
seq_lens
,
torch
.
tensor
(
model_runner
.
sliding_window_size
+
1
)
)
...
...
@@ -418,7 +423,7 @@ def update_flashinfer_indices(
kv_indices
,
)
if
forward_mode
==
ForwardMode
.
DECODE
:
if
forward_mode
.
is_decode
()
:
# CUDA graph uses different flashinfer_decode_wrapper
if
flashinfer_decode_wrapper
is
None
:
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
69b3bb9a
...
...
@@ -530,11 +530,7 @@ class ModelRunner:
):
return
self
.
cuda_graph_runner
.
replay
(
batch
)
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
,
ForwardMode
.
DECODE
,
)
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
...
...
@@ -542,11 +538,7 @@ class ModelRunner:
@
torch
.
inference_mode
()
def
forward_extend
(
self
,
batch
:
ScheduleBatch
):
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
,
forward_mode
=
ForwardMode
.
EXTEND
,
)
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
)
if
self
.
is_generation
:
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
...
...
@@ -562,11 +554,7 @@ class ModelRunner:
@
torch
.
inference_mode
()
def
forward_extend_multi_modal
(
self
,
batch
:
ScheduleBatch
):
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
,
forward_mode
=
ForwardMode
.
EXTEND
,
)
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
...
...
@@ -577,16 +565,18 @@ class ModelRunner:
)
def
forward
(
self
,
batch
:
ScheduleBatch
,
forward_mode
:
ForwardMode
self
,
batch
:
ScheduleBatch
)
->
Tuple
[
SampleOutput
,
LogitsProcessorOutput
]:
if
self
.
is_multimodal_model
and
forward_mode
==
ForwardMode
.
EXTEND
:
assert
batch
.
forward_mode
is
not
None
if
self
.
is_multimodal_model
and
batch
.
forward_mode
.
is_extend
():
return
self
.
forward_extend_multi_modal
(
batch
)
elif
forward_mode
==
ForwardMode
.
DECODE
:
elif
batch
.
forward_mode
.
is_decode
()
:
return
self
.
forward_decode
(
batch
)
elif
forward_mode
==
ForwardMode
.
EXTEND
:
elif
batch
.
forward_mode
.
is_extend
()
:
return
self
.
forward_extend
(
batch
)
else
:
raise
ValueError
(
f
"Invaid forward mode:
{
forward_mode
}
"
)
raise
ValueError
(
f
"Invaid forward mode:
{
batch
.
forward_mode
}
"
)
@
lru_cache
()
...
...
python/sglang/srt/models/llava.py
View file @
69b3bb9a
...
...
@@ -136,7 +136,7 @@ class LlavaBaseForCausalLM(nn.Module):
image_sizes
:
Optional
[
List
[
List
[
int
]]]
=
None
,
image_offsets
:
Optional
[
List
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
EXTEND
:
if
input_metadata
.
forward_mode
.
is_extend
()
:
bs
=
input_metadata
.
batch_size
# Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size
...
...
@@ -357,7 +357,7 @@ class LlavaBaseForCausalLM(nn.Module):
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
=
input_embeds
)
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
elif
input_metadata
.
forward_mode
.
is_decode
()
:
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/llavavid.py
View file @
69b3bb9a
...
...
@@ -116,7 +116,7 @@ class LlavaVidForCausalLM(nn.Module):
image_sizes
:
Optional
[
List
[
List
[
int
]]]
=
None
,
image_offsets
:
Optional
[
List
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
EXTEND
:
if
input_metadata
.
forward_mode
.
is_extend
()
:
bs
=
input_metadata
.
batch_size
# Embed text inputs
...
...
@@ -199,7 +199,7 @@ class LlavaVidForCausalLM(nn.Module):
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
=
input_embeds
)
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
elif
input_metadata
.
forward_mode
.
is_decode
()
:
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment