Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
69b3bb9a
Unverified
Commit
69b3bb9a
authored
Sep 09, 2024
by
Liangsheng Yin
Committed by
GitHub
Sep 09, 2024
Browse files
Unify forward mode (#1360)
parent
689ff588
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
54 additions
and
58 deletions
+54
-58
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+2
-3
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+3
-3
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+2
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+7
-0
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+3
-8
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+23
-18
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+10
-20
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+2
-2
python/sglang/srt/models/llavavid.py
python/sglang/srt/models/llavavid.py
+2
-2
No files found.
python/sglang/bench_latency.py
View file @
69b3bb9a
...
@@ -60,7 +60,6 @@ import torch.distributed as dist
...
@@ -60,7 +60,6 @@ import torch.distributed as dist
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
...
@@ -208,14 +207,14 @@ def extend(reqs, model_runner):
...
@@ -208,14 +207,14 @@ def extend(reqs, model_runner):
tree_cache
=
None
,
tree_cache
=
None
,
)
)
batch
.
prepare_for_extend
(
model_runner
.
model_config
.
vocab_size
)
batch
.
prepare_for_extend
(
model_runner
.
model_config
.
vocab_size
)
sample_output
,
logits_output
=
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
sample_output
,
logits_output
=
model_runner
.
forward
(
batch
)
next_token_ids
=
sample_output
.
batch_next_token_ids
.
tolist
()
next_token_ids
=
sample_output
.
batch_next_token_ids
.
tolist
()
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
def
decode
(
input_token_ids
,
batch
,
model_runner
):
def
decode
(
input_token_ids
,
batch
,
model_runner
):
batch
.
prepare_for_decode
(
input_token_ids
)
batch
.
prepare_for_decode
(
input_token_ids
)
sample_output
,
logits_output
=
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
sample_output
,
logits_output
=
model_runner
.
forward
(
batch
)
next_token_ids
=
sample_output
.
batch_next_token_ids
.
tolist
()
next_token_ids
=
sample_output
.
batch_next_token_ids
.
tolist
()
return
next_token_ids
,
logits_output
.
next_token_logits
return
next_token_ids
,
logits_output
.
next_token_logits
...
...
python/sglang/srt/layers/logits_processor.py
View file @
69b3bb9a
...
@@ -103,7 +103,7 @@ class LogitsProcessor(nn.Module):
...
@@ -103,7 +103,7 @@ class LogitsProcessor(nn.Module):
@
staticmethod
@
staticmethod
def
get_top_logprobs
(
all_logprobs
:
torch
.
Tensor
,
logits_metadata
:
LogitsMetadata
):
def
get_top_logprobs
(
all_logprobs
:
torch
.
Tensor
,
logits_metadata
:
LogitsMetadata
):
if
logits_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
if
logits_metadata
.
forward_mode
.
is_decode
()
:
output_top_logprobs
=
[]
output_top_logprobs
=
[]
max_k
=
max
(
logits_metadata
.
top_logprobs_nums
)
max_k
=
max
(
logits_metadata
.
top_logprobs_nums
)
ret
=
all_logprobs
.
topk
(
max_k
,
dim
=
1
)
ret
=
all_logprobs
.
topk
(
max_k
,
dim
=
1
)
...
@@ -163,7 +163,7 @@ class LogitsProcessor(nn.Module):
...
@@ -163,7 +163,7 @@ class LogitsProcessor(nn.Module):
assert
isinstance
(
logits_metadata
,
LogitsMetadata
)
assert
isinstance
(
logits_metadata
,
LogitsMetadata
)
# Get the last hidden states and last logits for the next token prediction
# Get the last hidden states and last logits for the next token prediction
if
logits_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
if
logits_metadata
.
forward_mode
.
is_decode
()
:
last_index
=
None
last_index
=
None
last_hidden
=
hidden_states
last_hidden
=
hidden_states
else
:
else
:
...
@@ -195,7 +195,7 @@ class LogitsProcessor(nn.Module):
...
@@ -195,7 +195,7 @@ class LogitsProcessor(nn.Module):
)
)
else
:
else
:
# When logprob is requested, compute the logits for all tokens.
# When logprob is requested, compute the logits for all tokens.
if
logits_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
if
logits_metadata
.
forward_mode
.
is_decode
()
:
last_logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
last_logits
,
dim
=-
1
)
last_logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
last_logits
,
dim
=-
1
)
# Get the logprob of top-k tokens
# Get the logprob of top-k tokens
...
...
python/sglang/srt/layers/radix_attention.py
View file @
69b3bb9a
...
@@ -197,9 +197,9 @@ class RadixAttention(nn.Module):
...
@@ -197,9 +197,9 @@ class RadixAttention(nn.Module):
k
=
k
.
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
qk_head_dim
)
k
=
k
.
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
qk_head_dim
)
v
=
v
.
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
v_head_dim
)
v
=
v
.
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
v_head_dim
)
if
input_metadata
.
forward_mode
==
ForwardMode
.
EXTEND
:
if
input_metadata
.
forward_mode
.
is_extend
()
:
return
self
.
extend_forward
(
q
,
k
,
v
,
input_metadata
)
return
self
.
extend_forward
(
q
,
k
,
v
,
input_metadata
)
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
elif
input_metadata
.
forward_mode
.
is_decode
()
:
return
self
.
decode_forward
(
q
,
k
,
v
,
input_metadata
)
return
self
.
decode_forward
(
q
,
k
,
v
,
input_metadata
)
def
store_kv_cache
(
self
,
cache_k
,
cache_v
,
input_metadata
:
InputMetadata
):
def
store_kv_cache
(
self
,
cache_k
,
cache_v
,
input_metadata
:
InputMetadata
):
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
69b3bb9a
...
@@ -29,6 +29,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
...
@@ -29,6 +29,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -334,6 +335,8 @@ class ScheduleBatch:
...
@@ -334,6 +335,8 @@ class ScheduleBatch:
token_to_kv_pool
:
BaseTokenToKVPool
token_to_kv_pool
:
BaseTokenToKVPool
tree_cache
:
BasePrefixCache
tree_cache
:
BasePrefixCache
forward_mode
:
ForwardMode
=
None
# Batched arguments to model runner
# Batched arguments to model runner
input_ids
:
torch
.
Tensor
=
None
input_ids
:
torch
.
Tensor
=
None
req_pool_indices
:
torch
.
Tensor
=
None
req_pool_indices
:
torch
.
Tensor
=
None
...
@@ -397,6 +400,8 @@ class ScheduleBatch:
...
@@ -397,6 +400,8 @@ class ScheduleBatch:
return
out_cache_loc
return
out_cache_loc
def
prepare_for_extend
(
self
,
vocab_size
:
int
):
def
prepare_for_extend
(
self
,
vocab_size
:
int
):
self
.
forward_mode
=
ForwardMode
.
EXTEND
bs
=
self
.
batch_size
()
bs
=
self
.
batch_size
()
reqs
=
self
.
reqs
reqs
=
self
.
reqs
input_ids
=
[
r
.
fill_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
input_ids
=
[
r
.
fill_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
...
@@ -626,6 +631,8 @@ class ScheduleBatch:
...
@@ -626,6 +631,8 @@ class ScheduleBatch:
return
jump_forward_reqs
return
jump_forward_reqs
def
prepare_for_decode
(
self
,
input_ids
=
None
):
def
prepare_for_decode
(
self
,
input_ids
=
None
):
self
.
forward_mode
=
ForwardMode
.
DECODE
if
input_ids
is
None
:
if
input_ids
is
None
:
input_ids
=
[
input_ids
=
[
r
.
output_ids
[
-
1
]
if
r
.
output_ids
else
r
.
origin_input_ids
[
-
1
]
r
.
output_ids
[
-
1
]
if
r
.
output_ids
else
r
.
origin_input_ids
[
-
1
]
...
...
python/sglang/srt/managers/tp_worker.py
View file @
69b3bb9a
...
@@ -53,7 +53,6 @@ from sglang.srt.managers.schedule_batch import (
...
@@ -53,7 +53,6 @@ from sglang.srt.managers.schedule_batch import (
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
...
@@ -521,9 +520,7 @@ class ModelTpServer:
...
@@ -521,9 +520,7 @@ class ModelTpServer:
if
self
.
model_runner
.
is_generation
:
if
self
.
model_runner
.
is_generation
:
# Forward and sample the next tokens
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
if
batch
.
extend_num_tokens
!=
0
:
sample_output
,
logits_output
=
self
.
model_runner
.
forward
(
sample_output
,
logits_output
=
self
.
model_runner
.
forward
(
batch
)
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
=
batch
.
check_sample_results
(
sample_output
)
next_token_ids
=
batch
.
check_sample_results
(
sample_output
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
next_token_ids
...
@@ -588,7 +585,7 @@ class ModelTpServer:
...
@@ -588,7 +585,7 @@ class ModelTpServer:
pt
+=
req
.
extend_input_len
pt
+=
req
.
extend_input_len
else
:
else
:
assert
batch
.
extend_num_tokens
!=
0
assert
batch
.
extend_num_tokens
!=
0
logits_output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
logits_output
=
self
.
model_runner
.
forward
(
batch
)
embeddings
=
logits_output
.
embeddings
.
tolist
()
embeddings
=
logits_output
.
embeddings
.
tolist
()
# Check finish conditions
# Check finish conditions
...
@@ -699,9 +696,7 @@ class ModelTpServer:
...
@@ -699,9 +696,7 @@ class ModelTpServer:
batch
.
prepare_for_decode
()
batch
.
prepare_for_decode
()
# Forward and sample the next tokens
# Forward and sample the next tokens
sample_output
,
logits_output
=
self
.
model_runner
.
forward
(
sample_output
,
logits_output
=
self
.
model_runner
.
forward
(
batch
)
batch
,
ForwardMode
.
DECODE
)
next_token_ids
=
batch
.
check_sample_results
(
sample_output
)
next_token_ids
=
batch
.
check_sample_results
(
sample_output
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
next_token_ids
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
69b3bb9a
...
@@ -25,10 +25,9 @@ import torch
...
@@ -25,10 +25,9 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
...
@@ -41,6 +40,15 @@ class ForwardMode(IntEnum):
...
@@ -41,6 +40,15 @@ class ForwardMode(IntEnum):
# Decode one token.
# Decode one token.
DECODE
=
auto
()
DECODE
=
auto
()
def
is_prefill
(
self
):
return
self
==
ForwardMode
.
PREFILL
def
is_extend
(
self
):
return
self
==
ForwardMode
.
EXTEND
def
is_decode
(
self
):
return
self
==
ForwardMode
.
DECODE
@
dataclass
@
dataclass
class
InputMetadata
:
class
InputMetadata
:
...
@@ -102,7 +110,7 @@ class InputMetadata:
...
@@ -102,7 +110,7 @@ class InputMetadata:
def
compute_positions
(
self
,
batch
:
ScheduleBatch
):
def
compute_positions
(
self
,
batch
:
ScheduleBatch
):
position_ids_offsets
=
batch
.
position_ids_offsets
position_ids_offsets
=
batch
.
position_ids_offsets
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
if
self
.
forward_mode
.
is_decode
()
:
if
True
:
if
True
:
self
.
positions
=
self
.
seq_lens
-
1
self
.
positions
=
self
.
seq_lens
-
1
else
:
else
:
...
@@ -141,7 +149,7 @@ class InputMetadata:
...
@@ -141,7 +149,7 @@ class InputMetadata:
self
.
positions
=
self
.
positions
.
to
(
torch
.
int64
)
self
.
positions
=
self
.
positions
.
to
(
torch
.
int64
)
def
compute_extend_infos
(
self
,
batch
:
ScheduleBatch
):
def
compute_extend_infos
(
self
,
batch
:
ScheduleBatch
):
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
if
self
.
forward_mode
.
is_decode
()
:
self
.
extend_seq_lens
=
self
.
extend_start_loc
=
self
.
extend_no_prefix
=
None
self
.
extend_seq_lens
=
self
.
extend_start_loc
=
self
.
extend_no_prefix
=
None
self
.
extend_seq_lens_cpu
=
self
.
logprob_start_lens_cpu
=
None
self
.
extend_seq_lens_cpu
=
self
.
logprob_start_lens_cpu
=
None
else
:
else
:
...
@@ -173,10 +181,9 @@ class InputMetadata:
...
@@ -173,10 +181,9 @@ class InputMetadata:
cls
,
cls
,
model_runner
:
"ModelRunner"
,
model_runner
:
"ModelRunner"
,
batch
:
ScheduleBatch
,
batch
:
ScheduleBatch
,
forward_mode
:
ForwardMode
,
):
):
ret
=
cls
(
ret
=
cls
(
forward_mode
=
forward_mode
,
forward_mode
=
batch
.
forward_mode
,
sampling_info
=
batch
.
sampling_info
,
sampling_info
=
batch
.
sampling_info
,
batch_size
=
batch
.
batch_size
(),
batch_size
=
batch
.
batch_size
(),
req_pool_indices
=
batch
.
req_pool_indices
,
req_pool_indices
=
batch
.
req_pool_indices
,
...
@@ -194,13 +201,11 @@ class InputMetadata:
...
@@ -194,13 +201,11 @@ class InputMetadata:
ret
.
compute_extend_infos
(
batch
)
ret
.
compute_extend_infos
(
batch
)
if
(
fm
=
batch
.
forward_mode
forward_mode
!=
ForwardMode
.
DECODE
if
not
fm
.
is_decode
()
or
model_runner
.
server_args
.
disable_flashinfer
:
or
model_runner
.
server_args
.
disable_flashinfer
):
ret
.
total_num_tokens
=
int
(
torch
.
sum
(
ret
.
seq_lens
))
ret
.
total_num_tokens
=
int
(
torch
.
sum
(
ret
.
seq_lens
))
if
forward_mode
!=
ForwardMode
.
DECODE
:
if
not
fm
.
is_decode
()
:
ret
.
init_multimuldal_info
(
batch
)
ret
.
init_multimuldal_info
(
batch
)
if
model_runner
.
server_args
.
disable_flashinfer
:
if
model_runner
.
server_args
.
disable_flashinfer
:
...
@@ -209,7 +214,7 @@ class InputMetadata:
...
@@ -209,7 +214,7 @@ class InputMetadata:
flashinfer_use_ragged
=
False
flashinfer_use_ragged
=
False
if
not
model_runner
.
server_args
.
disable_flashinfer
:
if
not
model_runner
.
server_args
.
disable_flashinfer
:
if
(
if
(
forward_mode
!=
ForwardMode
.
DECODE
not
fm
.
is_decode
()
and
int
(
torch
.
sum
(
ret
.
seq_lens
))
>
4096
and
int
(
torch
.
sum
(
ret
.
seq_lens
))
>
4096
and
model_runner
.
sliding_window_size
is
None
and
model_runner
.
sliding_window_size
is
None
):
):
...
@@ -226,7 +231,7 @@ class InputMetadata:
...
@@ -226,7 +231,7 @@ class InputMetadata:
self
.
triton_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
,
dtype
=
torch
.
int32
)
self
.
triton_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
,
dtype
=
torch
.
int32
)
self
.
triton_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
seq_lens
[:
-
1
],
dim
=
0
)
self
.
triton_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
seq_lens
[:
-
1
],
dim
=
0
)
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
if
self
.
forward_mode
.
is_decode
()
:
self
.
triton_max_extend_len
=
None
self
.
triton_max_extend_len
=
None
else
:
else
:
self
.
triton_prefix_lens
=
torch
.
tensor
(
batch
.
prefix_lens_cpu
,
device
=
"cuda"
)
self
.
triton_prefix_lens
=
torch
.
tensor
(
batch
.
prefix_lens_cpu
,
device
=
"cuda"
)
...
@@ -239,7 +244,7 @@ class InputMetadata:
...
@@ -239,7 +244,7 @@ class InputMetadata:
prefix_lens_cpu
,
prefix_lens_cpu
,
flashinfer_use_ragged
,
flashinfer_use_ragged
,
):
):
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
if
self
.
forward_mode
.
is_decode
()
:
prefix_lens
=
None
prefix_lens
=
None
else
:
else
:
prefix_lens
=
self
.
extend_prefix_lens
prefix_lens
=
self
.
extend_prefix_lens
...
@@ -339,7 +344,7 @@ def update_flashinfer_indices(
...
@@ -339,7 +344,7 @@ def update_flashinfer_indices(
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
forward_mode
==
ForwardMode
.
DECODE
:
if
forward_mode
.
is_decode
()
:
# CUDA graph uses different flashinfer_decode_wrapper
# CUDA graph uses different flashinfer_decode_wrapper
if
flashinfer_decode_wrapper
is
None
:
if
flashinfer_decode_wrapper
is
None
:
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
...
@@ -388,7 +393,7 @@ def update_flashinfer_indices(
...
@@ -388,7 +393,7 @@ def update_flashinfer_indices(
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
for
wrapper_id
in
range
(
2
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
if
wrapper_id
==
0
:
if
forward_mode
==
ForwardMode
.
DECODE
:
if
forward_mode
.
is_decode
()
:
paged_kernel_lens
=
torch
.
minimum
(
paged_kernel_lens
=
torch
.
minimum
(
seq_lens
,
torch
.
tensor
(
model_runner
.
sliding_window_size
+
1
)
seq_lens
,
torch
.
tensor
(
model_runner
.
sliding_window_size
+
1
)
)
)
...
@@ -418,7 +423,7 @@ def update_flashinfer_indices(
...
@@ -418,7 +423,7 @@ def update_flashinfer_indices(
kv_indices
,
kv_indices
,
)
)
if
forward_mode
==
ForwardMode
.
DECODE
:
if
forward_mode
.
is_decode
()
:
# CUDA graph uses different flashinfer_decode_wrapper
# CUDA graph uses different flashinfer_decode_wrapper
if
flashinfer_decode_wrapper
is
None
:
if
flashinfer_decode_wrapper
is
None
:
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
69b3bb9a
...
@@ -530,11 +530,7 @@ class ModelRunner:
...
@@ -530,11 +530,7 @@ class ModelRunner:
):
):
return
self
.
cuda_graph_runner
.
replay
(
batch
)
return
self
.
cuda_graph_runner
.
replay
(
batch
)
input_metadata
=
InputMetadata
.
from_schedule_batch
(
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
)
self
,
batch
,
ForwardMode
.
DECODE
,
)
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
...
@@ -542,11 +538,7 @@ class ModelRunner:
...
@@ -542,11 +538,7 @@ class ModelRunner:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
forward_extend
(
self
,
batch
:
ScheduleBatch
):
def
forward_extend
(
self
,
batch
:
ScheduleBatch
):
input_metadata
=
InputMetadata
.
from_schedule_batch
(
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
)
self
,
batch
,
forward_mode
=
ForwardMode
.
EXTEND
,
)
if
self
.
is_generation
:
if
self
.
is_generation
:
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
...
@@ -562,11 +554,7 @@ class ModelRunner:
...
@@ -562,11 +554,7 @@ class ModelRunner:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
forward_extend_multi_modal
(
self
,
batch
:
ScheduleBatch
):
def
forward_extend_multi_modal
(
self
,
batch
:
ScheduleBatch
):
input_metadata
=
InputMetadata
.
from_schedule_batch
(
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
)
self
,
batch
,
forward_mode
=
ForwardMode
.
EXTEND
,
)
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
batch
.
input_ids
,
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
.
positions
,
...
@@ -577,16 +565,18 @@ class ModelRunner:
...
@@ -577,16 +565,18 @@ class ModelRunner:
)
)
def
forward
(
def
forward
(
self
,
batch
:
ScheduleBatch
,
forward_mode
:
ForwardMode
self
,
batch
:
ScheduleBatch
)
->
Tuple
[
SampleOutput
,
LogitsProcessorOutput
]:
)
->
Tuple
[
SampleOutput
,
LogitsProcessorOutput
]:
if
self
.
is_multimodal_model
and
forward_mode
==
ForwardMode
.
EXTEND
:
assert
batch
.
forward_mode
is
not
None
if
self
.
is_multimodal_model
and
batch
.
forward_mode
.
is_extend
():
return
self
.
forward_extend_multi_modal
(
batch
)
return
self
.
forward_extend_multi_modal
(
batch
)
elif
forward_mode
==
ForwardMode
.
DECODE
:
elif
batch
.
forward_mode
.
is_decode
()
:
return
self
.
forward_decode
(
batch
)
return
self
.
forward_decode
(
batch
)
elif
forward_mode
==
ForwardMode
.
EXTEND
:
elif
batch
.
forward_mode
.
is_extend
()
:
return
self
.
forward_extend
(
batch
)
return
self
.
forward_extend
(
batch
)
else
:
else
:
raise
ValueError
(
f
"Invaid forward mode:
{
forward_mode
}
"
)
raise
ValueError
(
f
"Invaid forward mode:
{
batch
.
forward_mode
}
"
)
@
lru_cache
()
@
lru_cache
()
...
...
python/sglang/srt/models/llava.py
View file @
69b3bb9a
...
@@ -136,7 +136,7 @@ class LlavaBaseForCausalLM(nn.Module):
...
@@ -136,7 +136,7 @@ class LlavaBaseForCausalLM(nn.Module):
image_sizes
:
Optional
[
List
[
List
[
int
]]]
=
None
,
image_sizes
:
Optional
[
List
[
List
[
int
]]]
=
None
,
image_offsets
:
Optional
[
List
[
int
]]
=
None
,
image_offsets
:
Optional
[
List
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
EXTEND
:
if
input_metadata
.
forward_mode
.
is_extend
()
:
bs
=
input_metadata
.
batch_size
bs
=
input_metadata
.
batch_size
# Got List[List[str]] extend it to List[str]
# Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size
# The length of the List should be equal to batch size
...
@@ -357,7 +357,7 @@ class LlavaBaseForCausalLM(nn.Module):
...
@@ -357,7 +357,7 @@ class LlavaBaseForCausalLM(nn.Module):
return
self
.
language_model
(
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
=
input_embeds
input_ids
,
positions
,
input_metadata
,
input_embeds
=
input_embeds
)
)
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
elif
input_metadata
.
forward_mode
.
is_decode
()
:
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/llavavid.py
View file @
69b3bb9a
...
@@ -116,7 +116,7 @@ class LlavaVidForCausalLM(nn.Module):
...
@@ -116,7 +116,7 @@ class LlavaVidForCausalLM(nn.Module):
image_sizes
:
Optional
[
List
[
List
[
int
]]]
=
None
,
image_sizes
:
Optional
[
List
[
List
[
int
]]]
=
None
,
image_offsets
:
Optional
[
List
[
int
]]
=
None
,
image_offsets
:
Optional
[
List
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
EXTEND
:
if
input_metadata
.
forward_mode
.
is_extend
()
:
bs
=
input_metadata
.
batch_size
bs
=
input_metadata
.
batch_size
# Embed text inputs
# Embed text inputs
...
@@ -199,7 +199,7 @@ class LlavaVidForCausalLM(nn.Module):
...
@@ -199,7 +199,7 @@ class LlavaVidForCausalLM(nn.Module):
return
self
.
language_model
(
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
=
input_embeds
input_ids
,
positions
,
input_metadata
,
input_embeds
=
input_embeds
)
)
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
elif
input_metadata
.
forward_mode
.
is_decode
()
:
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment