Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f25f4dfd
Unverified
Commit
f25f4dfd
authored
Aug 28, 2024
by
Yineng Zhang
Committed by
GitHub
Aug 28, 2024
Browse files
hotfix: revert sampler CUDA Graph (#1242)
parent
184ae1c6
Changes
33
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
88 additions
and
222 deletions
+88
-222
.github/workflows/e2e-test.yml
.github/workflows/e2e-test.yml
+0
-5
README.md
README.md
+1
-1
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+6
-4
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+4
-4
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+15
-68
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+8
-20
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+20
-32
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+9
-24
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+1
-8
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+3
-11
python/sglang/srt/models/chatglm.py
python/sglang/srt/models/chatglm.py
+12
-4
python/sglang/srt/models/commandr.py
python/sglang/srt/models/commandr.py
+1
-5
python/sglang/srt/models/dbrx.py
python/sglang/srt/models/dbrx.py
+1
-5
python/sglang/srt/models/deepseek.py
python/sglang/srt/models/deepseek.py
+1
-5
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+1
-5
python/sglang/srt/models/gemma.py
python/sglang/srt/models/gemma.py
+1
-5
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+1
-5
python/sglang/srt/models/gpt_bigcode.py
python/sglang/srt/models/gpt_bigcode.py
+1
-5
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+1
-5
No files found.
.github/workflows/e2e-test.yml
View file @
f25f4dfd
...
...
@@ -38,11 +38,6 @@ jobs:
cd test/srt
python3 -m unittest test_serving_throughput.TestServingThroughput.test_default
-
name
:
Benchmark Serving Latency
timeout-minutes
:
10
run
:
|
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3.1-8B-Instruct --batch-size 1 --input 128 --output 8
-
name
:
Benchmark Serving Throughput (w/o RadixAttention)
timeout-minutes
:
10
run
:
|
...
...
README.md
View file @
f25f4dfd
...
...
@@ -56,7 +56,7 @@ pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
### Method 2: From source
```
# Use the last release branch
git clone -b v0.2.14 https://github.com/sgl-project/sglang.git
git clone -b v0.2.14
.post1
https://github.com/sgl-project/sglang.git
cd sglang
pip install --upgrade pip
...
...
python/pyproject.toml
View file @
f25f4dfd
...
...
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name
=
"sglang"
version
=
"0.2.14"
version
=
"0.2.14
.post1
"
description
=
"SGLang is yet another fast serving framework for large language models and vision language models."
readme
=
"README.md"
requires-python
=
">=3.8"
...
...
python/sglang/bench_latency.py
View file @
f25f4dfd
...
...
@@ -200,14 +200,16 @@ def extend(reqs, model_runner):
tree_cache
=
None
,
)
batch
.
prepare_for_extend
(
model_runner
.
model_config
.
vocab_size
)
sample_output
,
logits_output
=
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
return
sample_output
.
batch_next_token_ids
,
logits_output
.
next_token_logits
,
batch
output
=
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
return
next_token_ids
,
output
.
next_token_logits
,
batch
def
decode
(
input_token_ids
,
batch
,
model_runner
):
batch
.
prepare_for_decode
(
input_token_ids
.
cpu
().
numpy
())
sample_output
,
logits_output
=
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
return
sample_output
.
batch_next_token_ids
,
logits_output
.
next_token_logits
output
=
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
return
next_token_ids
,
output
.
next_token_logits
@
torch
.
inference_mode
()
...
...
python/sglang/srt/layers/logits_processor.py
View file @
f25f4dfd
...
...
@@ -29,7 +29,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
@
dataclasses
.
dataclass
class
Logit
s
ProcessorOutput
:
class
LogitProcessorOutput
:
# The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits
:
torch
.
Tensor
# The logprobs of the next tokens. shape: [#seq, vocab_size]
...
...
@@ -185,7 +185,7 @@ class LogitsProcessor(nn.Module):
# Return only last_logits if logprob is not requested
if
not
logits_metadata
.
return_logprob
:
return
Logit
s
ProcessorOutput
(
return
LogitProcessorOutput
(
next_token_logits
=
last_logits
,
next_token_logprobs
=
None
,
normalized_prompt_logprobs
=
None
,
...
...
@@ -209,7 +209,7 @@ class LogitsProcessor(nn.Module):
else
:
output_top_logprobs
=
None
return
Logit
s
ProcessorOutput
(
return
LogitProcessorOutput
(
next_token_logits
=
last_logits
,
next_token_logprobs
=
last_logprobs
,
normalized_prompt_logprobs
=
None
,
...
...
@@ -278,7 +278,7 @@ class LogitsProcessor(nn.Module):
# Remove the last token logprob for the prefill tokens.
input_token_logprobs
=
input_token_logprobs
[:
-
1
]
return
Logit
s
ProcessorOutput
(
return
LogitProcessorOutput
(
next_token_logits
=
last_logits
,
next_token_logprobs
=
last_logprobs
,
normalized_prompt_logprobs
=
normalized_prompt_logprobs
,
...
...
python/sglang/srt/layers/sampler.py
View file @
f25f4dfd
import
dataclasses
import
logging
from
typing
import
Union
import
torch
from
flashinfer.sampling
import
(
...
...
@@ -11,8 +9,6 @@ from flashinfer.sampling import (
)
from
vllm.model_executor.custom_op
import
CustomOp
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
# TODO: move this dict to another place
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
...
...
@@ -20,71 +16,30 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
logger
=
logging
.
getLogger
(
__name__
)
@
dataclasses
.
dataclass
class
SampleOutput
:
success
:
torch
.
Tensor
probs
:
torch
.
Tensor
batch_next_token_ids
:
torch
.
Tensor
class
Sampler
(
CustomOp
):
def
__init__
(
self
):
super
().
__init__
()
def
_apply_penalties
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
# min-token, presence, frequency
if
sampling_info
.
linear_penalties
is
not
None
:
logits
+=
sampling_info
.
linear_penalties
# repetition
if
sampling_info
.
scaling_penalties
is
not
None
:
logits
=
torch
.
where
(
logits
>
0
,
logits
/
sampling_info
.
scaling_penalties
,
logits
*
sampling_info
.
scaling_penalties
,
)
return
logits
def
_get_probs
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
,
is_torch_compile
:
bool
=
False
,
):
def
forward_cuda
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
# Post process logits
logits
=
logits
.
contiguous
()
logits
.
div_
(
sampling_info
.
temperatures
)
if
is_torch_compile
:
# FIXME: Temporary workaround for unknown bugs in torch.compile
logits
.
add_
(
0
)
if
sampling_info
.
logit_bias
is
not
None
:
logits
.
add_
(
sampling_info
.
logit_bias
)
if
sampling_info
.
vocab_mask
is
not
None
:
logits
=
logits
.
masked_fill
(
~
sampling_info
.
vocab_mask
,
float
(
"-inf"
))
logits
=
s
elf
.
_apply_penalties
(
logits
,
sampling_info
)
logits
=
s
ampling_info
.
penalizer_orchestrator
.
apply
(
logits
)
return
torch
.
softmax
(
logits
,
dim
=-
1
)
def
forward_cuda
(
self
,
logits
:
Union
[
torch
.
Tensor
,
LogitsProcessorOutput
],
sampling_info
:
SamplingBatchInfo
,
):
if
isinstance
(
logits
,
LogitsProcessorOutput
):
logits
=
logits
.
next_token_logits
probs
=
self
.
_get_probs
(
logits
,
sampling_info
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
if
not
global_server_args_dict
[
"disable_flashinfer_sampling"
]:
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
uniform_samples
=
torch
.
rand
(
(
max_top_k_round
,
batch_size
),
device
=
probs
.
device
)
if
sampling_info
.
need_min_p_sampling
:
if
sampling_info
.
min_ps
.
any
()
:
probs
=
top_k_renorm_prob
(
probs
,
sampling_info
.
top_ks
)
probs
=
top_p_renorm_prob
(
probs
,
sampling_info
.
top_ps
)
batch_next_token_ids
,
success
=
min_p_sampling_from_probs
(
...
...
@@ -100,23 +55,18 @@ class Sampler(CustomOp):
probs
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
,
sampling_info
.
min_ps
)
return
SampleOutput
(
success
,
probs
,
batch_next_token_ids
)
def
forward_native
(
self
,
logits
:
Union
[
torch
.
Tensor
,
LogitsProcessorOutput
],
sampling_info
:
SamplingBatchInfo
,
):
if
isinstance
(
logits
,
LogitsProcessorOutput
):
logits
=
logits
.
next_token_logits
probs
=
self
.
_get_probs
(
logits
,
sampling_info
,
is_torch_compile
=
True
)
if
not
torch
.
all
(
success
):
logging
.
warning
(
"Sampling failed, fallback to top_k=1 strategy"
)
probs
=
probs
.
masked_fill
(
torch
.
isnan
(
probs
),
0.0
)
argmax_ids
=
torch
.
argmax
(
probs
,
dim
=-
1
)
batch_next_token_ids
=
torch
.
where
(
success
,
batch_next_token_ids
,
argmax_ids
)
batch_next_token_ids
,
success
=
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
,
sampling_info
.
min_ps
)
return
batch_next_token_ids
return
SampleOutput
(
success
,
probs
,
batch_next_token_ids
)
def
forward_native
():
raise
NotImplementedError
(
"Native forward is not implemented yet."
)
def
top_k_top_p_min_p_sampling_from_probs_torch
(
...
...
@@ -137,10 +87,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
probs_sort
[
probs_sort
<
min_p_thresholds
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
try
:
# FIXME: torch.multiomial does not support num_samples = 1
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
2
,
replacement
=
True
)[
:,
:
1
]
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
except
RuntimeError
as
e
:
logger
.
warning
(
f
"Sampling error:
{
e
}
"
)
batch_next_token_ids
=
torch
.
zeros
(
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
f25f4dfd
from
__future__
import
annotations
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -19,7 +17,7 @@ limitations under the License.
import
logging
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
import
torch
...
...
@@ -31,10 +29,6 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
if
TYPE_CHECKING
:
from
sglang.srt.layers.sampler
import
SampleOutput
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
# Put some global args for easy access
...
...
@@ -684,17 +678,11 @@ class ScheduleBatch:
self
.
top_logprobs_nums
.
extend
(
other
.
top_logprobs_nums
)
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
def
check_sample_results
(
self
,
sample_output
:
SampleOutput
):
if
not
torch
.
all
(
sample_output
.
success
):
probs
=
sample_output
.
probs
batch_next_token_ids
=
sample_output
.
batch_next_token_ids
logging
.
warning
(
"Sampling failed, fallback to top_k=1 strategy"
)
probs
=
probs
.
masked_fill
(
torch
.
isnan
(
probs
),
0.0
)
argmax_ids
=
torch
.
argmax
(
probs
,
dim
=-
1
)
batch_next_token_ids
=
torch
.
where
(
sample_output
.
success
,
batch_next_token_ids
,
argmax_ids
)
sample_output
.
probs
=
probs
sample_output
.
batch_next_token_ids
=
batch_next_token_ids
def
sample
(
self
,
logits
:
torch
.
Tensor
):
from
sglang.srt.layers.sampler
import
Sampler
sampler
=
Sampler
()
batch_next_token_ids
=
sampler
(
logits
,
self
.
sampling_info
)
return
sample_output
.
batch_next_token_ids
return
batch_next_token_ids
python/sglang/srt/managers/tp_worker.py
View file @
f25f4dfd
...
...
@@ -31,7 +31,7 @@ from sglang.global_config import global_config
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.layers.logits_processor
import
Logit
s
ProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitProcessorOutput
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchEmbeddingOut
,
...
...
@@ -505,29 +505,21 @@ class ModelTpServer:
if
self
.
model_runner
.
is_generation
:
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
sample_output
,
logits_output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
=
batch
.
check_sample_results
(
sample_output
)
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
# Move logprobs to cpu
if
logits_output
.
next_token_logprobs
is
not
None
:
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
].
tolist
()
)
logits_output
.
input_token_logprobs
=
(
logits_output
.
input_token_logprobs
.
tolist
()
)
logits_output
.
normalized_prompt_logprobs
=
(
logits_output
.
normalized_prompt_logprobs
.
tolist
()
if
output
.
next_token_logprobs
is
not
None
:
output
.
next_token_logprobs
=
output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
].
tolist
()
output
.
input_token_logprobs
=
output
.
input_token_logprobs
.
tolist
()
output
.
normalized_prompt_logprobs
=
(
output
.
normalized_prompt_logprobs
.
tolist
()
)
next_token_ids
=
next_token_ids
.
tolist
()
...
...
@@ -566,14 +558,12 @@ class ModelTpServer:
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
if
req
.
return_logprob
:
self
.
add_logprob_return_values
(
i
,
req
,
pt
,
next_token_ids
,
logits_output
)
self
.
add_logprob_return_values
(
i
,
req
,
pt
,
next_token_ids
,
output
)
pt
+=
req
.
extend_input_len
else
:
assert
batch
.
extend_num_tokens
!=
0
logits_
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
embeddings
=
logits_
output
.
embeddings
.
tolist
()
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
embeddings
=
output
.
embeddings
.
tolist
()
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
...
...
@@ -601,7 +591,7 @@ class ModelTpServer:
req
:
Req
,
pt
:
int
,
next_token_ids
:
List
[
int
],
output
:
Logit
s
ProcessorOutput
,
output
:
LogitProcessorOutput
,
):
if
req
.
normalized_prompt_logprob
is
None
:
req
.
normalized_prompt_logprob
=
output
.
normalized_prompt_logprobs
[
i
]
...
...
@@ -683,17 +673,15 @@ class ModelTpServer:
batch
.
prepare_for_decode
()
# Forward and sample the next tokens
sample_output
,
logits_output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
next_token_ids
=
batch
.
check_sample_results
(
sample_output
)
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
# Move logprobs to cpu
if
logits_
output
.
next_token_logprobs
is
not
None
:
next_token_logprobs
=
logits_
output
.
next_token_logprobs
[
if
output
.
next_token_logprobs
is
not
None
:
next_token_logprobs
=
output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
].
tolist
()
...
...
@@ -719,7 +707,7 @@ class ModelTpServer:
(
next_token_logprobs
[
i
],
next_token_id
)
)
if
req
.
top_logprobs_num
>
0
:
req
.
output_top_logprobs
.
append
(
logits_
output
.
output_top_logprobs
[
i
])
req
.
output_top_logprobs
.
append
(
output
.
output_top_logprobs
[
i
])
self
.
handle_finished_requests
(
batch
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
f25f4dfd
...
...
@@ -26,18 +26,16 @@ from vllm.distributed.parallel_state import graph_capture
from
vllm.model_executor.custom_op
import
CustomOp
from
sglang.srt.layers.logits_processor
import
(
LogitProcessorOutput
,
LogitsMetadata
,
LogitsProcessor
,
LogitsProcessorOutput
,
)
from
sglang.srt.layers.sampler
import
SampleOutput
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.model_executor.forward_batch_info
import
(
ForwardMode
,
InputMetadata
,
update_flashinfer_indices
,
)
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.utils
import
monkey_patch_vllm_all_gather
...
...
@@ -146,10 +144,6 @@ class CudaGraphRunner:
self
.
flashinfer_kv_indices
.
clone
(),
]
# Sampling inputs
vocab_size
=
model_runner
.
model_config
.
vocab_size
self
.
sampling_info
=
SamplingBatchInfo
.
dummy_one
(
self
.
max_bs
,
vocab_size
)
self
.
compile_bs
=
[
1
,
2
,
4
,
8
,
16
,
24
,
32
]
if
use_torch_compile
else
[]
if
use_torch_compile
:
...
...
@@ -241,7 +235,6 @@ class CudaGraphRunner:
def
run_once
():
input_metadata
=
InputMetadata
(
forward_mode
=
ForwardMode
.
DECODE
,
sampling_info
=
self
.
sampling_info
[:
bs
],
batch_size
=
bs
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
...
...
@@ -306,35 +299,27 @@ class CudaGraphRunner:
self
.
flashinfer_handlers
[
bs
],
)
# Sampling inputs
self
.
sampling_info
.
inplace_assign
(
raw_bs
,
batch
.
sampling_info
)
# Replay
torch
.
cuda
.
synchronize
()
self
.
graphs
[
bs
].
replay
()
torch
.
cuda
.
synchronize
()
sample_output
,
logits_
output
=
self
.
output_buffers
[
bs
]
output
=
self
.
output_buffers
[
bs
]
# Unpad
if
bs
!=
raw_bs
:
logits_
output
=
Logit
s
ProcessorOutput
(
next_token_logits
=
logits_
output
.
next_token_logits
[:
raw_bs
],
output
=
LogitProcessorOutput
(
next_token_logits
=
output
.
next_token_logits
[:
raw_bs
],
next_token_logprobs
=
None
,
normalized_prompt_logprobs
=
None
,
input_token_logprobs
=
None
,
input_top_logprobs
=
None
,
output_top_logprobs
=
None
,
)
sample_output
=
SampleOutput
(
sample_output
.
success
[:
raw_bs
],
sample_output
.
probs
[:
raw_bs
],
sample_output
.
batch_next_token_ids
[:
raw_bs
],
)
# Extract logprobs
if
batch
.
return_logprob
:
logits_
output
.
next_token_logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
logits_
output
.
next_token_logits
,
dim
=-
1
output
.
next_token_logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
output
.
next_token_logits
,
dim
=-
1
)
return_top_logprob
=
any
(
x
>
0
for
x
in
batch
.
top_logprobs_nums
)
if
return_top_logprob
:
...
...
@@ -342,8 +327,8 @@ class CudaGraphRunner:
forward_mode
=
ForwardMode
.
DECODE
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
)
logits_
output
.
output_top_logprobs
=
LogitsProcessor
.
get_top_logprobs
(
logits_
output
.
next_token_logprobs
,
logits_metadata
output
.
output_top_logprobs
=
LogitsProcessor
.
get_top_logprobs
(
output
.
next_token_logprobs
,
logits_metadata
)[
1
]
return
sample_output
,
logits_
output
return
output
python/sglang/srt/model_executor/forward_batch_info.py
View file @
f25f4dfd
from
__future__
import
annotations
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -18,7 +16,7 @@ limitations under the License.
"""ModelRunner runs the forward passes of the models."""
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
typing
import
TYPE_CHECKING
,
List
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
numpy
as
np
import
torch
...
...
@@ -28,7 +26,6 @@ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
class
ForwardMode
(
IntEnum
):
...
...
@@ -45,7 +42,6 @@ class InputMetadata:
"""Store all inforamtion of a forward pass."""
forward_mode
:
ForwardMode
sampling_info
:
SamplingBatchInfo
batch_size
:
int
req_pool_indices
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
...
...
@@ -183,7 +179,6 @@ class InputMetadata:
):
ret
=
cls
(
forward_mode
=
forward_mode
,
sampling_info
=
batch
.
sampling_info
,
batch_size
=
batch
.
batch_size
(),
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
...
...
@@ -194,8 +189,6 @@ class InputMetadata:
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
)
ret
.
sampling_info
.
prepare_penalties
()
ret
.
compute_positions
(
batch
)
ret
.
compute_extend_infos
(
batch
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
f25f4dfd
...
...
@@ -21,7 +21,7 @@ import importlib.resources
import
logging
import
pkgutil
from
functools
import
lru_cache
from
typing
import
Optional
,
Tuple
,
Type
from
typing
import
Optional
,
Type
import
torch
import
torch.nn
as
nn
...
...
@@ -44,8 +44,6 @@ from vllm.model_executor.model_loader import get_model
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.global_config
import
global_config
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
SampleOutput
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
(
MHATokenToKVPool
,
...
...
@@ -517,11 +515,7 @@ class ModelRunner:
@
torch
.
inference_mode
()
def
forward_decode
(
self
,
batch
:
ScheduleBatch
):
if
(
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
len
(
batch
.
reqs
))
and
not
batch
.
sampling_info
.
has_bias
()
):
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
len
(
batch
.
reqs
)):
return
self
.
cuda_graph_runner
.
replay
(
batch
)
input_metadata
=
InputMetadata
.
from_schedule_batch
(
...
...
@@ -570,9 +564,7 @@ class ModelRunner:
input_metadata
.
image_offsets
,
)
def
forward
(
self
,
batch
:
ScheduleBatch
,
forward_mode
:
ForwardMode
)
->
Tuple
[
SampleOutput
,
LogitsProcessorOutput
]:
def
forward
(
self
,
batch
:
ScheduleBatch
,
forward_mode
:
ForwardMode
):
if
self
.
is_multimodal_model
and
forward_mode
==
ForwardMode
.
EXTEND
:
return
self
.
forward_extend_multi_modal
(
batch
)
elif
forward_mode
==
ForwardMode
.
DECODE
:
...
...
python/sglang/srt/models/chatglm.py
View file @
f25f4dfd
...
...
@@ -31,18 +31,20 @@ from vllm.model_executor.layers.linear import (
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs
import
ChatGLMConfig
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
LoraConfig
=
None
...
...
@@ -381,11 +383,17 @@ class ChatGLMForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
...
...
python/sglang/srt/models/commandr.py
View file @
f25f4dfd
...
...
@@ -64,7 +64,6 @@ from vllm.model_executor.utils import set_weight_attrs
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -327,7 +326,6 @@ class CohereForCausalLM(nn.Module):
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
self
.
model
=
CohereModel
(
config
,
quant_config
)
@
torch
.
no_grad
()
...
...
@@ -342,11 +340,9 @@ class CohereForCausalLM(nn.Module):
positions
,
input_metadata
,
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/dbrx.py
View file @
f25f4dfd
...
...
@@ -45,7 +45,6 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -383,7 +382,6 @@ class DbrxForCausalLM(nn.Module):
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -393,11 +391,9 @@ class DbrxForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
expert_params_mapping
=
[
...
...
python/sglang/srt/models/deepseek.py
View file @
f25f4dfd
...
...
@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -386,7 +385,6 @@ class DeepseekForCausalLM(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -396,11 +394,9 @@ class DeepseekForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
f25f4dfd
...
...
@@ -45,7 +45,6 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -633,7 +632,6 @@ class DeepseekV2ForCausalLM(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
...
...
@@ -642,11 +640,9 @@ class DeepseekV2ForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/gemma.py
View file @
f25f4dfd
...
...
@@ -37,7 +37,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -288,7 +287,6 @@ class GemmaForCausalLM(nn.Module):
self
.
quant_config
=
quant_config
self
.
model
=
GemmaModel
(
config
,
quant_config
=
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -299,11 +297,9 @@ class GemmaForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
(
sample_output
,
logits_output
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/gemma2.py
View file @
f25f4dfd
...
...
@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.activation
import
GeluAndMul
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -397,7 +396,6 @@ class Gemma2ForCausalLM(nn.Module):
self
.
quant_config
=
quant_config
self
.
model
=
Gemma2Model
(
config
,
cache_config
,
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -408,11 +406,9 @@ class Gemma2ForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
get_attention_sliding_window_size
(
self
):
return
get_attention_sliding_window_size
(
self
.
config
)
...
...
python/sglang/srt/models/gpt_bigcode.py
View file @
f25f4dfd
...
...
@@ -35,7 +35,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.activation
import
get_act_fn
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -262,7 +261,6 @@ class GPTBigCodeForCausalLM(nn.Module):
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -272,11 +270,9 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
...
...
python/sglang/srt/models/grok.py
View file @
f25f4dfd
...
...
@@ -46,7 +46,6 @@ from sglang.srt.layers.fused_moe import FusedMoE
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -298,7 +297,6 @@ class Grok1ModelForCausalLM(nn.Module):
self
.
model
=
Grok1Model
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
# Monkey patch _prepare_weights to load pre-sharded weights
setattr
(
DefaultModelLoader
,
"_prepare_weights"
,
_prepare_presharded_weights
)
...
...
@@ -315,11 +313,9 @@ class Grok1ModelForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment