Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1ac304ee
"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "2623fbf21c3a6e540710f77435ba55154b897b4b"
Unverified
Commit
1ac304ee
authored
Aug 08, 2024
by
Liangsheng Yin
Committed by
GitHub
Aug 08, 2024
Browse files
Adjust `InputeMetadata` and `ScheduleBatch` (#981)
parent
20a4f927
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
202 additions
and
191 deletions
+202
-191
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+16
-44
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+9
-9
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+167
-105
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+10
-33
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
1ac304ee
...
@@ -307,7 +307,6 @@ class ScheduleBatch:
...
@@ -307,7 +307,6 @@ class ScheduleBatch:
input_ids
:
torch
.
Tensor
=
None
input_ids
:
torch
.
Tensor
=
None
req_pool_indices
:
torch
.
Tensor
=
None
req_pool_indices
:
torch
.
Tensor
=
None
seq_lens
:
torch
.
Tensor
=
None
seq_lens
:
torch
.
Tensor
=
None
prefix_lens
:
torch
.
Tensor
=
None
position_ids_offsets
:
torch
.
Tensor
=
None
position_ids_offsets
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
extend_num_tokens
:
int
=
None
extend_num_tokens
:
int
=
None
...
@@ -316,11 +315,6 @@ class ScheduleBatch:
...
@@ -316,11 +315,6 @@ class ScheduleBatch:
return_logprob
:
bool
=
False
return_logprob
:
bool
=
False
top_logprobs_nums
:
List
[
int
]
=
None
top_logprobs_nums
:
List
[
int
]
=
None
# For multimodal
pixel_values
:
List
[
torch
.
Tensor
]
=
None
image_sizes
:
List
[
List
[
int
]]
=
None
image_offsets
:
List
[
int
]
=
None
# Batched sampling params
# Batched sampling params
temperatures
:
torch
.
Tensor
=
None
temperatures
:
torch
.
Tensor
=
None
top_ps
:
torch
.
Tensor
=
None
top_ps
:
torch
.
Tensor
=
None
...
@@ -412,59 +406,40 @@ class ScheduleBatch:
...
@@ -412,59 +406,40 @@ class ScheduleBatch:
self
.
logit_bias
[
i
][:
len
(
int_token_logit_bias
)]
=
int_token_logit_bias
self
.
logit_bias
[
i
][:
len
(
int_token_logit_bias
)]
=
int_token_logit_bias
def
prepare_for_extend
(
self
,
vocab_size
:
int
,
int_token_logit_bias
:
torch
.
Tensor
):
def
prepare_for_extend
(
self
,
vocab_size
:
int
,
int_token_logit_bias
:
torch
.
Tensor
):
device
=
"cuda"
bs
=
self
.
batch_size
()
bs
=
self
.
batch_size
()
reqs
=
self
.
reqs
reqs
=
self
.
reqs
input_ids
=
[
r
.
input_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
input_ids
=
[
r
.
input_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
prefix_indices
=
[
r
.
prefix_indices
for
r
in
reqs
]
extend_num_tokens
=
sum
(
len
(
ids
)
for
ids
in
input_ids
)
# Handle prefix
extend_lens
=
[]
prefix_lens
=
[]
seq_lens
=
[]
seq_lens
=
[]
# Allocate memory
req_pool_indices_cpu
=
self
.
alloc_req_slots
(
bs
)
req_pool_indices_cpu
=
self
.
alloc_req_slots
(
bs
)
out_cache_loc
=
self
.
alloc_token_slots
(
extend_num_tokens
)
pt
=
0
for
i
,
req
in
enumerate
(
reqs
):
for
i
,
req
in
enumerate
(
reqs
):
req
.
req_pool_idx
=
req_pool_indices_cpu
[
i
]
req
.
req_pool_idx
=
req_pool_indices_cpu
[
i
]
extend_lens
.
append
(
len
(
input_ids
[
i
]))
pre_len
,
seq_len
=
len
(
req
.
prefix_indices
),
len
(
req
.
input_ids
)
ext_len
=
seq_len
-
pre_len
seq_lens
.
append
(
seq_len
)
if
len
(
prefix_indices
[
i
])
==
0
:
if
pre_len
>
0
:
prefix_lens
.
append
(
0
)
else
:
prefix_lens
.
append
(
len
(
prefix_indices
[
i
]))
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
:
len
(
prefix_indices
[
i
])
:
pre_len
]
=
prefix_indices
[
i
]
]
=
req
.
prefix_indices
seq_lens
.
append
(
prefix_lens
[
-
1
]
+
extend_lens
[
-
1
])
# Allocate memory
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
pre_len
:
seq_len
]
=
(
seq_lens
,
prefix_lens
=
np
.
array
(
seq_lens
),
np
.
array
(
prefix_lens
)
out_cache_loc
[
pt
:
pt
+
ext_len
]
extend_num_tokens
=
seq_lens
.
sum
()
-
prefix_lens
.
sum
()
)
out_cache_loc
=
self
.
alloc_token_slots
(
extend_num_tokens
)
pt
+=
ext_len
pt
=
0
for
i
,
req
in
enumerate
(
reqs
):
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
prefix_lens
[
i
]
:
prefix_lens
[
i
]
+
extend_lens
[
i
]
]
=
out_cache_loc
[
pt
:
pt
+
extend_lens
[
i
]]
pt
+=
extend_lens
[
i
]
# Set fields
# Set fields
with
torch
.
device
(
"cuda"
):
with
torch
.
device
(
"cuda"
):
self
.
input_ids
=
torch
.
tensor
(
sum
(
input_ids
,
[]),
dtype
=
torch
.
int32
)
self
.
input_ids
=
torch
.
tensor
(
sum
(
input_ids
,
[]),
dtype
=
torch
.
int32
)
self
.
req_pool_indices
=
torch
.
tensor
(
req_pool_indices_cpu
)
self
.
req_pool_indices
=
torch
.
tensor
(
req_pool_indices_cpu
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int32
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int32
)
self
.
position_ids_offsets
=
torch
.
zeros
((
bs
,),
dtype
=
torch
.
int32
)
self
.
position_ids_offsets
=
torch
.
zeros
((
bs
,),
dtype
=
torch
.
int64
)
self
.
pixel_values
=
[
r
.
pixel_values
for
r
in
reqs
]
self
.
image_sizes
=
[
r
.
image_size
for
r
in
reqs
]
self
.
image_offsets
=
[
(
r
.
image_offset
-
p_len
)
if
r
.
image_offset
is
not
None
else
0
for
r
,
p_len
in
zip
(
reqs
,
prefix_lens
)
]
self
.
prefix_lens
=
torch
.
tensor
(
prefix_lens
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
extend_num_tokens
=
extend_num_tokens
self
.
extend_num_tokens
=
extend_num_tokens
self
.
out_cache_loc
=
out_cache_loc
self
.
out_cache_loc
=
out_cache_loc
self
.
top_logprobs_nums
=
[
r
.
top_logprobs_num
for
r
in
reqs
]
self
.
top_logprobs_nums
=
[
r
.
top_logprobs_num
for
r
in
reqs
]
...
@@ -642,7 +617,6 @@ class ScheduleBatch:
...
@@ -642,7 +617,6 @@ class ScheduleBatch:
]
]
self
.
input_ids
=
torch
.
tensor
(
input_ids
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
input_ids
=
torch
.
tensor
(
input_ids
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
seq_lens
.
add_
(
1
)
self
.
seq_lens
.
add_
(
1
)
self
.
prefix_lens
=
None
# Alloc mem
# Alloc mem
bs
=
self
.
batch_size
()
bs
=
self
.
batch_size
()
...
@@ -667,7 +641,6 @@ class ScheduleBatch:
...
@@ -667,7 +641,6 @@ class ScheduleBatch:
self
.
seq_lens
=
self
.
seq_lens
[
new_indices
]
self
.
seq_lens
=
self
.
seq_lens
[
new_indices
]
self
.
input_ids
=
None
self
.
input_ids
=
None
self
.
req_pool_indices
=
self
.
req_pool_indices
[
new_indices
]
self
.
req_pool_indices
=
self
.
req_pool_indices
[
new_indices
]
self
.
prefix_lens
=
None
self
.
position_ids_offsets
=
self
.
position_ids_offsets
[
new_indices
]
self
.
position_ids_offsets
=
self
.
position_ids_offsets
[
new_indices
]
self
.
out_cache_loc
=
None
self
.
out_cache_loc
=
None
self
.
top_logprobs_nums
=
[
self
.
top_logprobs_nums
[
i
]
for
i
in
unfinished_indices
]
self
.
top_logprobs_nums
=
[
self
.
top_logprobs_nums
[
i
]
for
i
in
unfinished_indices
]
...
@@ -692,7 +665,6 @@ class ScheduleBatch:
...
@@ -692,7 +665,6 @@ class ScheduleBatch:
[
self
.
req_pool_indices
,
other
.
req_pool_indices
]
[
self
.
req_pool_indices
,
other
.
req_pool_indices
]
)
)
self
.
seq_lens
=
torch
.
concat
([
self
.
seq_lens
,
other
.
seq_lens
])
self
.
seq_lens
=
torch
.
concat
([
self
.
seq_lens
,
other
.
seq_lens
])
self
.
prefix_lens
=
None
self
.
position_ids_offsets
=
torch
.
concat
(
self
.
position_ids_offsets
=
torch
.
concat
(
[
self
.
position_ids_offsets
,
other
.
position_ids_offsets
]
[
self
.
position_ids_offsets
,
other
.
position_ids_offsets
]
)
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
1ac304ee
...
@@ -33,7 +33,7 @@ from sglang.srt.managers.schedule_batch import ScheduleBatch
...
@@ -33,7 +33,7 @@ from sglang.srt.managers.schedule_batch import ScheduleBatch
from
sglang.srt.model_executor.forward_batch_info
import
(
from
sglang.srt.model_executor.forward_batch_info
import
(
ForwardMode
,
ForwardMode
,
InputMetadata
,
InputMetadata
,
init
_flashinfer_
arg
s
,
update
_flashinfer_
indice
s
,
)
)
from
sglang.srt.utils
import
monkey_patch_vllm_all_gather
from
sglang.srt.utils
import
monkey_patch_vllm_all_gather
...
@@ -165,7 +165,7 @@ class CudaGraphRunner:
...
@@ -165,7 +165,7 @@ class CudaGraphRunner:
paged_kv_indices_buffer
=
self
.
flashinfer_kv_indices
,
paged_kv_indices_buffer
=
self
.
flashinfer_kv_indices
,
paged_kv_last_page_len_buffer
=
self
.
flashinfer_kv_last_page_len
[:
bs
],
paged_kv_last_page_len_buffer
=
self
.
flashinfer_kv_last_page_len
[:
bs
],
)
)
init
_flashinfer_
arg
s
(
update
_flashinfer_
indice
s
(
ForwardMode
.
DECODE
,
ForwardMode
.
DECODE
,
self
.
model_runner
,
self
.
model_runner
,
req_pool_indices
,
req_pool_indices
,
...
@@ -176,19 +176,19 @@ class CudaGraphRunner:
...
@@ -176,19 +176,19 @@ class CudaGraphRunner:
# Run and capture
# Run and capture
def
run_once
():
def
run_once
():
input_metadata
=
InputMetadata
.
create
(
input_metadata
=
InputMetadata
(
self
.
model_runner
,
forward_mode
=
ForwardMode
.
DECODE
,
forward_mode
=
ForwardMode
.
DECODE
,
batch_size
=
bs
,
req_pool_indices
=
req_pool_indices
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
p
re
fix_lens
=
None
,
re
q_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
position_ids_offsets
=
position_ids_offsets
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
out_cache_loc
=
out_cache_loc
,
out_cache_loc
=
out_cache_loc
,
return_logprob
=
False
,
return_logprob
=
False
,
top_logprobs_nums
=
0
,
top_logprobs_nums
=
0
,
skip_flashinfer_init
=
True
,
positions
=
(
seq_lens
-
1
).
to
(
torch
.
int64
),
flashinfer_decode_wrapper
=
flashinfer_decode_wrapper
,
)
)
input_metadata
.
flashinfer_decode_wrapper
=
flashinfer_decode_wrapper
return
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
return
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
...
@@ -222,7 +222,7 @@ class CudaGraphRunner:
...
@@ -222,7 +222,7 @@ class CudaGraphRunner:
self
.
out_cache_loc
[:
raw_bs
]
=
batch
.
out_cache_loc
self
.
out_cache_loc
[:
raw_bs
]
=
batch
.
out_cache_loc
# FlashInfer inputs
# FlashInfer inputs
init
_flashinfer_
arg
s
(
update
_flashinfer_
indice
s
(
ForwardMode
.
DECODE
,
ForwardMode
.
DECODE
,
self
.
model_runner
,
self
.
model_runner
,
self
.
req_pool_indices
[:
bs
],
self
.
req_pool_indices
[:
bs
],
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
1ac304ee
...
@@ -16,13 +16,17 @@ limitations under the License.
...
@@ -16,13 +16,17 @@ limitations under the License.
"""ModelRunner runs the forward passes of the models."""
"""ModelRunner runs the forward passes of the models."""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
enum
import
IntEnum
,
auto
from
typing
import
List
from
typing
import
TYPE_CHECKING
,
List
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
class
ForwardMode
(
IntEnum
):
class
ForwardMode
(
IntEnum
):
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
...
@@ -39,25 +43,33 @@ class InputMetadata:
...
@@ -39,25 +43,33 @@ class InputMetadata:
forward_mode
:
ForwardMode
forward_mode
:
ForwardMode
batch_size
:
int
batch_size
:
int
total_num_tokens
:
int
req_pool_indices
:
torch
.
Tensor
req_pool_indices
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
positions
:
torch
.
Tensor
req_to_token_pool
:
ReqToTokenPool
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
BaseTokenToKVPool
token_to_kv_pool
:
BaseTokenToKVPool
# For extend
extend_seq_lens
:
torch
.
Tensor
extend_start_loc
:
torch
.
Tensor
extend_no_prefix
:
bool
# Output location of the KV cache
# Output location of the KV cache
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
total_num_tokens
:
int
=
None
# Position information
positions
:
torch
.
Tensor
=
None
# For extend
extend_seq_lens
:
torch
.
Tensor
=
None
extend_start_loc
:
torch
.
Tensor
=
None
extend_no_prefix
:
bool
=
None
# Output options
# Output options
return_logprob
:
bool
=
False
return_logprob
:
bool
=
False
top_logprobs_nums
:
List
[
int
]
=
None
top_logprobs_nums
:
List
[
int
]
=
None
# For multimodal
pixel_values
:
List
[
torch
.
Tensor
]
=
None
image_sizes
:
List
[
List
[
int
]]
=
None
image_offsets
:
List
[
int
]
=
None
# Trition attention backend
# Trition attention backend
triton_max_seq_len
:
int
=
0
triton_max_seq_len
:
int
=
0
triton_max_extend_len
:
int
=
0
triton_max_extend_len
:
int
=
0
...
@@ -70,107 +82,170 @@ class InputMetadata:
...
@@ -70,107 +82,170 @@ class InputMetadata:
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
flashinfer_use_ragged
:
bool
=
False
flashinfer_use_ragged
:
bool
=
False
@
classmethod
def
init_multimuldal_info
(
self
,
batch
:
ScheduleBatch
):
def
create
(
reqs
=
batch
.
reqs
cls
,
self
.
pixel_values
=
[
r
.
pixel_values
for
r
in
reqs
]
model_runner
,
self
.
image_sizes
=
[
r
.
image_size
for
r
in
reqs
]
forward_mode
,
self
.
image_offsets
=
[
req_pool_indices
,
(
seq_lens
,
(
r
.
image_offset
-
len
(
r
.
prefix_indices
))
prefix_lens
,
if
r
.
image_offset
is
not
None
position_ids_offsets
,
else
0
out_cache_loc
,
top_logprobs_nums
=
None
,
return_logprob
=
False
,
skip_flashinfer_init
=
False
,
):
flashinfer_use_ragged
=
False
if
not
skip_flashinfer_init
and
not
model_runner
.
server_args
.
disable_flashinfer
:
if
forward_mode
!=
ForwardMode
.
DECODE
and
int
(
torch
.
sum
(
seq_lens
))
>
4096
:
flashinfer_use_ragged
=
True
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
model_runner
.
flashinfer_decode_wrapper
,
flashinfer_use_ragged
,
)
)
for
r
in
reqs
]
batch_size
=
len
(
req_pool_indices
)
def
compute_positions
(
self
,
batch
:
ScheduleBatch
):
position_ids_offsets
=
batch
.
position_ids_offsets
if
forward_mode
==
ForwardMode
.
DECODE
:
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
positions
=
((
seq_lens
-
1
)
+
position_ids_offsets
).
to
(
torch
.
int64
)
if
True
:
extend_seq_lens
=
extend_start_loc
=
extend_no_prefix
=
None
self
.
positions
=
self
.
seq_lens
-
1
if
not
model_runner
.
server_args
.
disable_flashinfer
:
# This variable is not needed in this case,
# we do not compute it to make it compatbile with cuda graph.
total_num_tokens
=
None
else
:
else
:
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
# Deprecated
self
.
positions
=
(
self
.
seq_lens
-
1
)
+
position_ids_offsets
else
:
else
:
seq_lens_cpu
=
seq_lens
.
cpu
().
numpy
()
if
True
:
prefix_lens_cpu
=
prefix_lens
.
cpu
().
numpy
()
self
.
positions
=
torch
.
tensor
(
position_ids_offsets_cpu
=
position_ids_offsets
.
cpu
().
numpy
()
np
.
concatenate
(
positions
=
torch
.
tensor
(
[
np
.
concatenate
(
np
.
arange
(
len
(
req
.
prefix_indices
),
len
(
req
.
input_ids
))
[
for
req
in
batch
.
reqs
np
.
arange
(
],
prefix_lens_cpu
[
i
]
+
position_ids_offsets_cpu
[
i
],
axis
=
0
,
seq_lens_cpu
[
i
]
+
position_ids_offsets_cpu
[
i
],
),
)
device
=
"cuda"
,
for
i
in
range
(
batch_size
)
)
],
else
:
axis
=
0
,
# Deprecated
),
position_ids_offsets_cpu
=
position_ids_offsets
.
cpu
().
numpy
()
device
=
"cuda"
,
self
.
positions
=
torch
.
tensor
(
)
np
.
concatenate
(
extend_seq_lens
=
seq_lens
-
prefix_lens
[
extend_start_loc
=
torch
.
zeros_like
(
seq_lens
)
np
.
arange
(
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
extend_seq_lens
[:
-
1
],
dim
=
0
)
len
(
req
.
prefix_indices
)
+
position_ids_offsets_cpu
[
i
],
extend_no_prefix
=
torch
.
all
(
prefix_lens
==
0
)
len
(
req
.
input_ids
)
+
position_ids_offsets_cpu
[
i
],
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
)
for
i
,
req
in
enumerate
(
batch
.
reqs
)
],
axis
=
0
,
),
device
=
"cuda"
,
)
# Positions should be in long type
self
.
positions
=
self
.
positions
.
to
(
torch
.
int64
)
def
compute_extend_infos
(
self
,
batch
:
ScheduleBatch
):
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
self
.
extend_seq_lens
=
self
.
extend_start_loc
=
self
.
extend_no_prefix
=
None
else
:
prefix_lens_cpu
=
[
len
(
r
.
input_ids
)
-
len
(
r
.
prefix_indices
)
for
r
in
batch
.
reqs
]
self
.
extend_seq_lens
=
torch
.
tensor
(
prefix_lens_cpu
,
device
=
"cuda"
)
self
.
extend_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
)
self
.
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
[:
-
1
],
dim
=
0
)
self
.
extend_no_prefix
=
all
(
x
==
0
for
x
in
prefix_lens_cpu
)
def
init_total_num_tokens
(
self
,
batch
:
ScheduleBatch
):
self
.
total_num_tokens
=
sum
(
len
(
req
.
input_ids
)
for
req
in
batch
.
reqs
)
@
classmethod
def
from_schedule_batch
(
cls
,
model_runner
:
"ModelRunner"
,
batch
:
ScheduleBatch
,
forward_mode
:
ForwardMode
,
):
ret
=
cls
(
ret
=
cls
(
forward_mode
=
forward_mode
,
forward_mode
=
forward_mode
,
batch_size
=
batch_size
,
batch_size
=
batch
.
batch_size
(),
total_num_tokens
=
total_num_tokens
,
req_pool_indices
=
batch
.
req_pool_indices
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
seq_lens
=
seq_lens
,
positions
=
positions
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
out_cache_loc
=
out_cache_loc
,
out_cache_loc
=
batch
.
out_cache_loc
,
extend_seq_lens
=
extend_seq_lens
,
return_logprob
=
batch
.
return_logprob
,
extend_start_loc
=
extend_start_loc
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
extend_no_prefix
=
extend_no_prefix
,
return_logprob
=
return_logprob
,
top_logprobs_nums
=
top_logprobs_nums
,
flashinfer_prefill_wrapper_ragged
=
model_runner
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
model_runner
.
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
,
flashinfer_use_ragged
=
flashinfer_use_ragged
,
)
)
ret
.
compute_positions
(
batch
)
ret
.
compute_extend_infos
(
batch
)
ret
.
init_total_num_tokens
(
batch
)
if
forward_mode
!=
ForwardMode
.
DECODE
:
ret
.
init_multimuldal_info
(
batch
)
prefix_lens
=
None
if
forward_mode
!=
ForwardMode
.
DECODE
:
prefix_lens
=
torch
.
tensor
(
[
len
(
r
.
prefix_indices
)
for
r
in
batch
.
reqs
],
device
=
"cuda"
)
if
model_runner
.
server_args
.
disable_flashinfer
:
if
model_runner
.
server_args
.
disable_flashinfer
:
(
ret
.
init_triton_args
(
batch
,
prefix_lens
)
ret
.
triton_max_seq_len
,
ret
.
triton_max_extend_len
,
flashinfer_use_ragged
=
False
ret
.
triton_start_loc
,
if
not
model_runner
.
server_args
.
disable_flashinfer
:
ret
.
triton_prefix_lens
,
if
(
)
=
init_triton_args
(
forward_mode
,
seq_lens
,
prefix_lens
)
forward_mode
!=
ForwardMode
.
DECODE
and
int
(
torch
.
sum
(
ret
.
seq_lens
))
>
4096
):
flashinfer_use_ragged
=
True
ret
.
init_flashinfer_handlers
(
model_runner
,
prefix_lens
,
flashinfer_use_ragged
)
return
ret
return
ret
def
init_triton_args
(
self
,
batch
:
ScheduleBatch
,
prefix_lens
):
"""Init auxiliary variables for triton attention backend."""
self
.
triton_max_seq_len
=
max
(
len
(
r
.
input_ids
)
for
r
in
batch
.
reqs
)
self
.
triton_prefix_lens
=
prefix_lens
self
.
triton_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
,
dtype
=
torch
.
int32
)
self
.
triton_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
seq_lens
[:
-
1
],
dim
=
0
)
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
self
.
triton_max_extend_len
=
None
else
:
extend_seq_lens
=
self
.
seq_lens
-
prefix_lens
self
.
triton_max_extend_len
=
int
(
torch
.
max
(
extend_seq_lens
))
def
init_flashinfer_args
(
def
init_flashinfer_handlers
(
self
,
model_runner
,
prefix_lens
,
flashinfer_use_ragged
):
update_flashinfer_indices
(
self
.
forward_mode
,
model_runner
,
self
.
req_pool_indices
,
self
.
seq_lens
,
prefix_lens
,
flashinfer_use_ragged
=
flashinfer_use_ragged
,
)
(
self
.
flashinfer_prefill_wrapper_ragged
,
self
.
flashinfer_prefill_wrapper_paged
,
self
.
flashinfer_decode_wrapper
,
self
.
flashinfer_use_ragged
,
)
=
(
model_runner
.
flashinfer_prefill_wrapper_ragged
,
model_runner
.
flashinfer_prefill_wrapper_paged
,
model_runner
.
flashinfer_decode_wrapper
,
flashinfer_use_ragged
,
)
def
update_flashinfer_indices
(
forward_mode
,
forward_mode
,
model_runner
,
model_runner
,
req_pool_indices
,
req_pool_indices
,
seq_lens
,
seq_lens
,
prefix_lens
,
prefix_lens
,
flashinfer_decode_wrapper
,
flashinfer_decode_wrapper
=
None
,
flashinfer_use_ragged
=
False
,
flashinfer_use_ragged
=
False
,
):
):
"""Init auxiliary variables for FlashInfer attention backend."""
"""Init auxiliary variables for FlashInfer attention backend."""
...
@@ -178,7 +253,6 @@ def init_flashinfer_args(
...
@@ -178,7 +253,6 @@ def init_flashinfer_args(
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
)
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
)
head_dim
=
model_runner
.
model_config
.
head_dim
head_dim
=
model_runner
.
model_config
.
head_dim
batch_size
=
len
(
req_pool_indices
)
batch_size
=
len
(
req_pool_indices
)
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
if
flashinfer_use_ragged
:
if
flashinfer_use_ragged
:
paged_kernel_lens
=
prefix_lens
paged_kernel_lens
=
prefix_lens
...
@@ -201,6 +275,10 @@ def init_flashinfer_args(
...
@@ -201,6 +275,10 @@ def init_flashinfer_args(
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
forward_mode
==
ForwardMode
.
DECODE
:
if
forward_mode
==
ForwardMode
.
DECODE
:
# CUDA graph uses different flashinfer_decode_wrapper
if
flashinfer_decode_wrapper
is
None
:
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
flashinfer_decode_wrapper
.
end_forward
()
flashinfer_decode_wrapper
.
end_forward
()
flashinfer_decode_wrapper
.
begin_forward
(
flashinfer_decode_wrapper
.
begin_forward
(
kv_indptr
,
kv_indptr
,
...
@@ -238,19 +316,3 @@ def init_flashinfer_args(
...
@@ -238,19 +316,3 @@ def init_flashinfer_args(
head_dim
,
head_dim
,
1
,
1
,
)
)
def
init_triton_args
(
forward_mode
,
seq_lens
,
prefix_lens
):
"""Init auxiliary variables for triton attention backend."""
batch_size
=
len
(
seq_lens
)
max_seq_len
=
int
(
torch
.
max
(
seq_lens
))
start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
start_loc
[
1
:]
=
torch
.
cumsum
(
seq_lens
[:
-
1
],
dim
=
0
)
if
forward_mode
==
ForwardMode
.
DECODE
:
max_extend_len
=
None
else
:
extend_seq_lens
=
seq_lens
-
prefix_lens
max_extend_len
=
int
(
torch
.
max
(
extend_seq_lens
))
return
max_seq_len
,
max_extend_len
,
start_loc
,
prefix_lens
python/sglang/srt/model_executor/model_runner.py
View file @
1ac304ee
...
@@ -350,33 +350,18 @@ class ModelRunner:
...
@@ -350,33 +350,18 @@ class ModelRunner:
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
len
(
batch
.
reqs
)):
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
len
(
batch
.
reqs
)):
return
self
.
cuda_graph_runner
.
replay
(
batch
)
return
self
.
cuda_graph_runner
.
replay
(
batch
)
input_metadata
=
InputMetadata
.
create
(
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
self
,
batch
,
ForwardMode
.
DECODE
forward_mode
=
ForwardMode
.
DECODE
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
prefix_lens
=
batch
.
prefix_lens
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
out_cache_loc
=
batch
.
out_cache_loc
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
)
)
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
)
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
forward_extend
(
self
,
batch
:
ScheduleBatch
):
def
forward_extend
(
self
,
batch
:
ScheduleBatch
):
input_metadata
=
InputMetadata
.
create
(
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
self
,
batch
,
forward_mode
=
ForwardMode
.
EXTEND
forward_mode
=
ForwardMode
.
EXTEND
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
prefix_lens
=
batch
.
prefix_lens
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
out_cache_loc
=
batch
.
out_cache_loc
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
)
)
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
...
@@ -384,24 +369,16 @@ class ModelRunner:
...
@@ -384,24 +369,16 @@ class ModelRunner:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
forward_extend_multi_modal
(
self
,
batch
:
ScheduleBatch
):
def
forward_extend_multi_modal
(
self
,
batch
:
ScheduleBatch
):
input_metadata
=
InputMetadata
.
create
(
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
self
,
batch
,
forward_mode
=
ForwardMode
.
EXTEND
forward_mode
=
ForwardMode
.
EXTEND
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
prefix_lens
=
batch
.
prefix_lens
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
out_cache_loc
=
batch
.
out_cache_loc
,
return_logprob
=
batch
.
return_logprob
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
)
)
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
batch
.
input_ids
,
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
.
positions
,
input_metadata
,
input_metadata
,
batch
.
pixel_values
,
input_metadata
.
pixel_values
,
batch
.
image_sizes
,
input_metadata
.
image_sizes
,
batch
.
image_offsets
,
input_metadata
.
image_offsets
,
)
)
def
forward
(
self
,
batch
:
ScheduleBatch
,
forward_mode
:
ForwardMode
):
def
forward
(
self
,
batch
:
ScheduleBatch
,
forward_mode
:
ForwardMode
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment