Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3f0fe08d
Unverified
Commit
3f0fe08d
authored
Sep 29, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 29, 2024
Browse files
Let ModelRunner take InputMetadata as input, instead of ScheduleBatch (#1541)
parent
55b974f9
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
143 additions
and
158 deletions
+143
-158
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+4
-2
python/sglang/srt/layers/attention_backend.py
python/sglang/srt/layers/attention_backend.py
+21
-16
python/sglang/srt/lora/lora_manager.py
python/sglang/srt/lora/lora_manager.py
+11
-8
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+4
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+8
-3
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+5
-4
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+12
-16
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+38
-69
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+24
-23
python/sglang/test/runners.py
python/sglang/test/runners.py
+13
-13
test/srt/models/test_lora.py
test/srt/models/test_lora.py
+2
-3
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
No files found.
python/sglang/bench_latency.py
View file @
3f0fe08d
...
...
@@ -225,14 +225,16 @@ def extend(reqs, model_runner):
tree_cache
=
None
,
)
batch
.
prepare_for_extend
(
model_runner
.
model_config
.
vocab_size
)
logits_output
=
model_runner
.
forward
(
batch
)
input_metadata
=
batch
.
get_input_metadata
()
logits_output
=
model_runner
.
forward
(
input_metadata
)
next_token_ids
=
model_runner
.
sample
(
logits_output
,
batch
).
tolist
()
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
def
decode
(
input_token_ids
,
batch
,
model_runner
):
batch
.
prepare_for_decode
(
input_token_ids
)
logits_output
=
model_runner
.
forward
(
batch
)
input_metadata
=
batch
.
get_input_metadata
()
logits_output
=
model_runner
.
forward
(
input_metadata
)
next_token_ids
=
model_runner
.
sample
(
logits_output
,
batch
).
tolist
()
return
next_token_ids
,
logits_output
.
next_token_logits
...
...
python/sglang/srt/layers/attention_backend.py
View file @
3f0fe08d
...
...
@@ -15,7 +15,7 @@ import torch.nn as nn
from
sglang.global_config
import
global_config
from
sglang.srt.layers.flashinfer_utils
import
update_flashinfer_indices
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.utils
import
is_hip
...
...
@@ -37,9 +37,7 @@ class AttentionBackend(ABC):
"""The base class of attention backends"""
@
abstractmethod
def
init_forward_metadata
(
self
,
batch
:
ScheduleBatch
,
input_metadata
:
InputMetadata
):
def
init_forward_metadata
(
self
,
input_metadata
:
InputMetadata
):
"""Init the metadata for a forward pass."""
raise
NotImplementedError
()
...
...
@@ -133,12 +131,11 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
forward_metadata
=
None
self
.
cuda_graph_metadata
=
{}
def
init_forward_metadata
(
self
,
batch
:
ScheduleBatch
,
input_metadata
:
InputMetadata
):
def
init_forward_metadata
(
self
,
input_metadata
:
InputMetadata
):
if
input_metadata
.
forward_mode
.
is_decode
():
prefix_lens
=
None
use_ragged
=
False
extend_no_prefix
=
False
total_num_tokens
=
None
else
:
prefix_lens
=
input_metadata
.
extend_prefix_lens
...
...
@@ -152,6 +149,7 @@ class FlashInferAttnBackend(AttentionBackend):
use_ragged
=
True
total_num_tokens
=
torch
.
sum
(
input_metadata
.
seq_lens
).
item
()
extend_no_prefix
=
not
torch
.
any
(
input_metadata
.
extend_prefix_lens
).
item
()
update_flashinfer_indices
(
input_metadata
.
forward_mode
,
...
...
@@ -162,7 +160,12 @@ class FlashInferAttnBackend(AttentionBackend):
use_ragged
=
use_ragged
,
)
self
.
forward_metadata
=
(
use_ragged
,
total_num_tokens
,
self
.
decode_wrapper
)
self
.
forward_metadata
=
(
use_ragged
,
extend_no_prefix
,
total_num_tokens
,
self
.
decode_wrapper
,
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
self
.
cuda_graph_kv_indptr
=
torch
.
zeros
(
...
...
@@ -228,7 +231,7 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
cuda_graph_metadata
[
bs
]
=
decode_wrapper
self
.
forward_metadata
=
(
False
,
None
,
decode_wrapper
)
self
.
forward_metadata
=
(
False
,
False
,
None
,
decode_wrapper
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
...
...
@@ -254,7 +257,9 @@ class FlashInferAttnBackend(AttentionBackend):
else
:
prefill_wrapper_paged
=
self
.
prefill_wrapper_paged
[
1
]
use_ragged
,
total_num_tokens
,
decode_wrapper
=
self
.
forward_metadata
use_ragged
,
extend_no_prefix
,
total_num_tokens
,
decode_wrapper
=
(
self
.
forward_metadata
)
if
not
use_ragged
:
if
k
is
not
None
:
...
...
@@ -280,7 +285,7 @@ class FlashInferAttnBackend(AttentionBackend):
logits_soft_cap
=
layer
.
logit_cap
,
)
if
input_metadata
.
extend_no_prefix
:
if
extend_no_prefix
:
o
=
o1
else
:
o2
,
s2
=
prefill_wrapper_paged
.
forward_return_lse
(
...
...
@@ -300,7 +305,9 @@ class FlashInferAttnBackend(AttentionBackend):
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
input_metadata
:
InputMetadata
):
use_ragged
,
total_num_tokens
,
decode_wrapper
=
self
.
forward_metadata
use_ragged
,
extend_no_prefix
,
total_num_tokens
,
decode_wrapper
=
(
self
.
forward_metadata
)
if
isinstance
(
decode_wrapper
,
list
):
if
layer
.
sliding_window_size
!=
-
1
:
...
...
@@ -351,9 +358,7 @@ class TritonAttnBackend(AttentionBackend):
self
.
cuda_graph_max_seq_len
=
model_runner
.
model_config
.
context_len
def
init_forward_metadata
(
self
,
batch
:
ScheduleBatch
,
input_metadata
:
InputMetadata
):
def
init_forward_metadata
(
self
,
input_metadata
:
InputMetadata
):
"""Init auxiliary variables for triton attention backend."""
if
input_metadata
.
forward_mode
.
is_decode
():
...
...
@@ -371,7 +376,7 @@ class TritonAttnBackend(AttentionBackend):
max_extend_len
=
None
else
:
start_loc
=
attn_logits
=
max_seq_len
=
None
prefix_lens
=
torch
.
tensor
(
batch
.
prefix_lens_cpu
,
device
=
"cuda"
)
prefix_lens
=
input_metadata
.
extend_prefix_lens
max_extend_len
=
torch
.
max
(
input_metadata
.
seq_lens
-
prefix_lens
).
item
()
self
.
forward_metadata
=
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
...
...
python/sglang/srt/lora/lora_manager.py
View file @
3f0fe08d
...
...
@@ -18,13 +18,12 @@ limitations under the License.
import
re
from
dataclasses
import
dataclass
import
torch
from
sglang.srt.lora.lora
import
LoRAAdapter
,
get_lora_layer
from
sglang.srt.lora.lora_config
import
LoRAConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.utils
import
is_hip
,
replace_submodule
# ROCm: flashinfer available later
...
...
@@ -208,9 +207,9 @@ class LoRAManager:
if
lora_weight_name
:
self
.
B_buffer
[
lora_weight_name
][
i
][
buffer_id
].
copy_
(
weights
)
def
prepare_lora_batch
(
self
,
batch
,
extend_seq_lens
=
None
):
def
prepare_lora_batch
(
self
,
input_metadata
:
InputMetadata
):
# load active loras into lora memory pool
cur_uids
=
set
(
[
req
.
lora_path
for
req
in
batch
.
reqs
]
)
cur_uids
=
set
(
input_metadata
.
lora_paths
)
assert
len
(
cur_uids
)
<=
self
.
max_loras_per_batch
i
=
0
evictable_uids
=
list
(
self
.
active_uids
)
...
...
@@ -230,11 +229,15 @@ class LoRAManager:
return
# setup lora in forward modules
bs
=
len
(
batch
.
reqs
)
seg_lens
=
extend_seq_lens
if
batch
.
forward_mode
.
is_extend
()
else
torch
.
ones
(
bs
)
bs
=
input_metadata
.
batch_size
seg_lens
=
(
input_metadata
.
extend_seq_lens
if
input_metadata
.
forward_mode
.
is_extend
()
else
torch
.
ones
(
bs
)
)
weight_indices
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
for
i
,
req
in
enumerate
(
batch
.
req
s
):
weight_indices
[
i
]
=
self
.
buffer_id
[
req
.
lora_path
]
for
i
,
lora_path
in
enumerate
(
input_metadata
.
lora_path
s
):
weight_indices
[
i
]
=
self
.
buffer_id
[
lora_path
]
for
module_name
,
module
in
self
.
lora_modules
:
layer_id
=
get_layer_id
(
module_name
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
3f0fe08d
...
...
@@ -29,7 +29,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
...
...
@@ -511,6 +511,9 @@ class ScheduleBatch:
self
.
extend_logprob_start_lens_cpu
=
[
r
.
extend_logprob_start_len
for
r
in
reqs
]
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
vocab_size
)
def
get_input_metadata
(
self
):
return
InputMetadata
.
from_schedule_batch
(
self
)
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
self
.
forward_mode
=
ForwardMode
.
MIXED
running_bs
=
running_batch
.
batch_size
()
...
...
python/sglang/srt/managers/scheduler.py
View file @
3f0fe08d
...
...
@@ -575,8 +575,9 @@ class Scheduler:
if
self
.
is_generation
:
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
input_metadata
=
batch
.
get_input_metadata
()
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
batch
input_metadata
,
batch
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
...
...
@@ -640,7 +641,8 @@ class Scheduler:
)
else
:
assert
batch
.
extend_num_tokens
!=
0
embeddings
=
self
.
tp_worker
.
forward_batch_embedding
(
batch
)
input_metadata
=
batch
.
get_input_metadata
()
embeddings
=
self
.
tp_worker
.
forward_batch_embedding
(
input_metadata
)
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
...
...
@@ -769,7 +771,10 @@ class Scheduler:
batch
.
prepare_for_decode
()
# Forward and sample the next tokens
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
batch
)
input_metadata
=
batch
.
get_input_metadata
()
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
input_metadata
,
batch
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
...
...
python/sglang/srt/managers/tp_worker.py
View file @
3f0fe08d
...
...
@@ -21,6 +21,7 @@ import logging
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.managers.io_struct
import
UpdateWeightReqInput
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
broadcast_pyobj
,
is_multimodal_model
,
set_random_seed
...
...
@@ -105,13 +106,13 @@ class ModelTpWorker:
self
.
random_seed
,
)
def
forward_batch_generation
(
self
,
batch
):
logits_output
=
self
.
model_runner
.
forward
(
batch
)
def
forward_batch_generation
(
self
,
input_metadata
:
InputMetadata
,
batch
):
logits_output
=
self
.
model_runner
.
forward
(
input_metadata
)
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
batch
)
return
logits_output
,
next_token_ids
def
forward_batch_embedding
(
self
,
batch
):
logits_output
=
self
.
model_runner
.
forward
(
batch
)
def
forward_batch_embedding
(
self
,
input_metadata
:
InputMetadata
):
logits_output
=
self
.
model_runner
.
forward
(
input_metadata
)
embeddings
=
logits_output
.
embeddings
.
tolist
()
return
embeddings
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
3f0fe08d
...
...
@@ -31,7 +31,6 @@ from sglang.srt.layers.logits_processor import (
LogitsProcessor
,
LogitsProcessorOutput
,
)
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.utils
import
monkey_patch_vllm_all_gather
...
...
@@ -143,7 +142,6 @@ class CudaGraphRunner:
self
.
seq_lens
=
torch
.
full
(
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int32
)
self
.
position_ids_offsets
=
torch
.
ones
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
out_cache_loc
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
# Capture
...
...
@@ -189,7 +187,6 @@ class CudaGraphRunner:
input_ids
=
self
.
input_ids
[:
bs
]
req_pool_indices
=
self
.
req_pool_indices
[:
bs
]
seq_lens
=
self
.
seq_lens
[:
bs
]
position_ids_offsets
=
self
.
position_ids_offsets
[:
bs
]
out_cache_loc
=
self
.
out_cache_loc
[:
bs
]
# Attention backend
...
...
@@ -202,6 +199,7 @@ class CudaGraphRunner:
input_metadata
=
InputMetadata
(
forward_mode
=
ForwardMode
.
DECODE
,
batch_size
=
bs
,
input_ids
=
input_ids
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
...
...
@@ -210,7 +208,7 @@ class CudaGraphRunner:
out_cache_loc
=
out_cache_loc
,
return_logprob
=
False
,
top_logprobs_nums
=
[
0
]
*
bs
,
positions
=
(
seq_lens
-
1
+
position_ids_offsets
).
to
(
torch
.
int64
),
positions
=
torch
.
clamp
(
(
seq_lens
-
1
),
min
=
0
).
to
(
torch
.
int64
),
)
return
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
...
...
@@ -235,24 +233,22 @@ class CudaGraphRunner:
self
.
graph_memory_pool
=
graph
.
pool
()
return
graph
,
out
def
replay
(
self
,
batch
:
ScheduleBatch
):
assert
batch
.
out_cache_loc
is
not
None
raw_bs
=
len
(
batch
.
reqs
)
def
replay
(
self
,
input_metadata
:
InputMetadata
):
assert
input_metadata
.
out_cache_loc
is
not
None
raw_bs
=
input_metadata
.
batch_size
# Pad
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
bs
=
self
.
capture_bs
[
index
]
if
bs
!=
raw_bs
:
self
.
seq_lens
.
fill_
(
self
.
seq_len_fill_value
)
self
.
position_ids_offsets
.
fill_
(
1
)
self
.
out_cache_loc
.
zero_
()
# Common inputs
self
.
input_ids
[:
raw_bs
]
=
batch
.
input_ids
self
.
req_pool_indices
[:
raw_bs
]
=
batch
.
req_pool_indices
self
.
seq_lens
[:
raw_bs
]
=
batch
.
seq_lens
self
.
position_ids_offsets
[:
raw_bs
]
=
batch
.
position_ids_offsets
self
.
out_cache_loc
[:
raw_bs
]
=
batch
.
out_cache_loc
self
.
input_ids
[:
raw_bs
]
=
input_metadata
.
input_ids
self
.
req_pool_indices
[:
raw_bs
]
=
input_metadata
.
req_pool_indices
self
.
seq_lens
[:
raw_bs
]
=
input_metadata
.
seq_lens
self
.
out_cache_loc
[:
raw_bs
]
=
input_metadata
.
out_cache_loc
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
...
...
@@ -275,15 +271,15 @@ class CudaGraphRunner:
)
# Extract logprobs
if
batch
.
return_logprob
:
if
input_metadata
.
return_logprob
:
logits_output
.
next_token_logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
logits_output
.
next_token_logits
,
dim
=-
1
)
return_top_logprob
=
any
(
x
>
0
for
x
in
batch
.
top_logprobs_nums
)
return_top_logprob
=
any
(
x
>
0
for
x
in
input_metadata
.
top_logprobs_nums
)
if
return_top_logprob
:
logits_metadata
=
LogitsMetadata
(
forward_mode
=
ForwardMode
.
DECODE
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
top_logprobs_nums
=
input_metadata
.
top_logprobs_nums
,
)
logits_output
.
output_top_logprobs
=
LogitsProcessor
.
get_top_logprobs
(
logits_output
.
next_token_logprobs
,
logits_metadata
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
3f0fe08d
...
...
@@ -18,7 +18,7 @@ limitations under the License.
"""Meta data for a forward pass."""
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
typing
import
TYPE_CHECKING
,
List
from
typing
import
TYPE_CHECKING
,
List
,
Set
import
numpy
as
np
import
torch
...
...
@@ -27,7 +27,6 @@ if TYPE_CHECKING:
from
sglang.srt.layers.attention_backend
import
AttentionBackend
from
sglang.srt.managers.schedule_batch
import
ImageInputs
,
ScheduleBatch
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.model_executor.model_runner
import
ModelRunner
class
ForwardMode
(
IntEnum
):
...
...
@@ -37,7 +36,7 @@ class ForwardMode(IntEnum):
EXTEND
=
auto
()
# Decode one token.
DECODE
=
auto
()
# Contains both
PREFILL and EXTEND
.
# Contains both
EXTEND and DECODE
.
MIXED
=
auto
()
def
is_prefill
(
self
):
...
...
@@ -57,15 +56,17 @@ class ForwardMode(IntEnum):
class
InputMetadata
:
"""Store all inforamtion of a forward pass."""
# The forward mode
forward_mode
:
ForwardMode
# The batch size
batch_size
:
int
# The input ids
input_ids
:
torch
.
Tensor
# The indices of requests in the req_to_token_pool
req_pool_indices
:
torch
.
Tensor
# The sequence length
seq_lens
:
torch
.
Tensor
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
BaseTokenToKVPool
attn_backend
:
AttentionBackend
# Output location of the KV cache
# The indices of output tokens in the token_to_kv_pool
out_cache_loc
:
torch
.
Tensor
# Position information
...
...
@@ -75,7 +76,6 @@ class InputMetadata:
extend_seq_lens
:
torch
.
Tensor
=
None
extend_prefix_lens
:
torch
.
Tensor
=
None
extend_start_loc
:
torch
.
Tensor
=
None
extend_no_prefix
:
bool
=
None
# For logprob
return_logprob
:
bool
=
False
...
...
@@ -86,82 +86,51 @@ class InputMetadata:
# For multimodal
image_inputs
:
List
[
ImageInputs
]
=
None
def
init_multimuldal_info
(
self
,
batch
:
ScheduleBatch
):
self
.
image_inputs
=
[
r
.
image_inputs
for
r
in
batch
.
reqs
]
# For LoRA
lora_paths
:
List
[
str
]
=
None
def
compute_positions
(
self
,
batch
:
ScheduleBatch
):
if
self
.
forward_mode
.
is_decode
():
if
True
:
self
.
positions
=
self
.
seq_lens
-
1
else
:
# Deprecated
self
.
positions
=
(
self
.
seq_lens
-
1
)
+
batch
.
position_ids_offsets
else
:
if
True
:
self
.
positions
=
torch
.
tensor
(
np
.
concatenate
(
[
np
.
arange
(
batch
.
prefix_lens_cpu
[
i
],
len
(
req
.
fill_ids
))
for
i
,
req
in
enumerate
(
batch
.
reqs
)
],
axis
=
0
,
),
device
=
"cuda"
,
)
else
:
# Deprecated
position_ids_offsets_cpu
=
batch
.
position_ids_offsets
.
cpu
().
numpy
()
self
.
positions
=
torch
.
tensor
(
np
.
concatenate
(
[
np
.
arange
(
batch
.
prefix_lens_cpu
[
i
]
+
position_ids_offsets_cpu
[
i
],
len
(
req
.
fill_ids
)
+
position_ids_offsets_cpu
[
i
],
)
for
i
,
req
in
enumerate
(
batch
.
reqs
)
],
axis
=
0
,
),
device
=
"cuda"
,
)
# Positions should be in long type
self
.
positions
=
self
.
positions
.
to
(
torch
.
int64
)
def
compute_extend_infos
(
self
,
batch
:
ScheduleBatch
):
self
.
extend_seq_lens
=
torch
.
tensor
(
batch
.
extend_lens_cpu
,
device
=
"cuda"
)
self
.
extend_prefix_lens
=
torch
.
tensor
(
batch
.
prefix_lens_cpu
,
device
=
"cuda"
)
self
.
extend_start_loc
=
torch
.
zeros_like
(
self
.
extend_seq_lens
)
self
.
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
[:
-
1
],
dim
=
0
)
self
.
extend_no_prefix
=
all
(
x
==
0
for
x
in
batch
.
prefix_lens_cpu
)
self
.
extend_seq_lens_cpu
=
batch
.
extend_lens_cpu
self
.
extend_logprob_start_lens_cpu
=
batch
.
extend_logprob_start_lens_cpu
# Attention backend
req_to_token_pool
:
ReqToTokenPool
=
None
token_to_kv_pool
:
BaseTokenToKVPool
=
None
attn_backend
:
AttentionBackend
=
None
@
classmethod
def
from_schedule_batch
(
cls
,
model_runner
:
"ModelRunner"
,
batch
:
ScheduleBatch
,
):
ret
=
cls
(
forward_mode
=
batch
.
forward_mode
,
batch_size
=
batch
.
batch_size
(),
input_ids
=
batch
.
input_ids
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
attn_backend
=
model_runner
.
attn_backend
,
out_cache_loc
=
batch
.
out_cache_loc
,
return_logprob
=
batch
.
return_logprob
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
lora_paths
=
[
req
.
lora_path
for
req
in
batch
.
reqs
],
)
ret
.
compute_positions
(
batch
)
if
not
batch
.
forward_mode
.
is_decode
():
ret
.
init_multimuldal_info
(
batch
)
ret
.
compute_extend_infos
(
batch
)
model_runner
.
attn_backend
.
init_forward_metadata
(
batch
,
ret
)
if
ret
.
forward_mode
.
is_decode
():
ret
.
positions
=
(
ret
.
seq_lens
-
1
).
to
(
torch
.
int64
)
else
:
ret
.
positions
=
torch
.
tensor
(
np
.
concatenate
(
[
np
.
arange
(
batch
.
prefix_lens_cpu
[
i
],
len
(
req
.
fill_ids
))
for
i
,
req
in
enumerate
(
batch
.
reqs
)
],
axis
=
0
,
),
device
=
"cuda"
,
).
to
(
torch
.
int64
)
ret
.
image_inputs
=
[
r
.
image_inputs
for
r
in
batch
.
reqs
]
ret
.
extend_seq_lens
=
torch
.
tensor
(
batch
.
extend_lens_cpu
,
device
=
"cuda"
)
ret
.
extend_prefix_lens
=
torch
.
tensor
(
batch
.
prefix_lens_cpu
,
device
=
"cuda"
)
ret
.
extend_start_loc
=
torch
.
zeros_like
(
ret
.
extend_seq_lens
)
ret
.
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
ret
.
extend_seq_lens
[:
-
1
],
dim
=
0
)
ret
.
extend_seq_lens_cpu
=
batch
.
extend_lens_cpu
ret
.
extend_logprob_start_lens_cpu
=
batch
.
extend_logprob_start_lens_cpu
return
ret
python/sglang/srt/model_executor/model_runner.py
View file @
3f0fe08d
...
...
@@ -466,46 +466,47 @@ class ModelRunner:
logger
.
info
(
"Capture cuda graph begin. This can take up to several minutes."
)
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
def
forward_decode
(
self
,
batch
:
ScheduleBatch
):
if
self
.
server_args
.
lora_paths
is
not
None
:
self
.
lora_manager
.
prepare_lora_batch
(
batch
)
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
len
(
batch
.
reqs
)):
return
self
.
cuda_graph_runner
.
replay
(
batch
)
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
)
def
forward_decode
(
self
,
input_metadata
:
InputMetadata
):
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
input_metadata
.
batch_size
):
return
self
.
cuda_graph_runner
.
replay
(
input_metadata
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
input_metadata
.
input_ids
,
input_metadata
.
positions
,
input_metadata
)
def
forward_extend
(
self
,
batch
:
ScheduleBatch
):
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
)
if
self
.
server_args
.
lora_paths
is
not
None
:
self
.
lora_manager
.
prepare_lora_batch
(
batch
,
input_metadata
.
extend_seq_lens
)
def
forward_extend
(
self
,
input_metadata
:
InputMetadata
):
if
self
.
is_generation
:
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
input_metadata
.
input_ids
,
input_metadata
.
positions
,
input_metadata
)
else
:
# Only embedding models have get_embedding parameter
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
input_ids
,
input_metadata
.
positions
,
input_metadata
,
get_embedding
=
True
,
)
def
forward
(
self
,
batch
:
ScheduleBatch
)
->
Tuple
[
LogitsProcessorOutput
]:
assert
batch
.
forward_mode
is
not
None
def
forward
(
self
,
input_metadata
:
InputMetadata
)
->
LogitsProcessorOutput
:
# Attach attention information
input_metadata
.
req_to_token_pool
=
self
.
req_to_token_pool
input_metadata
.
token_to_kv_pool
=
self
.
token_to_kv_pool
input_metadata
.
attn_backend
=
self
.
attn_backend
input_metadata
.
attn_backend
.
init_forward_metadata
(
input_metadata
)
# Attach lora information
if
self
.
server_args
.
lora_paths
is
not
None
:
self
.
lora_manager
.
prepare_lora_batch
(
input_metadata
)
if
batch
.
forward_mode
.
is_decode
():
return
self
.
forward_decode
(
batch
)
elif
batch
.
forward_mode
.
is_extend
():
return
self
.
forward_extend
(
batch
)
if
input_metadata
.
forward_mode
.
is_decode
():
return
self
.
forward_decode
(
input_metadata
)
elif
input_metadata
.
forward_mode
.
is_extend
():
return
self
.
forward_extend
(
input_metadata
)
else
:
raise
ValueError
(
f
"Invaid forward mode:
{
batch
.
forward_mode
}
"
)
raise
ValueError
(
f
"Invaid forward mode:
{
input_metadata
.
forward_mode
}
"
)
def
_apply_logits_bias
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
...
...
python/sglang/test/runners.py
View file @
3f0fe08d
...
...
@@ -71,10 +71,10 @@ class ModelOutput:
class
HFRunner
:
def
__init__
(
self
,
model_path
,
torch_dtype
,
model_type
=
"generation"
,
output_str_only
=
False
,
model_path
:
str
,
torch_dtype
:
torch
.
dtype
,
model_type
:
str
=
"generation"
,
output_str_only
:
bool
=
False
,
):
self
.
model_type
=
model_type
self
.
output_str_only
=
output_str_only
...
...
@@ -244,15 +244,15 @@ class HFRunner:
class
SRTRunner
:
def
__init__
(
self
,
model_path
,
torch_dtype
,
model_type
,
tp_size
=
1
,
port
=
DEFAULT_PORT_FOR_SRT_TEST_RUNNER
,
lora_paths
=
None
,
max_loras_per_batch
=
4
,
disable_cuda_graph
=
False
,
disable_radix_cache
=
False
,
model_path
:
str
,
torch_dtype
:
torch
.
dtype
,
model_type
:
str
,
tp_size
:
int
=
1
,
port
:
int
=
DEFAULT_PORT_FOR_SRT_TEST_RUNNER
,
lora_paths
:
List
[
str
]
=
None
,
max_loras_per_batch
:
int
=
4
,
disable_cuda_graph
:
bool
=
False
,
disable_radix_cache
:
bool
=
False
,
):
self
.
model_type
=
model_type
self
.
is_generation
=
model_type
==
"generation"
...
...
test/srt/models/test_lora.py
View file @
3f0fe08d
...
...
@@ -15,7 +15,6 @@ limitations under the License.
import
multiprocessing
as
mp
import
unittest
import
uuid
import
torch
...
...
@@ -85,9 +84,9 @@ class TestLoRA(unittest.TestCase):
with
SRTRunner
(
base_path
,
tp_size
=
tp_size
,
torch_dtype
=
torch_dtype
,
is_generation
=
True
,
model_type
=
"generation"
,
tp_size
=
tp_size
,
lora_paths
=
all_lora_paths
,
max_loras_per_batch
=
3
,
disable_cuda_graph
=
True
,
...
...
test/srt/run_suite.py
View file @
3f0fe08d
...
...
@@ -7,6 +7,7 @@ suites = {
"minimal"
:
[
"models/test_embedding_models.py"
,
"models/test_generation_models.py"
,
# "models/test_lora.py",
"models/test_reward_models.py"
,
"sampling/penaltylib"
,
"test_chunked_prefill.py"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment