Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1ac304ee
Unverified
Commit
1ac304ee
authored
Aug 08, 2024
by
Liangsheng Yin
Committed by
GitHub
Aug 08, 2024
Browse files
Adjust `InputeMetadata` and `ScheduleBatch` (#981)
parent
20a4f927
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
202 additions
and
191 deletions
+202
-191
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+16
-44
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+9
-9
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+167
-105
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+10
-33
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
1ac304ee
...
...
@@ -307,7 +307,6 @@ class ScheduleBatch:
input_ids
:
torch
.
Tensor
=
None
req_pool_indices
:
torch
.
Tensor
=
None
seq_lens
:
torch
.
Tensor
=
None
prefix_lens
:
torch
.
Tensor
=
None
position_ids_offsets
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
extend_num_tokens
:
int
=
None
...
...
@@ -316,11 +315,6 @@ class ScheduleBatch:
return_logprob
:
bool
=
False
top_logprobs_nums
:
List
[
int
]
=
None
# For multimodal
pixel_values
:
List
[
torch
.
Tensor
]
=
None
image_sizes
:
List
[
List
[
int
]]
=
None
image_offsets
:
List
[
int
]
=
None
# Batched sampling params
temperatures
:
torch
.
Tensor
=
None
top_ps
:
torch
.
Tensor
=
None
...
...
@@ -412,59 +406,40 @@ class ScheduleBatch:
self
.
logit_bias
[
i
][:
len
(
int_token_logit_bias
)]
=
int_token_logit_bias
def
prepare_for_extend
(
self
,
vocab_size
:
int
,
int_token_logit_bias
:
torch
.
Tensor
):
device
=
"cuda"
bs
=
self
.
batch_size
()
reqs
=
self
.
reqs
input_ids
=
[
r
.
input_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
prefix_indices
=
[
r
.
prefix_indices
for
r
in
reqs
]
# Handle prefix
extend_lens
=
[]
prefix_lens
=
[]
extend_num_tokens
=
sum
(
len
(
ids
)
for
ids
in
input_ids
)
seq_lens
=
[]
# Allocate memory
req_pool_indices_cpu
=
self
.
alloc_req_slots
(
bs
)
out_cache_loc
=
self
.
alloc_token_slots
(
extend_num_tokens
)
pt
=
0
for
i
,
req
in
enumerate
(
reqs
):
req
.
req_pool_idx
=
req_pool_indices_cpu
[
i
]
extend_lens
.
append
(
len
(
input_ids
[
i
]))
pre_len
,
seq_len
=
len
(
req
.
prefix_indices
),
len
(
req
.
input_ids
)
ext_len
=
seq_len
-
pre_len
seq_lens
.
append
(
seq_len
)
if
len
(
prefix_indices
[
i
])
==
0
:
prefix_lens
.
append
(
0
)
else
:
prefix_lens
.
append
(
len
(
prefix_indices
[
i
]))
if
pre_len
>
0
:
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
:
len
(
prefix_indices
[
i
])
]
=
prefix_indices
[
i
]
seq_lens
.
append
(
prefix_lens
[
-
1
]
+
extend_lens
[
-
1
])
:
pre_len
]
=
req
.
prefix_indices
# Allocate memory
seq_lens
,
prefix_lens
=
np
.
array
(
seq_lens
),
np
.
array
(
prefix_lens
)
extend_num_tokens
=
seq_lens
.
sum
()
-
prefix_lens
.
sum
()
out_cache_loc
=
self
.
alloc_token_slots
(
extend_num_tokens
)
pt
=
0
for
i
,
req
in
enumerate
(
reqs
):
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
prefix_lens
[
i
]
:
prefix_lens
[
i
]
+
extend_lens
[
i
]
]
=
out_cache_loc
[
pt
:
pt
+
extend_lens
[
i
]]
pt
+=
extend_lens
[
i
]
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
pre_len
:
seq_len
]
=
(
out_cache_loc
[
pt
:
pt
+
ext_len
]
)
pt
+=
ext_len
# Set fields
with
torch
.
device
(
"cuda"
):
self
.
input_ids
=
torch
.
tensor
(
sum
(
input_ids
,
[]),
dtype
=
torch
.
int32
)
self
.
req_pool_indices
=
torch
.
tensor
(
req_pool_indices_cpu
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int32
)
self
.
position_ids_offsets
=
torch
.
zeros
((
bs
,),
dtype
=
torch
.
int32
)
self
.
pixel_values
=
[
r
.
pixel_values
for
r
in
reqs
]
self
.
image_sizes
=
[
r
.
image_size
for
r
in
reqs
]
self
.
image_offsets
=
[
(
r
.
image_offset
-
p_len
)
if
r
.
image_offset
is
not
None
else
0
for
r
,
p_len
in
zip
(
reqs
,
prefix_lens
)
]
self
.
prefix_lens
=
torch
.
tensor
(
prefix_lens
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
position_ids_offsets
=
torch
.
zeros
((
bs
,),
dtype
=
torch
.
int64
)
self
.
extend_num_tokens
=
extend_num_tokens
self
.
out_cache_loc
=
out_cache_loc
self
.
top_logprobs_nums
=
[
r
.
top_logprobs_num
for
r
in
reqs
]
...
...
@@ -642,7 +617,6 @@ class ScheduleBatch:
]
self
.
input_ids
=
torch
.
tensor
(
input_ids
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
seq_lens
.
add_
(
1
)
self
.
prefix_lens
=
None
# Alloc mem
bs
=
self
.
batch_size
()
...
...
@@ -667,7 +641,6 @@ class ScheduleBatch:
self
.
seq_lens
=
self
.
seq_lens
[
new_indices
]
self
.
input_ids
=
None
self
.
req_pool_indices
=
self
.
req_pool_indices
[
new_indices
]
self
.
prefix_lens
=
None
self
.
position_ids_offsets
=
self
.
position_ids_offsets
[
new_indices
]
self
.
out_cache_loc
=
None
self
.
top_logprobs_nums
=
[
self
.
top_logprobs_nums
[
i
]
for
i
in
unfinished_indices
]
...
...
@@ -692,7 +665,6 @@ class ScheduleBatch:
[
self
.
req_pool_indices
,
other
.
req_pool_indices
]
)
self
.
seq_lens
=
torch
.
concat
([
self
.
seq_lens
,
other
.
seq_lens
])
self
.
prefix_lens
=
None
self
.
position_ids_offsets
=
torch
.
concat
(
[
self
.
position_ids_offsets
,
other
.
position_ids_offsets
]
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
1ac304ee
...
...
@@ -33,7 +33,7 @@ from sglang.srt.managers.schedule_batch import ScheduleBatch
from
sglang.srt.model_executor.forward_batch_info
import
(
ForwardMode
,
InputMetadata
,
init
_flashinfer_
arg
s
,
update
_flashinfer_
indice
s
,
)
from
sglang.srt.utils
import
monkey_patch_vllm_all_gather
...
...
@@ -165,7 +165,7 @@ class CudaGraphRunner:
paged_kv_indices_buffer
=
self
.
flashinfer_kv_indices
,
paged_kv_last_page_len_buffer
=
self
.
flashinfer_kv_last_page_len
[:
bs
],
)
init
_flashinfer_
arg
s
(
update
_flashinfer_
indice
s
(
ForwardMode
.
DECODE
,
self
.
model_runner
,
req_pool_indices
,
...
...
@@ -176,19 +176,19 @@ class CudaGraphRunner:
# Run and capture
def
run_once
():
input_metadata
=
InputMetadata
.
create
(
self
.
model_runner
,
input_metadata
=
InputMetadata
(
forward_mode
=
ForwardMode
.
DECODE
,
batch_size
=
bs
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
p
re
fix_lens
=
None
,
position_ids_offsets
=
position_ids_offsets
,
re
q_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
out_cache_loc
=
out_cache_loc
,
return_logprob
=
False
,
top_logprobs_nums
=
0
,
skip_flashinfer_init
=
True
,
positions
=
(
seq_lens
-
1
).
to
(
torch
.
int64
),
flashinfer_decode_wrapper
=
flashinfer_decode_wrapper
,
)
input_metadata
.
flashinfer_decode_wrapper
=
flashinfer_decode_wrapper
return
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
...
...
@@ -222,7 +222,7 @@ class CudaGraphRunner:
self
.
out_cache_loc
[:
raw_bs
]
=
batch
.
out_cache_loc
# FlashInfer inputs
init
_flashinfer_
arg
s
(
update
_flashinfer_
indice
s
(
ForwardMode
.
DECODE
,
self
.
model_runner
,
self
.
req_pool_indices
[:
bs
],
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
1ac304ee
...
...
@@ -16,13 +16,17 @@ limitations under the License.
"""ModelRunner runs the forward passes of the models."""
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
typing
import
List
from
typing
import
TYPE_CHECKING
,
List
import
numpy
as
np
import
torch
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
class
ForwardMode
(
IntEnum
):
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
...
...
@@ -39,25 +43,33 @@ class InputMetadata:
forward_mode
:
ForwardMode
batch_size
:
int
total_num_tokens
:
int
req_pool_indices
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
positions
:
torch
.
Tensor
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
BaseTokenToKVPool
# For extend
extend_seq_lens
:
torch
.
Tensor
extend_start_loc
:
torch
.
Tensor
extend_no_prefix
:
bool
# Output location of the KV cache
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
total_num_tokens
:
int
=
None
# Position information
positions
:
torch
.
Tensor
=
None
# For extend
extend_seq_lens
:
torch
.
Tensor
=
None
extend_start_loc
:
torch
.
Tensor
=
None
extend_no_prefix
:
bool
=
None
# Output options
return_logprob
:
bool
=
False
top_logprobs_nums
:
List
[
int
]
=
None
# For multimodal
pixel_values
:
List
[
torch
.
Tensor
]
=
None
image_sizes
:
List
[
List
[
int
]]
=
None
image_offsets
:
List
[
int
]
=
None
# Trition attention backend
triton_max_seq_len
:
int
=
0
triton_max_extend_len
:
int
=
0
...
...
@@ -70,107 +82,170 @@ class InputMetadata:
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
flashinfer_use_ragged
:
bool
=
False
@
classmethod
def
create
(
cls
,
model_runner
,
forward_mode
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
position_ids_offsets
,
out_cache_loc
,
top_logprobs_nums
=
None
,
return_logprob
=
False
,
skip_flashinfer_init
=
False
,
):
flashinfer_use_ragged
=
False
if
not
skip_flashinfer_init
and
not
model_runner
.
server_args
.
disable_flashinfer
:
if
forward_mode
!=
ForwardMode
.
DECODE
and
int
(
torch
.
sum
(
seq_lens
))
>
4096
:
flashinfer_use_ragged
=
True
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
model_runner
.
flashinfer_decode_wrapper
,
flashinfer_use_ragged
,
def
init_multimuldal_info
(
self
,
batch
:
ScheduleBatch
):
reqs
=
batch
.
reqs
self
.
pixel_values
=
[
r
.
pixel_values
for
r
in
reqs
]
self
.
image_sizes
=
[
r
.
image_size
for
r
in
reqs
]
self
.
image_offsets
=
[
(
(
r
.
image_offset
-
len
(
r
.
prefix_indices
))
if
r
.
image_offset
is
not
None
else
0
)
for
r
in
reqs
]
batch_size
=
len
(
req_pool_indices
)
def
compute_positions
(
self
,
batch
:
ScheduleBatch
):
position_ids_offsets
=
batch
.
position_ids_offsets
if
forward_mode
==
ForwardMode
.
DECODE
:
positions
=
((
seq_lens
-
1
)
+
position_ids_offsets
).
to
(
torch
.
int64
)
extend_seq_lens
=
extend_start_loc
=
extend_no_prefix
=
None
if
not
model_runner
.
server_args
.
disable_flashinfer
:
# This variable is not needed in this case,
# we do not compute it to make it compatbile with cuda graph.
total_num_tokens
=
None
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
if
True
:
self
.
positions
=
self
.
seq_lens
-
1
else
:
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
# Deprecated
self
.
positions
=
(
self
.
seq_lens
-
1
)
+
position_ids_offsets
else
:
seq_lens_cpu
=
seq_lens
.
cpu
().
numpy
()
prefix_lens_cpu
=
prefix_lens
.
cpu
().
numpy
()
position_ids_offsets_cpu
=
position_ids_offsets
.
cpu
().
numpy
()
positions
=
torch
.
tensor
(
np
.
concatenate
(
[
np
.
arange
(
prefix_lens_cpu
[
i
]
+
position_ids_offsets_cpu
[
i
],
seq_lens_cpu
[
i
]
+
position_ids_offsets_cpu
[
i
],
)
for
i
in
range
(
batch_size
)
],
axis
=
0
,
),
device
=
"cuda"
,
)
extend_seq_lens
=
seq_lens
-
prefix_lens
extend_start_loc
=
torch
.
zeros_like
(
seq_lens
)
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
extend_seq_lens
[:
-
1
],
dim
=
0
)
extend_no_prefix
=
torch
.
all
(
prefix_lens
==
0
)
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
if
True
:
self
.
positions
=
torch
.
tensor
(
np
.
concatenate
(
[
np
.
arange
(
len
(
req
.
prefix_indices
),
len
(
req
.
input_ids
))
for
req
in
batch
.
reqs
],
axis
=
0
,
),
device
=
"cuda"
,
)
else
:
# Deprecated
position_ids_offsets_cpu
=
position_ids_offsets
.
cpu
().
numpy
()
self
.
positions
=
torch
.
tensor
(
np
.
concatenate
(
[
np
.
arange
(
len
(
req
.
prefix_indices
)
+
position_ids_offsets_cpu
[
i
],
len
(
req
.
input_ids
)
+
position_ids_offsets_cpu
[
i
],
)
for
i
,
req
in
enumerate
(
batch
.
reqs
)
],
axis
=
0
,
),
device
=
"cuda"
,
)
# Positions should be in long type
self
.
positions
=
self
.
positions
.
to
(
torch
.
int64
)
def
compute_extend_infos
(
self
,
batch
:
ScheduleBatch
):
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
self
.
extend_seq_lens
=
self
.
extend_start_loc
=
self
.
extend_no_prefix
=
None
else
:
prefix_lens_cpu
=
[
len
(
r
.
input_ids
)
-
len
(
r
.
prefix_indices
)
for
r
in
batch
.
reqs
]
self
.
extend_seq_lens
=
torch
.
tensor
(
prefix_lens_cpu
,
device
=
"cuda"
)
self
.
extend_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
)
self
.
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
[:
-
1
],
dim
=
0
)
self
.
extend_no_prefix
=
all
(
x
==
0
for
x
in
prefix_lens_cpu
)
def
init_total_num_tokens
(
self
,
batch
:
ScheduleBatch
):
self
.
total_num_tokens
=
sum
(
len
(
req
.
input_ids
)
for
req
in
batch
.
reqs
)
@
classmethod
def
from_schedule_batch
(
cls
,
model_runner
:
"ModelRunner"
,
batch
:
ScheduleBatch
,
forward_mode
:
ForwardMode
,
):
ret
=
cls
(
forward_mode
=
forward_mode
,
batch_size
=
batch_size
,
total_num_tokens
=
total_num_tokens
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
positions
=
positions
,
batch_size
=
batch
.
batch_size
(),
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
out_cache_loc
=
out_cache_loc
,
extend_seq_lens
=
extend_seq_lens
,
extend_start_loc
=
extend_start_loc
,
extend_no_prefix
=
extend_no_prefix
,
return_logprob
=
return_logprob
,
top_logprobs_nums
=
top_logprobs_nums
,
flashinfer_prefill_wrapper_ragged
=
model_runner
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
model_runner
.
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
,
flashinfer_use_ragged
=
flashinfer_use_ragged
,
out_cache_loc
=
batch
.
out_cache_loc
,
return_logprob
=
batch
.
return_logprob
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
)
ret
.
compute_positions
(
batch
)
ret
.
compute_extend_infos
(
batch
)
ret
.
init_total_num_tokens
(
batch
)
if
forward_mode
!=
ForwardMode
.
DECODE
:
ret
.
init_multimuldal_info
(
batch
)
prefix_lens
=
None
if
forward_mode
!=
ForwardMode
.
DECODE
:
prefix_lens
=
torch
.
tensor
(
[
len
(
r
.
prefix_indices
)
for
r
in
batch
.
reqs
],
device
=
"cuda"
)
if
model_runner
.
server_args
.
disable_flashinfer
:
(
ret
.
triton_max_seq_len
,
ret
.
triton_max_extend_len
,
ret
.
triton_start_loc
,
ret
.
triton_prefix_lens
,
)
=
init_triton_args
(
forward_mode
,
seq_lens
,
prefix_lens
)
ret
.
init_triton_args
(
batch
,
prefix_lens
)
flashinfer_use_ragged
=
False
if
not
model_runner
.
server_args
.
disable_flashinfer
:
if
(
forward_mode
!=
ForwardMode
.
DECODE
and
int
(
torch
.
sum
(
ret
.
seq_lens
))
>
4096
):
flashinfer_use_ragged
=
True
ret
.
init_flashinfer_handlers
(
model_runner
,
prefix_lens
,
flashinfer_use_ragged
)
return
ret
def
init_triton_args
(
self
,
batch
:
ScheduleBatch
,
prefix_lens
):
"""Init auxiliary variables for triton attention backend."""
self
.
triton_max_seq_len
=
max
(
len
(
r
.
input_ids
)
for
r
in
batch
.
reqs
)
self
.
triton_prefix_lens
=
prefix_lens
self
.
triton_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
,
dtype
=
torch
.
int32
)
self
.
triton_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
seq_lens
[:
-
1
],
dim
=
0
)
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
self
.
triton_max_extend_len
=
None
else
:
extend_seq_lens
=
self
.
seq_lens
-
prefix_lens
self
.
triton_max_extend_len
=
int
(
torch
.
max
(
extend_seq_lens
))
def
init_flashinfer_args
(
def
init_flashinfer_handlers
(
self
,
model_runner
,
prefix_lens
,
flashinfer_use_ragged
):
update_flashinfer_indices
(
self
.
forward_mode
,
model_runner
,
self
.
req_pool_indices
,
self
.
seq_lens
,
prefix_lens
,
flashinfer_use_ragged
=
flashinfer_use_ragged
,
)
(
self
.
flashinfer_prefill_wrapper_ragged
,
self
.
flashinfer_prefill_wrapper_paged
,
self
.
flashinfer_decode_wrapper
,
self
.
flashinfer_use_ragged
,
)
=
(
model_runner
.
flashinfer_prefill_wrapper_ragged
,
model_runner
.
flashinfer_prefill_wrapper_paged
,
model_runner
.
flashinfer_decode_wrapper
,
flashinfer_use_ragged
,
)
def
update_flashinfer_indices
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
flashinfer_decode_wrapper
,
flashinfer_decode_wrapper
=
None
,
flashinfer_use_ragged
=
False
,
):
"""Init auxiliary variables for FlashInfer attention backend."""
...
...
@@ -178,7 +253,6 @@ def init_flashinfer_args(
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
)
head_dim
=
model_runner
.
model_config
.
head_dim
batch_size
=
len
(
req_pool_indices
)
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
if
flashinfer_use_ragged
:
paged_kernel_lens
=
prefix_lens
...
...
@@ -201,6 +275,10 @@ def init_flashinfer_args(
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
forward_mode
==
ForwardMode
.
DECODE
:
# CUDA graph uses different flashinfer_decode_wrapper
if
flashinfer_decode_wrapper
is
None
:
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
flashinfer_decode_wrapper
.
end_forward
()
flashinfer_decode_wrapper
.
begin_forward
(
kv_indptr
,
...
...
@@ -238,19 +316,3 @@ def init_flashinfer_args(
head_dim
,
1
,
)
def
init_triton_args
(
forward_mode
,
seq_lens
,
prefix_lens
):
"""Init auxiliary variables for triton attention backend."""
batch_size
=
len
(
seq_lens
)
max_seq_len
=
int
(
torch
.
max
(
seq_lens
))
start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
start_loc
[
1
:]
=
torch
.
cumsum
(
seq_lens
[:
-
1
],
dim
=
0
)
if
forward_mode
==
ForwardMode
.
DECODE
:
max_extend_len
=
None
else
:
extend_seq_lens
=
seq_lens
-
prefix_lens
max_extend_len
=
int
(
torch
.
max
(
extend_seq_lens
))
return
max_seq_len
,
max_extend_len
,
start_loc
,
prefix_lens
python/sglang/srt/model_executor/model_runner.py
View file @
1ac304ee
...
...
@@ -350,33 +350,18 @@ class ModelRunner:
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
len
(
batch
.
reqs
)):
return
self
.
cuda_graph_runner
.
replay
(
batch
)
input_metadata
=
InputMetadata
.
create
(
self
,
forward_mode
=
ForwardMode
.
DECODE
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
prefix_lens
=
batch
.
prefix_lens
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
out_cache_loc
=
batch
.
out_cache_loc
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
,
ForwardMode
.
DECODE
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
)
@
torch
.
inference_mode
()
def
forward_extend
(
self
,
batch
:
ScheduleBatch
):
input_metadata
=
InputMetadata
.
create
(
self
,
forward_mode
=
ForwardMode
.
EXTEND
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
prefix_lens
=
batch
.
prefix_lens
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
out_cache_loc
=
batch
.
out_cache_loc
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
,
forward_mode
=
ForwardMode
.
EXTEND
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
...
...
@@ -384,24 +369,16 @@ class ModelRunner:
@
torch
.
inference_mode
()
def
forward_extend_multi_modal
(
self
,
batch
:
ScheduleBatch
):
input_metadata
=
InputMetadata
.
create
(
self
,
forward_mode
=
ForwardMode
.
EXTEND
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
prefix_lens
=
batch
.
prefix_lens
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
out_cache_loc
=
batch
.
out_cache_loc
,
return_logprob
=
batch
.
return_logprob
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
,
forward_mode
=
ForwardMode
.
EXTEND
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
,
batch
.
pixel_values
,
batch
.
image_sizes
,
batch
.
image_offsets
,
input_metadata
.
pixel_values
,
input_metadata
.
image_sizes
,
input_metadata
.
image_offsets
,
)
def
forward
(
self
,
batch
:
ScheduleBatch
,
forward_mode
:
ForwardMode
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment