Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
87e8c090
Unverified
Commit
87e8c090
authored
Aug 06, 2024
by
Liangsheng Yin
Committed by
GitHub
Aug 06, 2024
Browse files
Organize code (rename, movement) (#953)
parent
ad56e684
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
295 additions
and
278 deletions
+295
-278
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+3
-2
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+1
-1
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+2
-5
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+2
-235
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+10
-10
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+3
-3
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+256
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+6
-10
python/sglang/srt/models/chatglm.py
python/sglang/srt/models/chatglm.py
+1
-1
python/sglang/srt/models/commandr.py
python/sglang/srt/models/commandr.py
+1
-1
python/sglang/srt/models/dbrx.py
python/sglang/srt/models/dbrx.py
+1
-1
python/sglang/srt/models/deepseek.py
python/sglang/srt/models/deepseek.py
+1
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+1
-1
python/sglang/srt/models/gemma.py
python/sglang/srt/models/gemma.py
+1
-1
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+1
-1
python/sglang/srt/models/gpt_bigcode.py
python/sglang/srt/models/gpt_bigcode.py
+1
-1
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+1
-1
python/sglang/srt/models/internlm2.py
python/sglang/srt/models/internlm2.py
+1
-1
python/sglang/srt/models/llama2.py
python/sglang/srt/models/llama2.py
+1
-1
python/sglang/srt/models/llama_classification.py
python/sglang/srt/models/llama_classification.py
+1
-1
No files found.
python/sglang/bench_latency.py
View file @
87e8c090
...
...
@@ -50,8 +50,9 @@ import torch
import
torch.distributed
as
dist
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.schedule_batch
import
Batch
,
ForwardMode
,
Req
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
...
...
@@ -188,7 +189,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
def
extend
(
reqs
,
model_runner
):
batch
=
Batch
.
init_new
(
batch
=
Schedule
Batch
.
init_new
(
reqs
=
reqs
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
...
...
python/sglang/srt/layers/logits_processor.py
View file @
87e8c090
...
...
@@ -25,7 +25,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather
,
)
from
sglang.srt.model_executor.
model_runner
import
ForwardMode
,
InputMetadata
from
sglang.srt.model_executor.
forward_batch_info
import
ForwardMode
,
InputMetadata
@
dataclasses
.
dataclass
...
...
python/sglang/srt/layers/radix_attention.py
View file @
87e8c090
...
...
@@ -22,11 +22,8 @@ from torch import nn
from
sglang.global_config
import
global_config
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
from
sglang.srt.layers.token_attention
import
token_attention_fwd
from
sglang.srt.model_executor.model_runner
import
(
ForwardMode
,
InputMetadata
,
global_server_args_dict
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.model_executor.model_runner
import
global_server_args_dict
class
RadixAttention
(
nn
.
Module
):
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
87e8c090
...
...
@@ -18,7 +18,6 @@ limitations under the License.
import
logging
import
warnings
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
typing
import
List
,
Union
import
numpy
as
np
...
...
@@ -46,15 +45,6 @@ global_server_args_dict = {
logger
=
logging
.
getLogger
(
__name__
)
class
ForwardMode
(
IntEnum
):
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
PREFILL
=
auto
()
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
EXTEND
=
auto
()
# Decode one token.
DECODE
=
auto
()
class
BaseFinishReason
:
def
__init__
(
self
,
is_error
:
bool
=
False
):
self
.
is_error
=
is_error
...
...
@@ -284,7 +274,7 @@ class Req:
@
dataclass
class
Batch
:
class
Schedule
Batch
:
"""Store all inforamtion of a batch."""
# Request, memory pool, and cache
...
...
@@ -673,7 +663,7 @@ class Batch:
if
self_val
is
not
None
:
# logit_bias can be None
setattr
(
self
,
item
,
self_val
[
new_indices
])
def
merge
(
self
,
other
:
"Batch"
):
def
merge
(
self
,
other
:
"
Schedule
Batch"
):
self
.
reqs
.
extend
(
other
.
reqs
)
self
.
req_pool_indices
=
torch
.
concat
(
...
...
@@ -770,229 +760,6 @@ class Batch:
return
batch_next_token_ids
@
dataclass
class
InputMetadata
:
"""Store all inforamtion of a forward pass."""
forward_mode
:
ForwardMode
batch_size
:
int
total_num_tokens
:
int
req_pool_indices
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
positions
:
torch
.
Tensor
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
BaseTokenToKVPool
# For extend
extend_seq_lens
:
torch
.
Tensor
extend_start_loc
:
torch
.
Tensor
extend_no_prefix
:
bool
# Output location of the KV cache
out_cache_loc
:
torch
.
Tensor
=
None
# Output options
return_logprob
:
bool
=
False
top_logprobs_nums
:
List
[
int
]
=
None
# Trition attention backend
triton_max_seq_len
:
int
=
0
triton_max_extend_len
:
int
=
0
triton_start_loc
:
torch
.
Tensor
=
None
triton_prefix_lens
:
torch
.
Tensor
=
None
# FlashInfer attention backend
flashinfer_prefill_wrapper_ragged
:
"BatchPrefillWithRaggedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper_paged
:
"BatchPrefillWithPagedKVCacheWrapper"
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
flashinfer_use_ragged
:
bool
=
False
@
classmethod
def
create
(
cls
,
model_runner
,
forward_mode
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
position_ids_offsets
,
out_cache_loc
,
top_logprobs_nums
=
None
,
return_logprob
=
False
,
skip_flashinfer_init
=
False
,
):
flashinfer_use_ragged
=
False
if
not
skip_flashinfer_init
and
not
model_runner
.
server_args
.
disable_flashinfer
:
if
forward_mode
!=
ForwardMode
.
DECODE
and
int
(
torch
.
sum
(
seq_lens
))
>
4096
:
flashinfer_use_ragged
=
True
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
model_runner
.
flashinfer_decode_wrapper
,
flashinfer_use_ragged
,
)
batch_size
=
len
(
req_pool_indices
)
if
forward_mode
==
ForwardMode
.
DECODE
:
positions
=
((
seq_lens
-
1
)
+
position_ids_offsets
).
to
(
torch
.
int64
)
extend_seq_lens
=
extend_start_loc
=
extend_no_prefix
=
None
if
not
model_runner
.
server_args
.
disable_flashinfer
:
# This variable is not needed in this case,
# we do not compute it to make it compatbile with cuda graph.
total_num_tokens
=
None
else
:
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
else
:
seq_lens_cpu
=
seq_lens
.
cpu
().
numpy
()
prefix_lens_cpu
=
prefix_lens
.
cpu
().
numpy
()
position_ids_offsets_cpu
=
position_ids_offsets
.
cpu
().
numpy
()
positions
=
torch
.
tensor
(
np
.
concatenate
(
[
np
.
arange
(
prefix_lens_cpu
[
i
]
+
position_ids_offsets_cpu
[
i
],
seq_lens_cpu
[
i
]
+
position_ids_offsets_cpu
[
i
],
)
for
i
in
range
(
batch_size
)
],
axis
=
0
,
),
device
=
"cuda"
,
)
extend_seq_lens
=
seq_lens
-
prefix_lens
extend_start_loc
=
torch
.
zeros_like
(
seq_lens
)
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
extend_seq_lens
[:
-
1
],
dim
=
0
)
extend_no_prefix
=
torch
.
all
(
prefix_lens
==
0
)
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
ret
=
cls
(
forward_mode
=
forward_mode
,
batch_size
=
batch_size
,
total_num_tokens
=
total_num_tokens
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
positions
=
positions
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
out_cache_loc
=
out_cache_loc
,
extend_seq_lens
=
extend_seq_lens
,
extend_start_loc
=
extend_start_loc
,
extend_no_prefix
=
extend_no_prefix
,
return_logprob
=
return_logprob
,
top_logprobs_nums
=
top_logprobs_nums
,
flashinfer_prefill_wrapper_ragged
=
model_runner
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
model_runner
.
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
,
flashinfer_use_ragged
=
flashinfer_use_ragged
,
)
if
model_runner
.
server_args
.
disable_flashinfer
:
(
ret
.
triton_max_seq_len
,
ret
.
triton_max_extend_len
,
ret
.
triton_start_loc
,
ret
.
triton_prefix_lens
,
)
=
init_triton_args
(
forward_mode
,
seq_lens
,
prefix_lens
)
return
ret
def
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
flashinfer_decode_wrapper
,
flashinfer_use_ragged
=
False
,
):
"""Init auxiliary variables for FlashInfer attention backend."""
num_qo_heads
=
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
)
head_dim
=
model_runner
.
model_config
.
head_dim
batch_size
=
len
(
req_pool_indices
)
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
if
flashinfer_use_ragged
:
paged_kernel_lens
=
prefix_lens
else
:
paged_kernel_lens
=
seq_lens
kv_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
req_pool_indices_cpu
=
req_pool_indices
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
kv_indices
=
torch
.
cat
(
[
model_runner
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
i
],
:
paged_kernel_lens_cpu
[
i
]
]
for
i
in
range
(
batch_size
)
],
dim
=
0
,
).
contiguous
()
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
forward_mode
==
ForwardMode
.
DECODE
:
flashinfer_decode_wrapper
.
end_forward
()
flashinfer_decode_wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
else
:
# extend part
qo_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
if
flashinfer_use_ragged
:
model_runner
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
qo_indptr
,
qo_indptr
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
)
# cached part
model_runner
.
flashinfer_prefill_wrapper_paged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_paged
.
begin_forward
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
def
init_triton_args
(
forward_mode
,
seq_lens
,
prefix_lens
):
"""Init auxiliary variables for triton attention backend."""
batch_size
=
len
(
seq_lens
)
max_seq_len
=
int
(
torch
.
max
(
seq_lens
))
start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
start_loc
[
1
:]
=
torch
.
cumsum
(
seq_lens
[:
-
1
],
dim
=
0
)
if
forward_mode
==
ForwardMode
.
DECODE
:
max_extend_len
=
None
else
:
extend_seq_lens
=
seq_lens
-
prefix_lens
max_extend_len
=
int
(
torch
.
max
(
extend_seq_lens
))
return
max_seq_len
,
max_extend_len
,
start_loc
,
prefix_lens
def
top_k_top_p_sampling_from_probs_torch
(
probs
:
torch
.
Tensor
,
top_ks
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
):
...
...
python/sglang/srt/managers/tp_worker.py
View file @
87e8c090
...
...
@@ -39,13 +39,13 @@ from sglang.srt.managers.policy_scheduler import PolicyScheduler
from
sglang.srt.managers.schedule_batch
import
(
FINISH_ABORT
,
BaseFinishReason
,
Batch
,
ForwardMode
,
Req
,
ScheduleBatch
,
)
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
...
...
@@ -172,7 +172,7 @@ class ModelTpServer:
# Init running status
self
.
waiting_queue
:
List
[
Req
]
=
[]
self
.
running_batch
:
Batch
=
None
self
.
running_batch
:
Schedule
Batch
=
None
self
.
out_pyobjs
=
[]
self
.
decode_forward_ct
=
0
self
.
stream_interval
=
server_args
.
stream_interval
...
...
@@ -353,7 +353,7 @@ class ModelTpServer:
)
self
.
waiting_queue
.
append
(
req
)
def
get_new_prefill_batch
(
self
)
->
Optional
[
Batch
]:
def
get_new_prefill_batch
(
self
)
->
Optional
[
Schedule
Batch
]:
# TODO(lsyin): organize this function
running_bs
=
(
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
is
not
None
else
0
...
...
@@ -526,7 +526,7 @@ class ModelTpServer:
)
# Return the new batch
new_batch
=
Batch
.
init_new
(
new_batch
=
Schedule
Batch
.
init_new
(
can_run_list
,
self
.
req_to_token_pool
,
self
.
token_to_kv_pool
,
...
...
@@ -535,7 +535,7 @@ class ModelTpServer:
self
.
waiting_queue
=
[
x
for
x
in
self
.
waiting_queue
if
x
not
in
can_run_list
]
return
new_batch
def
forward_prefill_batch
(
self
,
batch
:
Batch
):
def
forward_prefill_batch
(
self
,
batch
:
Schedule
Batch
):
# Build batch tensors
batch
.
prepare_for_extend
(
self
.
model_config
.
vocab_size
,
self
.
int_token_logit_bias
...
...
@@ -624,7 +624,7 @@ class ModelTpServer:
)
req
.
output_top_logprobs
.
append
(
output
.
output_top_logprobs
[
i
])
def
cache_filled_batch
(
self
,
batch
:
Batch
):
def
cache_filled_batch
(
self
,
batch
:
Schedule
Batch
):
req_pool_indices_cpu
=
batch
.
req_pool_indices
.
cpu
().
numpy
()
for
i
,
req
in
enumerate
(
batch
.
reqs
):
new_prefix_indices
,
new_last_node
=
self
.
tree_cache
.
cache_req
(
...
...
@@ -641,7 +641,7 @@ class ModelTpServer:
# inflight request would get a new req idx
self
.
req_to_token_pool
.
free
(
int
(
req_pool_indices_cpu
[
i
]))
def
forward_decode_batch
(
self
,
batch
:
Batch
):
def
forward_decode_batch
(
self
,
batch
:
Schedule
Batch
):
# Check if decode out of memory
if
not
batch
.
check_decode_mem
():
old_ratio
=
self
.
new_token_ratio
...
...
@@ -700,7 +700,7 @@ class ModelTpServer:
self
.
handle_finished_requests
(
batch
)
def
handle_finished_requests
(
self
,
batch
:
Batch
):
def
handle_finished_requests
(
self
,
batch
:
Schedule
Batch
):
output_rids
=
[]
output_vids
=
[]
decoded_texts
=
[]
...
...
@@ -800,7 +800,7 @@ class ModelTpServer:
else
:
batch
.
reqs
=
[]
def
filter_out_inflight
(
self
,
batch
:
Batch
):
def
filter_out_inflight
(
self
,
batch
:
Schedule
Batch
):
# TODO(lsyin): reduce the overhead, make a special version for this
if
self
.
current_inflight_req
is
None
:
return
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
87e8c090
...
...
@@ -29,8 +29,8 @@ from sglang.srt.layers.logits_processor import (
LogitsMetadata
,
LogitsProcessor
,
)
from
sglang.srt.managers.schedule_batch
import
(
Batch
,
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.model_executor.forward_batch_info
import
(
ForwardMode
,
InputMetadata
,
init_flashinfer_args
,
...
...
@@ -202,7 +202,7 @@ class CudaGraphRunner:
self
.
graph_memory_pool
=
graph
.
pool
()
return
graph
,
None
,
out
,
flashinfer_decode_wrapper
def
replay
(
self
,
batch
:
Batch
):
def
replay
(
self
,
batch
:
Schedule
Batch
):
assert
batch
.
out_cache_loc
is
not
None
raw_bs
=
len
(
batch
.
reqs
)
...
...
python/sglang/srt/model_executor/forward_batch_info.py
0 → 100644
View file @
87e8c090
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""ModelRunner runs the forward passes of the models."""
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
typing
import
List
import
numpy
as
np
import
torch
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
class
ForwardMode
(
IntEnum
):
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
PREFILL
=
auto
()
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
EXTEND
=
auto
()
# Decode one token.
DECODE
=
auto
()
@
dataclass
class
InputMetadata
:
"""Store all inforamtion of a forward pass."""
forward_mode
:
ForwardMode
batch_size
:
int
total_num_tokens
:
int
req_pool_indices
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
positions
:
torch
.
Tensor
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
BaseTokenToKVPool
# For extend
extend_seq_lens
:
torch
.
Tensor
extend_start_loc
:
torch
.
Tensor
extend_no_prefix
:
bool
# Output location of the KV cache
out_cache_loc
:
torch
.
Tensor
=
None
# Output options
return_logprob
:
bool
=
False
top_logprobs_nums
:
List
[
int
]
=
None
# Trition attention backend
triton_max_seq_len
:
int
=
0
triton_max_extend_len
:
int
=
0
triton_start_loc
:
torch
.
Tensor
=
None
triton_prefix_lens
:
torch
.
Tensor
=
None
# FlashInfer attention backend
flashinfer_prefill_wrapper_ragged
:
"BatchPrefillWithRaggedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper_paged
:
"BatchPrefillWithPagedKVCacheWrapper"
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
flashinfer_use_ragged
:
bool
=
False
@
classmethod
def
create
(
cls
,
model_runner
,
forward_mode
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
position_ids_offsets
,
out_cache_loc
,
top_logprobs_nums
=
None
,
return_logprob
=
False
,
skip_flashinfer_init
=
False
,
):
flashinfer_use_ragged
=
False
if
not
skip_flashinfer_init
and
not
model_runner
.
server_args
.
disable_flashinfer
:
if
forward_mode
!=
ForwardMode
.
DECODE
and
int
(
torch
.
sum
(
seq_lens
))
>
4096
:
flashinfer_use_ragged
=
True
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
model_runner
.
flashinfer_decode_wrapper
,
flashinfer_use_ragged
,
)
batch_size
=
len
(
req_pool_indices
)
if
forward_mode
==
ForwardMode
.
DECODE
:
positions
=
((
seq_lens
-
1
)
+
position_ids_offsets
).
to
(
torch
.
int64
)
extend_seq_lens
=
extend_start_loc
=
extend_no_prefix
=
None
if
not
model_runner
.
server_args
.
disable_flashinfer
:
# This variable is not needed in this case,
# we do not compute it to make it compatbile with cuda graph.
total_num_tokens
=
None
else
:
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
else
:
seq_lens_cpu
=
seq_lens
.
cpu
().
numpy
()
prefix_lens_cpu
=
prefix_lens
.
cpu
().
numpy
()
position_ids_offsets_cpu
=
position_ids_offsets
.
cpu
().
numpy
()
positions
=
torch
.
tensor
(
np
.
concatenate
(
[
np
.
arange
(
prefix_lens_cpu
[
i
]
+
position_ids_offsets_cpu
[
i
],
seq_lens_cpu
[
i
]
+
position_ids_offsets_cpu
[
i
],
)
for
i
in
range
(
batch_size
)
],
axis
=
0
,
),
device
=
"cuda"
,
)
extend_seq_lens
=
seq_lens
-
prefix_lens
extend_start_loc
=
torch
.
zeros_like
(
seq_lens
)
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
extend_seq_lens
[:
-
1
],
dim
=
0
)
extend_no_prefix
=
torch
.
all
(
prefix_lens
==
0
)
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
ret
=
cls
(
forward_mode
=
forward_mode
,
batch_size
=
batch_size
,
total_num_tokens
=
total_num_tokens
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
positions
=
positions
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
out_cache_loc
=
out_cache_loc
,
extend_seq_lens
=
extend_seq_lens
,
extend_start_loc
=
extend_start_loc
,
extend_no_prefix
=
extend_no_prefix
,
return_logprob
=
return_logprob
,
top_logprobs_nums
=
top_logprobs_nums
,
flashinfer_prefill_wrapper_ragged
=
model_runner
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
model_runner
.
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
,
flashinfer_use_ragged
=
flashinfer_use_ragged
,
)
if
model_runner
.
server_args
.
disable_flashinfer
:
(
ret
.
triton_max_seq_len
,
ret
.
triton_max_extend_len
,
ret
.
triton_start_loc
,
ret
.
triton_prefix_lens
,
)
=
init_triton_args
(
forward_mode
,
seq_lens
,
prefix_lens
)
return
ret
def
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
flashinfer_decode_wrapper
,
flashinfer_use_ragged
=
False
,
):
"""Init auxiliary variables for FlashInfer attention backend."""
num_qo_heads
=
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
)
head_dim
=
model_runner
.
model_config
.
head_dim
batch_size
=
len
(
req_pool_indices
)
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
if
flashinfer_use_ragged
:
paged_kernel_lens
=
prefix_lens
else
:
paged_kernel_lens
=
seq_lens
kv_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
req_pool_indices_cpu
=
req_pool_indices
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
kv_indices
=
torch
.
cat
(
[
model_runner
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
i
],
:
paged_kernel_lens_cpu
[
i
]
]
for
i
in
range
(
batch_size
)
],
dim
=
0
,
).
contiguous
()
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
forward_mode
==
ForwardMode
.
DECODE
:
flashinfer_decode_wrapper
.
end_forward
()
flashinfer_decode_wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
else
:
# extend part
qo_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
if
flashinfer_use_ragged
:
model_runner
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
qo_indptr
,
qo_indptr
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
)
# cached part
model_runner
.
flashinfer_prefill_wrapper_paged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_paged
.
begin_forward
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
def
init_triton_args
(
forward_mode
,
seq_lens
,
prefix_lens
):
"""Init auxiliary variables for triton attention backend."""
batch_size
=
len
(
seq_lens
)
max_seq_len
=
int
(
torch
.
max
(
seq_lens
))
start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
start_loc
[
1
:]
=
torch
.
cumsum
(
seq_lens
[:
-
1
],
dim
=
0
)
if
forward_mode
==
ForwardMode
.
DECODE
:
max_extend_len
=
None
else
:
extend_seq_lens
=
seq_lens
-
prefix_lens
max_extend_len
=
int
(
torch
.
max
(
extend_seq_lens
))
return
max_seq_len
,
max_extend_len
,
start_loc
,
prefix_lens
python/sglang/srt/model_executor/model_runner.py
View file @
87e8c090
...
...
@@ -41,18 +41,14 @@ from vllm.distributed import (
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.global_config
import
global_config
from
sglang.srt.managers.schedule_batch
import
(
Batch
,
ForwardMode
,
InputMetadata
,
global_server_args_dict
,
)
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
(
MHATokenToKVPool
,
MLATokenToKVPool
,
ReqToTokenPool
,
)
from
sglang.srt.model_config
import
AttentionArch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
get_available_gpu_memory
,
...
...
@@ -350,7 +346,7 @@ class ModelRunner:
)
@
torch
.
inference_mode
()
def
forward_decode
(
self
,
batch
:
Batch
):
def
forward_decode
(
self
,
batch
:
Schedule
Batch
):
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
len
(
batch
.
reqs
)):
return
self
.
cuda_graph_runner
.
replay
(
batch
)
...
...
@@ -370,7 +366,7 @@ class ModelRunner:
)
@
torch
.
inference_mode
()
def
forward_extend
(
self
,
batch
:
Batch
):
def
forward_extend
(
self
,
batch
:
Schedule
Batch
):
input_metadata
=
InputMetadata
.
create
(
self
,
forward_mode
=
ForwardMode
.
EXTEND
,
...
...
@@ -387,7 +383,7 @@ class ModelRunner:
)
@
torch
.
inference_mode
()
def
forward_extend_multi_modal
(
self
,
batch
:
Batch
):
def
forward_extend_multi_modal
(
self
,
batch
:
Schedule
Batch
):
input_metadata
=
InputMetadata
.
create
(
self
,
forward_mode
=
ForwardMode
.
EXTEND
,
...
...
@@ -408,7 +404,7 @@ class ModelRunner:
batch
.
image_offsets
,
)
def
forward
(
self
,
batch
:
Batch
,
forward_mode
:
ForwardMode
):
def
forward
(
self
,
batch
:
Schedule
Batch
,
forward_mode
:
ForwardMode
):
if
self
.
is_multimodal_model
and
forward_mode
==
ForwardMode
.
EXTEND
:
return
self
.
forward_extend_multi_modal
(
batch
)
elif
forward_mode
==
ForwardMode
.
DECODE
:
...
...
python/sglang/srt/models/chatglm.py
View file @
87e8c090
...
...
@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs import ChatGLMConfig
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.
model_runner
import
InputMetadata
from
sglang.srt.model_executor.
forward_batch_info
import
InputMetadata
LoraConfig
=
None
...
...
python/sglang/srt/models/commandr.py
View file @
87e8c090
...
...
@@ -64,7 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.
model_runner
import
InputMetadata
from
sglang.srt.model_executor.
forward_batch_info
import
InputMetadata
@
torch
.
compile
...
...
python/sglang/srt/models/dbrx.py
View file @
87e8c090
...
...
@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.
model_runner
import
InputMetadata
from
sglang.srt.model_executor.
forward_batch_info
import
InputMetadata
class
DbrxRouter
(
nn
.
Module
):
...
...
python/sglang/srt/models/deepseek.py
View file @
87e8c090
...
...
@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.m
anagers.schedule
_batch
import
InputMetadata
from
sglang.srt.m
odel_executor.forward
_batch
_info
import
InputMetadata
class
DeepseekMLP
(
nn
.
Module
):
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
87e8c090
...
...
@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.
model_runner
import
InputMetadata
from
sglang.srt.model_executor.
forward_batch_info
import
InputMetadata
class
DeepseekV2MLP
(
nn
.
Module
):
...
...
python/sglang/srt/models/gemma.py
View file @
87e8c090
...
...
@@ -37,7 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.
model_runner
import
InputMetadata
from
sglang.srt.model_executor.
forward_batch_info
import
InputMetadata
class
GemmaMLP
(
nn
.
Module
):
...
...
python/sglang/srt/models/gemma2.py
View file @
87e8c090
...
...
@@ -42,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.
model_runner
import
InputMetadata
from
sglang.srt.model_executor.
forward_batch_info
import
InputMetadata
class
GemmaRMSNorm
(
CustomOp
):
...
...
python/sglang/srt/models/gpt_bigcode.py
View file @
87e8c090
...
...
@@ -35,7 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.m
anagers.schedule
_batch
import
InputMetadata
from
sglang.srt.m
odel_executor.forward
_batch
_info
import
InputMetadata
class
GPTBigCodeAttention
(
nn
.
Module
):
...
...
python/sglang/srt/models/grok.py
View file @
87e8c090
...
...
@@ -52,7 +52,7 @@ from vllm.utils import print_warning_once
from
sglang.srt.layers.fused_moe
import
fused_moe
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.
model_runner
import
InputMetadata
from
sglang.srt.model_executor.
forward_batch_info
import
InputMetadata
use_fused
=
True
...
...
python/sglang/srt/models/internlm2.py
View file @
87e8c090
...
...
@@ -40,7 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.
model_runner
import
InputMetadata
from
sglang.srt.model_executor.
forward_batch_info
import
InputMetadata
class
InternLM2MLP
(
nn
.
Module
):
...
...
python/sglang/srt/models/llama2.py
View file @
87e8c090
...
...
@@ -41,7 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.
model_runner
import
InputMetadata
from
sglang.srt.model_executor.
forward_batch_info
import
InputMetadata
class
LlamaMLP
(
nn
.
Module
):
...
...
python/sglang/srt/models/llama_classification.py
View file @
87e8c090
...
...
@@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitProcessorOutput
from
sglang.srt.model_executor.
model_runner
import
InputMetadata
from
sglang.srt.model_executor.
forward_batch_info
import
InputMetadata
from
sglang.srt.models.llama2
import
LlamaModel
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment