Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
9c6ba248
Unverified
Commit
9c6ba248
authored
Dec 30, 2024
by
Lianmin Zheng
Committed by
GitHub
Dec 30, 2024
Browse files
Refactor logprob computation to return the real logprob used in sampling (#2664)
parent
b02da24a
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
307 additions
and
314 deletions
+307
-314
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+164
-212
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+55
-21
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+9
-15
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+3
-4
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+3
-30
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+13
-30
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+23
-0
test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py
.../sampling/penaltylib/test_srt_endpoint_with_penalizers.py
+2
-2
test/srt/test_srt_endpoint.py
test/srt/test_srt_endpoint.py
+35
-0
No files found.
python/sglang/srt/layers/logits_processor.py
View file @
9c6ba248
...
...
@@ -17,6 +17,8 @@ import dataclasses
from
typing
import
List
,
Optional
,
Union
import
torch
import
triton
import
triton.language
as
tl
from
torch
import
nn
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
...
...
@@ -33,76 +35,77 @@ from sglang.srt.model_executor.forward_batch_info import (
@
dataclasses
.
dataclass
class
LogitsProcessorOutput
:
## First part. This part will be returned by python/sglang/srt/layers/logits_processor.py::LogitsProcessor.
# The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits
:
torch
.
Tensor
# The logprobs of the next tokens. shape: [#seq, vocab_size]
next_token_logprobs
:
torch
.
Tensor
=
None
# Used by speculative decoding (EAGLE)
# The last hidden layers
hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
## Second part. This part will be returned by python/sglang/srt/layers/sampler.py::Sampler.
# The logprobs of the next tokens. shape: [#seq]
next_token_logprobs
:
Optional
[
torch
.
Tensor
]
=
None
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
next_token_top_logprobs_val
:
Optional
[
List
]
=
None
next_token_top_logprobs_idx
:
Optional
[
List
]
=
None
## Third part. This part will be returned by python/sglang/srt/layers/logits_processor.py::LogitsProcessor. Prefill-only.
# The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs
:
torch
.
Tensor
=
None
# The logprobs of input tokens. shape: [#token
, vocab_size
]
# The logprobs of input tokens. shape: [#token]
input_token_logprobs
:
torch
.
Tensor
=
None
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k]
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
input_top_logprobs_val
:
List
=
None
input_top_logprobs_idx
:
List
=
None
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k]
output_top_logprobs_val
:
List
=
None
output_top_logprobs_idx
:
List
=
None
# Used by speculative decoding (EAGLE)
# The output of transformer layers
hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
@
dataclasses
.
dataclass
class
LogitsMetadata
:
forward_mode
:
ForwardMode
top_logprobs_nums
:
Optional
[
List
[
int
]]
return_logprob
:
bool
=
False
return_top_logprob
:
bool
=
False
capture_hidden_mode
:
CaptureHiddenMode
=
CaptureHiddenMode
.
NULL
extend_return_logprob
:
bool
=
False
extend_return_top_logprob
:
bool
=
False
extend_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
extend_seq_lens_cpu
:
Optional
[
List
[
int
]]
=
None
extend_logprob_start_lens_cpu
:
Optional
[
List
[
int
]]
=
None
extend_logprob_pruned_lens_cpu
:
Optional
[
List
[
int
]]
=
None
capture_hidden_mode
:
CaptureHiddenMode
=
CaptureHiddenMode
.
NULL
top_logprobs_nums
:
Optional
[
List
[
int
]]
=
None
@
classmethod
def
from_forward_batch
(
cls
,
forward_batch
:
ForwardBatch
):
extend_logprob_pruned_lens_cpu
=
None
if
forward_batch
.
return_logprob
:
return_top_logprob
=
any
(
x
>
0
for
x
in
forward_batch
.
top_logprobs_nums
)
if
forward_batch
.
forward_mode
.
is_extend
():
extend_logprob_pruned_lens_cpu
=
[
extend_len
-
start_len
for
extend_len
,
start_len
in
zip
(
forward_batch
.
extend_seq_lens_cpu
,
forward_batch
.
extend_logprob_start_lens_cpu
,
)
]
else
:
return_top_logprob
=
False
if
forward_batch
.
spec_info
:
capture_hidden_mode
=
forward_batch
.
spec_info
.
capture_hidden_mode
else
:
capture_hidden_mode
=
CaptureHiddenMode
.
NULL
if
forward_batch
.
forward_mode
.
is_extend
()
and
forward_batch
.
return_logprob
:
extend_return_logprob
=
True
extend_return_top_logprob
=
any
(
x
>
0
for
x
in
forward_batch
.
top_logprobs_nums
)
extend_logprob_pruned_lens_cpu
=
[
extend_len
-
start_len
for
extend_len
,
start_len
in
zip
(
forward_batch
.
extend_seq_lens_cpu
,
forward_batch
.
extend_logprob_start_lens_cpu
,
)
]
else
:
extend_return_logprob
=
extend_return_top_logprob
=
(
extend_logprob_pruned_lens_cpu
)
=
False
return
cls
(
forward_mode
=
forward_batch
.
forward_mode
,
top_logprobs_nums
=
forward_batch
.
top_logprobs_nums
,
return_logprob
=
forward_batch
.
return_logprob
,
return_top_logprob
=
return_top_logprob
,
capture_hidden_mode
=
capture_hidden_mode
,
extend_
return_logprob
=
extend_
return_logprob
,
extend_
return_top_logprob
=
extend_
return_top_logprob
,
extend_seq_lens
=
forward_batch
.
extend_seq_lens
,
extend_seq_lens_cpu
=
forward_batch
.
extend_seq_lens_cpu
,
extend_logprob_start_lens_cpu
=
forward_batch
.
extend_logprob_start_lens_cpu
,
extend_logprob_pruned_lens_cpu
=
extend_logprob_pruned_lens_cpu
,
capture_hidden_mode
=
capture_hidden_mode
,
top_logprobs_nums
=
forward_batch
.
top_logprobs_nums
,
)
...
...
@@ -129,7 +132,6 @@ class LogitsProcessor(nn.Module):
):
if
isinstance
(
logits_metadata
,
ForwardBatch
):
logits_metadata
=
LogitsMetadata
.
from_forward_batch
(
logits_metadata
)
assert
isinstance
(
logits_metadata
,
LogitsMetadata
)
# Get the last hidden states and last logits for the next token prediction
if
(
...
...
@@ -142,18 +144,10 @@ class LogitsProcessor(nn.Module):
last_index
=
torch
.
cumsum
(
logits_metadata
.
extend_seq_lens
,
dim
=
0
)
-
1
last_hidden
=
hidden_states
[
last_index
]
# Compute logits
last_logits
=
self
.
_get_logits
(
last_hidden
,
lm_head
)
if
self
.
do_tensor_parallel_all_gather
:
last_logits
=
tensor_model_parallel_all_gather
(
last_logits
)
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
if
self
.
final_logit_softcapping
:
last_logits
.
div_
(
self
.
final_logit_softcapping
)
torch
.
tanh
(
last_logits
,
out
=
last_logits
)
last_logits
.
mul_
(
self
.
final_logit_softcapping
)
# Return only last_logits if logprob is not requested
if
not
logits_metadata
.
return_logprob
:
if
not
logits_metadata
.
extend_return_logprob
:
# Decode mode or extend mode without return_logprob.
return
LogitsProcessorOutput
(
next_token_logits
=
last_logits
,
hidden_states
=
(
...
...
@@ -167,95 +161,60 @@ class LogitsProcessor(nn.Module):
),
)
else
:
last_logprobs
=
self
.
compute_temp_top_p_normalized_logprobs
(
last_logits
,
logits_metadata
# Slice the requested tokens to compute logprob
pt
,
pruned_states
,
pruned_input_ids
=
0
,
[],
[]
for
start_len
,
extend_len
in
zip
(
logits_metadata
.
extend_logprob_start_lens_cpu
,
logits_metadata
.
extend_seq_lens_cpu
,
):
pruned_states
.
append
(
hidden_states
[
pt
+
start_len
:
pt
+
extend_len
])
pruned_input_ids
.
append
(
input_ids
[
pt
+
start_len
:
pt
+
extend_len
])
pt
+=
extend_len
# Compute the logits of all required tokens
pruned_states
=
torch
.
cat
(
pruned_states
)
del
hidden_states
input_token_logits
=
self
.
_get_logits
(
pruned_states
,
lm_head
)
del
pruned_states
# Normalize the logprob w/o temperature, top-p
input_logprobs
=
input_token_logits
input_logprobs
=
self
.
compute_temp_top_p_normalized_logprobs
(
input_logprobs
,
logits_metadata
)
if
logits_metadata
.
forward_mode
.
is_decode
():
if
logits_metadata
.
return_top_logprob
:
output_top_logprobs_val
,
output_top_logprobs_idx
=
(
self
.
get_top_logprobs
(
last_logprobs
,
logits_metadata
)[
2
:
4
]
)
else
:
output_top_logprobs_val
=
output_top_logprobs_idx
=
None
return
LogitsProcessorOutput
(
next_token_logits
=
last_logits
,
next_token_logprobs
=
last_logprobs
,
output_top_logprobs_val
=
output_top_logprobs_val
,
output_top_logprobs_idx
=
output_top_logprobs_idx
,
)
# Get the logprob of top-k tokens
if
logits_metadata
.
extend_return_top_logprob
:
(
input_top_logprobs_val
,
input_top_logprobs_idx
,
)
=
self
.
get_top_logprobs
(
input_logprobs
,
logits_metadata
)
else
:
# Slice the requested tokens to compute logprob
pt
,
states
,
pruned_input_ids
=
0
,
[],
[]
for
start_len
,
extend_len
in
zip
(
logits_metadata
.
extend_logprob_start_lens_cpu
,
logits_metadata
.
extend_seq_lens_cpu
,
):
states
.
append
(
hidden_states
[
pt
+
start_len
:
pt
+
extend_len
])
pruned_input_ids
.
append
(
input_ids
[
pt
+
start_len
:
pt
+
extend_len
])
pt
+=
extend_len
# Compute the logits and logprobs for all required tokens
states
=
torch
.
cat
(
states
,
dim
=
0
)
all_logits
=
self
.
_get_logits
(
states
,
lm_head
)
if
self
.
do_tensor_parallel_all_gather
:
all_logits
=
tensor_model_parallel_all_gather
(
all_logits
)
# The LM head's weights may be zero-padded for parallelism. Remove any
# extra logits that this padding may have produced.
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
if
self
.
final_logit_softcapping
:
all_logits
.
div_
(
self
.
final_logit_softcapping
)
torch
.
tanh
(
all_logits
,
out
=
all_logits
)
all_logits
.
mul_
(
self
.
final_logit_softcapping
)
all_logprobs
=
all_logits
del
all_logits
,
hidden_states
all_logprobs
=
self
.
compute_temp_top_p_normalized_logprobs
(
all_logprobs
,
logits_metadata
)
# Get the logprob of top-k tokens
if
logits_metadata
.
return_top_logprob
:
(
input_top_logprobs_val
,
input_top_logprobs_idx
,
output_top_logprobs_val
,
output_top_logprobs_idx
,
)
=
self
.
get_top_logprobs
(
all_logprobs
,
logits_metadata
)
else
:
input_top_logprobs_val
=
input_top_logprobs_idx
=
(
output_top_logprobs_val
)
=
output_top_logprobs_idx
=
None
# Compute the normalized logprobs for the requested tokens.
# Note that we pad a zero at the end for easy batching.
input_token_logprobs
=
all_logprobs
[
torch
.
arange
(
all_logprobs
.
shape
[
0
],
device
=
"cuda"
),
torch
.
cat
(
[
torch
.
cat
(
pruned_input_ids
)[
1
:],
torch
.
tensor
([
0
],
device
=
"cuda"
),
]
),
]
normalized_prompt_logprobs
=
self
.
_get_normalized_prompt_logprobs
(
input_token_logprobs
,
logits_metadata
,
)
input_top_logprobs_val
=
input_top_logprobs_idx
=
None
# Compute the normalized logprobs for the requested tokens.
# Note that we pad a zero at the end for easy batching.
input_token_logprobs
=
input_logprobs
[
torch
.
arange
(
input_logprobs
.
shape
[
0
],
device
=
"cuda"
),
torch
.
cat
(
[
torch
.
cat
(
pruned_input_ids
)[
1
:],
torch
.
tensor
([
0
],
device
=
"cuda"
),
]
),
]
normalized_prompt_logprobs
=
self
.
_get_normalized_prompt_logprobs
(
input_token_logprobs
,
logits_metadata
,
)
return
LogitsProcessorOutput
(
next_token_logits
=
last_logits
,
next_token_logprobs
=
last_logprobs
,
normalized_prompt_logprobs
=
normalized_prompt_logprobs
,
input_token_logprobs
=
input_token_logprobs
,
input_top_logprobs_val
=
input_top_logprobs_val
,
input_top_logprobs_idx
=
input_top_logprobs_idx
,
output_top_logprobs_val
=
output_top_logprobs_val
,
output_top_logprobs_idx
=
output_top_logprobs_idx
,
)
return
LogitsProcessorOutput
(
next_token_logits
=
last_logits
,
normalized_prompt_logprobs
=
normalized_prompt_logprobs
,
input_token_logprobs
=
input_token_logprobs
,
input_top_logprobs_val
=
input_top_logprobs_val
,
input_top_logprobs_idx
=
input_top_logprobs_idx
,
)
def
_get_logits
(
self
,
...
...
@@ -269,9 +228,19 @@ class LogitsProcessor(nn.Module):
# GGUF models
logits
=
lm_head
.
linear_method
.
apply
(
lm_head
,
hidden_states
,
embedding_bias
)
# Optional scaling factor
if
self
.
logit_scale
is
not
None
:
logits
.
mul_
(
self
.
logit_scale
)
# In-place multiply
logits
.
mul_
(
self
.
logit_scale
)
if
self
.
do_tensor_parallel_all_gather
:
logits
=
tensor_model_parallel_all_gather
(
logits
)
# Compute the normalized logprobs for the requested tokens.
# Note that we pad a zero at the end for easy batching.
logits
=
logits
[:,
:
self
.
config
.
vocab_size
].
float
()
if
self
.
final_logit_softcapping
:
fused_softcap
(
logits
,
self
.
final_logit_softcapping
)
return
logits
@
staticmethod
...
...
@@ -302,90 +271,73 @@ class LogitsProcessor(nn.Module):
values
=
ret
.
values
.
tolist
()
indices
=
ret
.
indices
.
tolist
()
if
logits_metadata
.
forward_mode
.
is_decode
():
output_top_logprobs_val
=
[]
output_top_logprobs_idx
=
[]
for
i
,
k
in
enumerate
(
logits_metadata
.
top_logprobs_nums
):
output_top_logprobs_val
.
append
(
values
[
i
][:
k
])
output_top_logprobs_idx
.
append
(
indices
[
i
][:
k
])
return
None
,
None
,
output_top_logprobs_val
,
output_top_logprobs_idx
else
:
input_top_logprobs_val
,
input_top_logprobs_idx
=
[],
[]
output_top_logprobs_val
,
output_top_logprobs_idx
=
[],
[]
input_top_logprobs_val
,
input_top_logprobs_idx
=
[],
[]
pt
=
0
for
k
,
pruned_len
in
zip
(
logits_metadata
.
top_logprobs_nums
,
logits_metadata
.
extend_logprob_pruned_lens_cpu
,
):
if
pruned_len
<=
0
:
input_top_logprobs_val
.
append
([])
input_top_logprobs_idx
.
append
([])
output_top_logprobs_val
.
append
([])
output_top_logprobs_idx
.
append
([])
continue
input_top_logprobs_val
.
append
(
[
values
[
pt
+
j
][:
k
]
for
j
in
range
(
pruned_len
-
1
)]
)
input_top_logprobs_idx
.
append
(
[
indices
[
pt
+
j
][:
k
]
for
j
in
range
(
pruned_len
-
1
)]
)
output_top_logprobs_val
.
append
(
list
(
values
[
pt
+
pruned_len
-
1
][:
k
],
)
)
output_top_logprobs_idx
.
append
(
list
(
indices
[
pt
+
pruned_len
-
1
][:
k
],
)
)
pt
+=
pruned_len
pt
=
0
for
k
,
pruned_len
in
zip
(
logits_metadata
.
top_logprobs_nums
,
logits_metadata
.
extend_logprob_pruned_lens_cpu
,
):
if
pruned_len
<=
0
:
input_top_logprobs_val
.
append
([])
input_top_logprobs_idx
.
append
([])
continue
return
(
input_top_logprobs_val
,
input_top_logprobs_idx
,
output_top_logprobs_val
,
output_top_logprobs_idx
,
input_top_logprobs_val
.
append
(
[
values
[
pt
+
j
][:
k
]
for
j
in
range
(
pruned_len
-
1
)]
)
input_top_logprobs_idx
.
append
(
[
indices
[
pt
+
j
][:
k
]
for
j
in
range
(
pruned_len
-
1
)]
)
pt
+=
pruned_len
return
input_top_logprobs_val
,
input_top_logprobs_idx
@
staticmethod
def
compute_temp_top_p_normalized_logprobs
(
last_logits
:
torch
.
Tensor
,
logits_metadata
:
LogitsMetadata
)
->
torch
.
Tensor
:
# TODO: Implement the temp and top-p normalization
return
torch
.
nn
.
functional
.
log_softmax
(
last_logits
,
dim
=-
1
)
def
test
():
all_logprobs
=
torch
.
tensor
(
# s s s
[[
0
,
1
,
2
,
3
],
[
1
,
2
,
3
,
4
],
[
2
,
3
,
4
,
5
],
[
3
,
4
,
5
,
6
],
[
4
,
5
,
6
,
7
]],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
@
triton
.
jit
def
fused_softcap_kernel
(
full_logits_ptr
,
softcapping_value
,
n_elements
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
block_start
=
pid
*
BLOCK_SIZE
offsets
=
block_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offsets
<
n_elements
# Load values
x
=
tl
.
load
(
full_logits_ptr
+
offsets
,
mask
=
mask
)
# Perform operations in-place
x
=
x
/
softcapping_value
# Manual tanh implementation using exp
exp2x
=
tl
.
exp
(
2
*
x
)
x
=
(
exp2x
-
1
)
/
(
exp2x
+
1
)
x
=
x
*
softcapping_value
# Store result
tl
.
store
(
full_logits_ptr
+
offsets
,
x
,
mask
=
mask
)
def
fused_softcap
(
full_logits
,
final_logit_softcapping
):
n_elements
=
full_logits
.
numel
()
BLOCK_SIZE
=
1024
grid
=
((
n_elements
+
BLOCK_SIZE
-
1
)
//
BLOCK_SIZE
,
1
,
1
)
fused_softcap_kernel
[
grid
](
full_logits_ptr
=
full_logits
,
softcapping_value
=
final_logit_softcapping
,
n_elements
=
n_elements
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
seq_lens
=
torch
.
tensor
([
2
,
0
,
3
,
0
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
input_ids
=
torch
.
tensor
([
1
,
2
,
3
,
0
,
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
token_logprobs
=
all_logprobs
[
torch
.
arange
(
all_logprobs
.
shape
[
0
],
device
=
"cuda"
),
torch
.
cat
([
input_ids
[
1
:],
torch
.
tensor
([
0
],
device
=
"cuda"
)]),
]
logprobs_cumsum
=
torch
.
cumsum
(
token_logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
len_cumsum
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
start
=
torch
.
cat
((
torch
.
tensor
([
0
],
device
=
"cuda"
),
len_cumsum
[:
-
1
]),
0
)
end
=
start
+
seq_lens
-
2
start
.
clamp_
(
min
=
0
,
max
=
token_logprobs
.
shape
[
0
]
-
1
)
end
.
clamp_
(
min
=
0
,
max
=
token_logprobs
.
shape
[
0
]
-
1
)
sum_logp
=
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
token_logprobs
[
start
]
# assert logprobs == [2, _, 2, 4, _]
print
(
"token logprobs"
,
token_logprobs
)
print
(
"start"
,
start
)
print
(
"end"
,
end
)
print
(
"sum_logp"
,
sum_logp
)
if
__name__
==
"__main__"
:
test
()
return
full_logits
python/sglang/srt/layers/sampler.py
View file @
9c6ba248
import
logging
from
typing
import
Union
from
typing
import
List
import
torch
from
torch
import
nn
...
...
@@ -28,13 +28,12 @@ class Sampler(nn.Module):
def
forward
(
self
,
logits
:
Union
[
torch
.
Tensor
,
LogitsProcessorOutput
]
,
logits
_output
:
LogitsProcessorOutput
,
sampling_info
:
SamplingBatchInfo
,
return_logprob
:
bool
,
top_logprobs_nums
:
List
[
int
],
):
if
isinstance
(
logits
,
LogitsProcessorOutput
):
logits
=
logits
.
next_token_logits
logits
=
logits
.
contiguous
()
logits
=
logits_output
.
next_token_logits
if
self
.
use_nan_detectioin
and
torch
.
any
(
torch
.
isnan
(
logits
)):
logger
.
warning
(
"Detected errors during sampling! NaN in the logits."
)
...
...
@@ -47,6 +46,8 @@ class Sampler(nn.Module):
if
sampling_info
.
is_all_greedy
:
# Use torch.argmax if all requests use greedy sampling
batch_next_token_ids
=
torch
.
argmax
(
logits
,
-
1
)
if
return_logprob
:
logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
else
:
# Post process logits
logits
.
div_
(
sampling_info
.
temperatures
)
...
...
@@ -54,6 +55,12 @@ class Sampler(nn.Module):
del
logits
if
global_server_args_dict
[
"sampling_backend"
]
==
"flashinfer"
:
if
return_logprob
:
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems
logprobs
=
torch
.
log
(
top_p_normalize_probs_torch
(
probs
,
sampling_info
.
top_ps
)
)
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
uniform_samples
=
torch
.
rand
(
(
max_top_k_round
,
batch_size
),
device
=
probs
.
device
...
...
@@ -76,6 +83,7 @@ class Sampler(nn.Module):
if
self
.
use_nan_detectioin
and
not
torch
.
all
(
success
):
logger
.
warning
(
"Detected errors during sampling!"
)
batch_next_token_ids
=
torch
.
zeros_like
(
batch_next_token_ids
)
elif
global_server_args_dict
[
"sampling_backend"
]
==
"pytorch"
:
# A slower fallback implementation with torch native operations.
batch_next_token_ids
=
top_k_top_p_min_p_sampling_from_probs_torch
(
...
...
@@ -85,12 +93,31 @@ class Sampler(nn.Module):
sampling_info
.
min_ps
,
sampling_info
.
need_min_p_sampling
,
)
if
return_logprob
:
logprobs
=
torch
.
log
(
top_p_normalize_probs_torch
(
probs
,
sampling_info
.
top_ps
)
)
else
:
raise
ValueError
(
f
"Invalid sampling backend:
{
global_server_args_dict
[
'sampling_backend'
]
}
"
)
return
batch_next_token_ids
.
to
(
torch
.
int32
)
batch_next_token_ids
=
batch_next_token_ids
.
to
(
torch
.
int32
)
# Attach logprobs to logits_output (in-place modification)
if
return_logprob
:
if
any
(
x
>
0
for
x
in
top_logprobs_nums
):
(
logits_output
.
next_token_top_logprobs_val
,
logits_output
.
next_token_top_logprobs_idx
,
)
=
get_top_logprobs
(
logprobs
,
top_logprobs_nums
)
logits_output
.
next_token_logprobs
=
logprobs
[
torch
.
arange
(
len
(
batch_next_token_ids
),
device
=
sampling_info
.
device
),
batch_next_token_ids
,
]
return
batch_next_token_ids
def
top_k_top_p_min_p_sampling_from_probs_torch
(
...
...
@@ -120,20 +147,27 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
return
batch_next_token_ids
def
top_p_normalize_probs
(
def
top_p_normalize_probs
_torch
(
probs
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
):
if
global_server_args_dict
[
"sampling_backend"
]
==
"flashinfer"
:
return
top_p_renorm_prob
(
probs
,
top_ps
)
elif
global_server_args_dict
[
"sampling_backend"
]
==
"pytorch"
:
# See also top_k_top_p_min_p_sampling_from_probs_torch
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
probs_sort
[(
probs_sum
-
probs_sort
)
>
top_ps
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
.
div_
(
probs_sort
.
sum
(
dim
=-
1
,
keepdim
=
True
))
return
torch
.
zeros_like
(
probs_sort
).
scatter_
(
-
1
,
probs_idx
,
probs_sort
)
else
:
raise
ValueError
(
f
"Invalid sampling backend:
{
global_server_args_dict
[
'sampling_backend'
]
}
"
)
# See also top_k_top_p_min_p_sampling_from_probs_torch
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
probs_sort
[(
probs_sum
-
probs_sort
)
>
top_ps
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
.
div_
(
probs_sort
.
sum
(
dim
=-
1
,
keepdim
=
True
))
return
torch
.
zeros_like
(
probs_sort
).
scatter_
(
-
1
,
probs_idx
,
probs_sort
)
def
get_top_logprobs
(
logprobs
:
torch
.
Tensor
,
top_logprobs_nums
:
List
[
int
]):
max_k
=
max
(
top_logprobs_nums
)
ret
=
logprobs
.
topk
(
max_k
,
dim
=
1
)
values
=
ret
.
values
.
tolist
()
indices
=
ret
.
indices
.
tolist
()
output_top_logprobs_val
=
[]
output_top_logprobs_idx
=
[]
for
i
,
k
in
enumerate
(
top_logprobs_nums
):
output_top_logprobs_val
.
append
(
values
[
i
][:
k
])
output_top_logprobs_idx
.
append
(
indices
[
i
][:
k
])
return
output_top_logprobs_val
,
output_top_logprobs_idx
python/sglang/srt/managers/scheduler.py
View file @
9c6ba248
...
...
@@ -974,12 +974,10 @@ class Scheduler:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
else
:
# Move next_token_ids and logprobs to cpu
next_token_ids
=
next_token_ids
.
tolist
()
if
batch
.
return_logprob
:
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
self
.
device
),
next_token_ids
,
].
tolist
()
logits_output
.
next_token_logprobs
.
tolist
()
)
logits_output
.
input_token_logprobs
=
(
logits_output
.
input_token_logprobs
.
tolist
()
...
...
@@ -987,7 +985,6 @@ class Scheduler:
logits_output
.
normalized_prompt_logprobs
=
(
logits_output
.
normalized_prompt_logprobs
.
tolist
()
)
next_token_ids
=
next_token_ids
.
tolist
()
# Check finish conditions
logprob_pt
=
0
...
...
@@ -1064,13 +1061,9 @@ class Scheduler:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
next_token_logprobs
=
logits_output
.
next_token_logprobs
else
:
# Move next_token_ids and logprobs to cpu
if
batch
.
return_logprob
:
next_token_logprobs
=
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
self
.
device
),
next_token_ids
,
].
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
if
batch
.
return_logprob
:
next_token_logprobs
=
logits_output
.
next_token_logprobs
.
tolist
()
self
.
token_to_kv_pool
.
free_group_begin
()
...
...
@@ -1095,10 +1088,10 @@ class Scheduler:
req
.
output_token_logprobs_idx
.
append
(
next_token_id
)
if
req
.
top_logprobs_num
>
0
:
req
.
output_top_logprobs_val
.
append
(
logits_output
.
output
_top_logprobs_val
[
i
]
logits_output
.
next_token
_top_logprobs_val
[
i
]
)
req
.
output_top_logprobs_idx
.
append
(
logits_output
.
output
_top_logprobs_idx
[
i
]
logits_output
.
next_token
_top_logprobs_idx
[
i
]
)
if
req
.
grammar
is
not
None
:
...
...
@@ -1200,8 +1193,9 @@ class Scheduler:
req
.
output_top_logprobs_idx
.
extend
(
output
.
input_top_logprobs_idx
[
i
][
-
req
.
last_update_decode_tokens
:]
)
req
.
output_top_logprobs_val
.
append
(
output
.
output_top_logprobs_val
[
i
])
req
.
output_top_logprobs_idx
.
append
(
output
.
output_top_logprobs_idx
[
i
])
req
.
output_top_logprobs_val
.
append
(
output
.
next_token_top_logprobs_val
[
i
])
req
.
output_top_logprobs_idx
.
append
(
output
.
next_token_top_logprobs_idx
[
i
])
return
num_input_logprobs
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
9c6ba248
...
...
@@ -144,10 +144,9 @@ class TpModelWorkerClient:
# Copy results to the CPU
if
model_worker_batch
.
return_logprob
:
logits_output
.
next_token_logprobs
=
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
self
.
device
),
next_token_ids
,
].
to
(
"cpu"
,
non_blocking
=
True
)
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
.
to
(
"cpu"
,
non_blocking
=
True
)
)
if
logits_output
.
input_token_logprobs
is
not
None
:
logits_output
.
input_token_logprobs
=
(
logits_output
.
input_token_logprobs
.
to
(
"cpu"
,
non_blocking
=
True
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
9c6ba248
...
...
@@ -392,34 +392,7 @@ class CudaGraphRunner:
self
.
graphs
[
bs
].
replay
()
next_token_logits
=
self
.
output_buffers
[
bs
][:
raw_bs
]
# Extract logprobs
if
forward_batch
.
return_logprob
:
logits_metadata
=
LogitsMetadata
(
forward_mode
=
ForwardMode
.
DECODE
,
top_logprobs_nums
=
forward_batch
.
top_logprobs_nums
,
)
next_token_logprobs
=
(
LogitsProcessor
.
compute_temp_top_p_normalized_logprobs
(
next_token_logits
,
logits_metadata
)
)
logits_output
=
LogitsProcessorOutput
(
next_token_logits
=
next_token_logits
,
next_token_logprobs
=
next_token_logprobs
,
)
return_top_logprob
=
any
(
x
>
0
for
x
in
forward_batch
.
top_logprobs_nums
)
if
return_top_logprob
:
(
logits_output
.
output_top_logprobs_val
,
logits_output
.
output_top_logprobs_idx
,
)
=
LogitsProcessor
.
get_top_logprobs
(
next_token_logprobs
,
logits_metadata
)[
2
:
4
]
else
:
logits_output
=
LogitsProcessorOutput
(
next_token_logits
=
next_token_logits
,
)
logits_output
=
LogitsProcessorOutput
(
next_token_logits
=
next_token_logits
,
)
return
logits_output
python/sglang/srt/model_executor/model_runner.py
View file @
9c6ba248
...
...
@@ -36,7 +36,7 @@ from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.layers.sampler
import
Sampler
,
get_top_logprobs
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_to_model
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
...
...
@@ -48,7 +48,6 @@ from sglang.srt.mem_cache.memory_pool import (
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader
import
get_model
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
enable_show_time_cost
,
...
...
@@ -192,7 +191,8 @@ class ModelRunner:
torch
.
get_device_module
(
self
.
device
).
set_device
(
self
.
gpu_id
)
if
self
.
device
==
"cuda"
:
backend
=
"nccl"
# ToDO(liangan1):Just use gloo to bypass the initilization fail
# TODO(liangan1):Just use gloo to bypass the initilization fail
# Need to use xccl for xpu backend in the future
elif
self
.
device
==
"xpu"
:
backend
=
"gloo"
...
...
@@ -704,6 +704,7 @@ class ModelRunner:
def
sample
(
self
,
logits_output
:
LogitsProcessorOutput
,
forward_batch
:
ForwardBatch
)
->
torch
.
Tensor
:
# Apply logit bias
sampling_info
=
forward_batch
.
sampling_info
if
sampling_info
.
sampling_info_done
:
# Overlap mode: the function update_regex_vocab_mask was executed
...
...
@@ -714,35 +715,17 @@ class ModelRunner:
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info
.
update_regex_vocab_mask
()
sampling_info
.
update_penalties
()
logits
=
self
.
apply_logits_bias
(
logits_output
.
next_token_logits
,
sampling_info
)
# Sample the next tokens.
next_token_ids
=
self
.
sampler
(
logits
,
sampling_info
)
sampling_info
.
apply_logits_bias
(
logits_output
.
next_token_logits
)
# Sample the next tokens
next_token_ids
=
self
.
sampler
(
logits_output
,
sampling_info
,
forward_batch
.
return_logprob
,
forward_batch
.
top_logprobs_nums
,
)
return
next_token_ids
def
apply_logits_bias
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
# Apply logit_bias
if
sampling_info
.
logit_bias
is
not
None
:
logits
.
add_
(
sampling_info
.
logit_bias
)
# min-token, presence, frequency
if
sampling_info
.
linear_penalties
is
not
None
:
logits
.
add_
(
sampling_info
.
linear_penalties
)
# repetition
if
sampling_info
.
scaling_penalties
is
not
None
:
logits
=
torch
.
where
(
logits
>
0
,
logits
/
sampling_info
.
scaling_penalties
,
logits
*
sampling_info
.
scaling_penalties
,
)
# Apply regex vocab_mask
if
sampling_info
.
vocab_mask
is
not
None
:
sampling_info
.
apply_mask
(
logits
=
logits
,
vocab_mask
=
sampling_info
.
vocab_mask
)
return
logits
@
property
def
model_is_mrope
(
self
)
->
bool
:
"""Detect if the model has "mrope" rope_scaling type.
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
9c6ba248
...
...
@@ -232,3 +232,26 @@ class SamplingBatchInfo:
self
.
logit_bias
=
SamplingBatchInfo
.
merge_bias_tensor
(
self
.
logit_bias
,
other
.
logit_bias
,
len
(
self
),
len
(
other
),
self
.
device
)
def
apply_logits_bias
(
self
,
logits
:
torch
.
Tensor
):
# Apply logit_bias
if
self
.
logit_bias
is
not
None
:
logits
.
add_
(
self
.
logit_bias
)
# min-token, presence, frequency
if
self
.
linear_penalties
is
not
None
:
logits
.
add_
(
self
.
linear_penalties
)
# repetition
if
self
.
scaling_penalties
is
not
None
:
logits
=
torch
.
where
(
logits
>
0
,
logits
/
self
.
scaling_penalties
,
logits
*
self
.
scaling_penalties
,
)
# Apply regex vocab_mask
if
self
.
vocab_mask
is
not
None
:
self
.
apply_mask
(
logits
=
logits
,
vocab_mask
=
self
.
vocab_mask
)
return
logits
test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py
View file @
9c6ba248
...
...
@@ -6,7 +6,7 @@ import requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
...
...
@@ -17,7 +17,7 @@ class TestBatchPenalizerE2E(unittest.TestCase):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
...
...
test/srt/test_srt_endpoint.py
View file @
9c6ba248
...
...
@@ -213,6 +213,41 @@ class TestSRTEndpoint(unittest.TestCase):
max_diff
=
np
.
max
(
diff
)
self
.
assertLess
(
max_diff
,
0.25
)
def
test_logprob_grammar
(
self
):
prompts
=
"Question: Is Paris the Capital of France? Answer:"
allowed_tokens
=
[
" Yes"
,
" No"
]
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
prompts
,
"sampling_params"
:
{
"temperature"
:
1.0
,
"max_new_tokens"
:
1
,
"regex"
:
"( Yes| No)"
,
},
"return_logprob"
:
True
,
"top_logprobs_num"
:
5
,
"return_text_in_logprobs"
:
True
,
},
)
response_json
=
response
.
json
()
output_top_logprobs
=
response_json
[
"meta_info"
][
"output_top_logprobs"
][
0
]
print
(
f
"
{
output_top_logprobs
=
}
"
)
# Parse results
# This is becaues the grammar constraint allows all prefix tokens
logprobs
=
[
None
]
*
2
for
i
in
range
(
len
(
output_top_logprobs
)):
try
:
idx
=
allowed_tokens
.
index
(
output_top_logprobs
[
i
][
2
])
except
ValueError
:
# Not found
continue
logprobs
[
idx
]
=
output_top_logprobs
[
i
][
0
]
self
.
assertTrue
(
all
(
x
is
not
None
for
x
in
logprobs
))
def
test_get_server_info
(
self
):
response
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
response_json
=
response
.
json
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment