Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
9c6ba248
Unverified
Commit
9c6ba248
authored
Dec 30, 2024
by
Lianmin Zheng
Committed by
GitHub
Dec 30, 2024
Browse files
Refactor logprob computation to return the real logprob used in sampling (#2664)
parent
b02da24a
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
307 additions
and
314 deletions
+307
-314
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+164
-212
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+55
-21
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+9
-15
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+3
-4
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+3
-30
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+13
-30
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+23
-0
test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py
.../sampling/penaltylib/test_srt_endpoint_with_penalizers.py
+2
-2
test/srt/test_srt_endpoint.py
test/srt/test_srt_endpoint.py
+35
-0
No files found.
python/sglang/srt/layers/logits_processor.py
View file @
9c6ba248
...
@@ -17,6 +17,8 @@ import dataclasses
...
@@ -17,6 +17,8 @@ import dataclasses
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
import
torch
import
torch
import
triton
import
triton.language
as
tl
from
torch
import
nn
from
torch
import
nn
from
vllm.distributed
import
(
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
...
@@ -33,51 +35,55 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -33,51 +35,55 @@ from sglang.srt.model_executor.forward_batch_info import (
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
LogitsProcessorOutput
:
class
LogitsProcessorOutput
:
## First part. This part will be returned by python/sglang/srt/layers/logits_processor.py::LogitsProcessor.
# The logits of the next tokens. shape: [#seq, vocab_size]
# The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits
:
torch
.
Tensor
next_token_logits
:
torch
.
Tensor
# The logprobs of the next tokens. shape: [#seq, vocab_size]
# Used by speculative decoding (EAGLE)
next_token_logprobs
:
torch
.
Tensor
=
None
# The last hidden layers
hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
## Second part. This part will be returned by python/sglang/srt/layers/sampler.py::Sampler.
# The logprobs of the next tokens. shape: [#seq]
next_token_logprobs
:
Optional
[
torch
.
Tensor
]
=
None
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
next_token_top_logprobs_val
:
Optional
[
List
]
=
None
next_token_top_logprobs_idx
:
Optional
[
List
]
=
None
## Third part. This part will be returned by python/sglang/srt/layers/logits_processor.py::LogitsProcessor. Prefill-only.
# The normlaized logprobs of prompts. shape: [#seq]
# The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs
:
torch
.
Tensor
=
None
normalized_prompt_logprobs
:
torch
.
Tensor
=
None
# The logprobs of input tokens. shape: [#token
, vocab_size
]
# The logprobs of input tokens. shape: [#token]
input_token_logprobs
:
torch
.
Tensor
=
None
input_token_logprobs
:
torch
.
Tensor
=
None
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k]
input_top_logprobs_val
:
List
=
None
input_top_logprobs_val
:
List
=
None
input_top_logprobs_idx
:
List
=
None
input_top_logprobs_idx
:
List
=
None
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k]
output_top_logprobs_val
:
List
=
None
output_top_logprobs_idx
:
List
=
None
# Used by speculative decoding (EAGLE)
# The output of transformer layers
hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
LogitsMetadata
:
class
LogitsMetadata
:
forward_mode
:
ForwardMode
forward_mode
:
ForwardMode
top_logprobs_nums
:
Optional
[
List
[
int
]]
capture_hidden_mode
:
CaptureHiddenMode
=
CaptureHiddenMode
.
NULL
return_logprob
:
bool
=
False
return_top_logprob
:
bool
=
False
extend_return_logprob
:
bool
=
False
extend_return_top_logprob
:
bool
=
False
extend_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
extend_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
extend_seq_lens_cpu
:
Optional
[
List
[
int
]]
=
None
extend_seq_lens_cpu
:
Optional
[
List
[
int
]]
=
None
extend_logprob_start_lens_cpu
:
Optional
[
List
[
int
]]
=
None
extend_logprob_start_lens_cpu
:
Optional
[
List
[
int
]]
=
None
extend_logprob_pruned_lens_cpu
:
Optional
[
List
[
int
]]
=
None
extend_logprob_pruned_lens_cpu
:
Optional
[
List
[
int
]]
=
None
top_logprobs_nums
:
Optional
[
List
[
int
]]
=
None
capture_hidden_mode
:
CaptureHiddenMode
=
CaptureHiddenMode
.
NULL
@
classmethod
@
classmethod
def
from_forward_batch
(
cls
,
forward_batch
:
ForwardBatch
):
def
from_forward_batch
(
cls
,
forward_batch
:
ForwardBatch
):
extend_logprob_pruned_lens_cpu
=
None
if
forward_batch
.
spec_info
:
capture_hidden_mode
=
forward_batch
.
spec_info
.
capture_hidden_mode
else
:
capture_hidden_mode
=
CaptureHiddenMode
.
NULL
if
forward_batch
.
return_logprob
:
if
forward_batch
.
forward_mode
.
is_extend
()
and
forward_batch
.
return_logprob
:
return_top_logprob
=
any
(
x
>
0
for
x
in
forward_batch
.
top_logprobs_nums
)
extend_return_logprob
=
True
if
forward_batch
.
forward_mode
.
is_extend
():
extend_return_top_logprob
=
any
(
x
>
0
for
x
in
forward_batch
.
top_logprobs_nums
)
extend_logprob_pruned_lens_cpu
=
[
extend_logprob_pruned_lens_cpu
=
[
extend_len
-
start_len
extend_len
-
start_len
for
extend_len
,
start_len
in
zip
(
for
extend_len
,
start_len
in
zip
(
...
@@ -86,23 +92,20 @@ class LogitsMetadata:
...
@@ -86,23 +92,20 @@ class LogitsMetadata:
)
)
]
]
else
:
else
:
return_top_logprob
=
False
extend_return_logprob
=
extend_return_top_logprob
=
(
extend_logprob_pruned_lens_cpu
if
forward_batch
.
spec_info
:
)
=
False
capture_hidden_mode
=
forward_batch
.
spec_info
.
capture_hidden_mode
else
:
capture_hidden_mode
=
CaptureHiddenMode
.
NULL
return
cls
(
return
cls
(
forward_mode
=
forward_batch
.
forward_mode
,
forward_mode
=
forward_batch
.
forward_mode
,
top_logprobs_nums
=
forward_batch
.
top_logprobs_nums
,
capture_hidden_mode
=
capture_hidden_mode
,
return_logprob
=
forward_batch
.
return_logprob
,
extend_
return_logprob
=
extend_
return_logprob
,
return_top_logprob
=
return_top_logprob
,
extend_
return_top_logprob
=
extend_
return_top_logprob
,
extend_seq_lens
=
forward_batch
.
extend_seq_lens
,
extend_seq_lens
=
forward_batch
.
extend_seq_lens
,
extend_seq_lens_cpu
=
forward_batch
.
extend_seq_lens_cpu
,
extend_seq_lens_cpu
=
forward_batch
.
extend_seq_lens_cpu
,
extend_logprob_start_lens_cpu
=
forward_batch
.
extend_logprob_start_lens_cpu
,
extend_logprob_start_lens_cpu
=
forward_batch
.
extend_logprob_start_lens_cpu
,
extend_logprob_pruned_lens_cpu
=
extend_logprob_pruned_lens_cpu
,
extend_logprob_pruned_lens_cpu
=
extend_logprob_pruned_lens_cpu
,
capture_hidden_mode
=
capture_hidden_mode
,
top_logprobs_nums
=
forward_batch
.
top_logprobs_nums
,
)
)
...
@@ -129,7 +132,6 @@ class LogitsProcessor(nn.Module):
...
@@ -129,7 +132,6 @@ class LogitsProcessor(nn.Module):
):
):
if
isinstance
(
logits_metadata
,
ForwardBatch
):
if
isinstance
(
logits_metadata
,
ForwardBatch
):
logits_metadata
=
LogitsMetadata
.
from_forward_batch
(
logits_metadata
)
logits_metadata
=
LogitsMetadata
.
from_forward_batch
(
logits_metadata
)
assert
isinstance
(
logits_metadata
,
LogitsMetadata
)
# Get the last hidden states and last logits for the next token prediction
# Get the last hidden states and last logits for the next token prediction
if
(
if
(
...
@@ -142,18 +144,10 @@ class LogitsProcessor(nn.Module):
...
@@ -142,18 +144,10 @@ class LogitsProcessor(nn.Module):
last_index
=
torch
.
cumsum
(
logits_metadata
.
extend_seq_lens
,
dim
=
0
)
-
1
last_index
=
torch
.
cumsum
(
logits_metadata
.
extend_seq_lens
,
dim
=
0
)
-
1
last_hidden
=
hidden_states
[
last_index
]
last_hidden
=
hidden_states
[
last_index
]
# Compute logits
last_logits
=
self
.
_get_logits
(
last_hidden
,
lm_head
)
last_logits
=
self
.
_get_logits
(
last_hidden
,
lm_head
)
if
self
.
do_tensor_parallel_all_gather
:
if
not
logits_metadata
.
extend_return_logprob
:
last_logits
=
tensor_model_parallel_all_gather
(
last_logits
)
# Decode mode or extend mode without return_logprob.
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
if
self
.
final_logit_softcapping
:
last_logits
.
div_
(
self
.
final_logit_softcapping
)
torch
.
tanh
(
last_logits
,
out
=
last_logits
)
last_logits
.
mul_
(
self
.
final_logit_softcapping
)
# Return only last_logits if logprob is not requested
if
not
logits_metadata
.
return_logprob
:
return
LogitsProcessorOutput
(
return
LogitsProcessorOutput
(
next_token_logits
=
last_logits
,
next_token_logits
=
last_logits
,
hidden_states
=
(
hidden_states
=
(
...
@@ -166,74 +160,42 @@ class LogitsProcessor(nn.Module):
...
@@ -166,74 +160,42 @@ class LogitsProcessor(nn.Module):
)
)
),
),
)
)
else
:
last_logprobs
=
self
.
compute_temp_top_p_normalized_logprobs
(
last_logits
,
logits_metadata
)
if
logits_metadata
.
forward_mode
.
is_decode
():
if
logits_metadata
.
return_top_logprob
:
output_top_logprobs_val
,
output_top_logprobs_idx
=
(
self
.
get_top_logprobs
(
last_logprobs
,
logits_metadata
)[
2
:
4
]
)
else
:
output_top_logprobs_val
=
output_top_logprobs_idx
=
None
return
LogitsProcessorOutput
(
next_token_logits
=
last_logits
,
next_token_logprobs
=
last_logprobs
,
output_top_logprobs_val
=
output_top_logprobs_val
,
output_top_logprobs_idx
=
output_top_logprobs_idx
,
)
else
:
else
:
# Slice the requested tokens to compute logprob
# Slice the requested tokens to compute logprob
pt
,
states
,
pruned_input_ids
=
0
,
[],
[]
pt
,
pruned_
states
,
pruned_input_ids
=
0
,
[],
[]
for
start_len
,
extend_len
in
zip
(
for
start_len
,
extend_len
in
zip
(
logits_metadata
.
extend_logprob_start_lens_cpu
,
logits_metadata
.
extend_logprob_start_lens_cpu
,
logits_metadata
.
extend_seq_lens_cpu
,
logits_metadata
.
extend_seq_lens_cpu
,
):
):
states
.
append
(
hidden_states
[
pt
+
start_len
:
pt
+
extend_len
])
pruned_
states
.
append
(
hidden_states
[
pt
+
start_len
:
pt
+
extend_len
])
pruned_input_ids
.
append
(
input_ids
[
pt
+
start_len
:
pt
+
extend_len
])
pruned_input_ids
.
append
(
input_ids
[
pt
+
start_len
:
pt
+
extend_len
])
pt
+=
extend_len
pt
+=
extend_len
# Compute the logits and logprobs for all required tokens
# Compute the logits of all required tokens
states
=
torch
.
cat
(
states
,
dim
=
0
)
pruned_states
=
torch
.
cat
(
pruned_states
)
all_logits
=
self
.
_get_logits
(
states
,
lm_head
)
del
hidden_states
if
self
.
do_tensor_parallel_all_gather
:
input_token_logits
=
self
.
_get_logits
(
pruned_states
,
lm_head
)
all_logits
=
tensor_model_parallel_all_gather
(
all_logits
)
del
pruned_states
# The LM head's weights may be zero-padded for parallelism. Remove any
# extra logits that this padding may have produced.
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
if
self
.
final_logit_softcapping
:
all_logits
.
div_
(
self
.
final_logit_softcapping
)
torch
.
tanh
(
all_logits
,
out
=
all_logits
)
all_logits
.
mul_
(
self
.
final_logit_softcapping
)
all_logprobs
=
all_logits
del
all_logits
,
hidden_states
all_logprobs
=
self
.
compute_temp_top_p_normalized_logprobs
(
# Normalize the logprob w/o temperature, top-p
all_logprobs
,
logits_metadata
input_logprobs
=
input_token_logits
input_logprobs
=
self
.
compute_temp_top_p_normalized_logprobs
(
input_logprobs
,
logits_metadata
)
)
# Get the logprob of top-k tokens
# Get the logprob of top-k tokens
if
logits_metadata
.
return_top_logprob
:
if
logits_metadata
.
extend_
return_top_logprob
:
(
(
input_top_logprobs_val
,
input_top_logprobs_val
,
input_top_logprobs_idx
,
input_top_logprobs_idx
,
output_top_logprobs_val
,
)
=
self
.
get_top_logprobs
(
input_logprobs
,
logits_metadata
)
output_top_logprobs_idx
,
)
=
self
.
get_top_logprobs
(
all_logprobs
,
logits_metadata
)
else
:
else
:
input_top_logprobs_val
=
input_top_logprobs_idx
=
(
input_top_logprobs_val
=
input_top_logprobs_idx
=
None
output_top_logprobs_val
)
=
output_top_logprobs_idx
=
None
# Compute the normalized logprobs for the requested tokens.
# Compute the normalized logprobs for the requested tokens.
# Note that we pad a zero at the end for easy batching.
# Note that we pad a zero at the end for easy batching.
input_token_logprobs
=
all
_logprobs
[
input_token_logprobs
=
input
_logprobs
[
torch
.
arange
(
all
_logprobs
.
shape
[
0
],
device
=
"cuda"
),
torch
.
arange
(
input
_logprobs
.
shape
[
0
],
device
=
"cuda"
),
torch
.
cat
(
torch
.
cat
(
[
[
torch
.
cat
(
pruned_input_ids
)[
1
:],
torch
.
cat
(
pruned_input_ids
)[
1
:],
...
@@ -248,13 +210,10 @@ class LogitsProcessor(nn.Module):
...
@@ -248,13 +210,10 @@ class LogitsProcessor(nn.Module):
return
LogitsProcessorOutput
(
return
LogitsProcessorOutput
(
next_token_logits
=
last_logits
,
next_token_logits
=
last_logits
,
next_token_logprobs
=
last_logprobs
,
normalized_prompt_logprobs
=
normalized_prompt_logprobs
,
normalized_prompt_logprobs
=
normalized_prompt_logprobs
,
input_token_logprobs
=
input_token_logprobs
,
input_token_logprobs
=
input_token_logprobs
,
input_top_logprobs_val
=
input_top_logprobs_val
,
input_top_logprobs_val
=
input_top_logprobs_val
,
input_top_logprobs_idx
=
input_top_logprobs_idx
,
input_top_logprobs_idx
=
input_top_logprobs_idx
,
output_top_logprobs_val
=
output_top_logprobs_val
,
output_top_logprobs_idx
=
output_top_logprobs_idx
,
)
)
def
_get_logits
(
def
_get_logits
(
...
@@ -269,9 +228,19 @@ class LogitsProcessor(nn.Module):
...
@@ -269,9 +228,19 @@ class LogitsProcessor(nn.Module):
# GGUF models
# GGUF models
logits
=
lm_head
.
linear_method
.
apply
(
lm_head
,
hidden_states
,
embedding_bias
)
logits
=
lm_head
.
linear_method
.
apply
(
lm_head
,
hidden_states
,
embedding_bias
)
# Optional scaling factor
if
self
.
logit_scale
is
not
None
:
if
self
.
logit_scale
is
not
None
:
logits
.
mul_
(
self
.
logit_scale
)
# In-place multiply
logits
.
mul_
(
self
.
logit_scale
)
if
self
.
do_tensor_parallel_all_gather
:
logits
=
tensor_model_parallel_all_gather
(
logits
)
# Compute the normalized logprobs for the requested tokens.
# Note that we pad a zero at the end for easy batching.
logits
=
logits
[:,
:
self
.
config
.
vocab_size
].
float
()
if
self
.
final_logit_softcapping
:
fused_softcap
(
logits
,
self
.
final_logit_softcapping
)
return
logits
return
logits
@
staticmethod
@
staticmethod
...
@@ -302,16 +271,7 @@ class LogitsProcessor(nn.Module):
...
@@ -302,16 +271,7 @@ class LogitsProcessor(nn.Module):
values
=
ret
.
values
.
tolist
()
values
=
ret
.
values
.
tolist
()
indices
=
ret
.
indices
.
tolist
()
indices
=
ret
.
indices
.
tolist
()
if
logits_metadata
.
forward_mode
.
is_decode
():
output_top_logprobs_val
=
[]
output_top_logprobs_idx
=
[]
for
i
,
k
in
enumerate
(
logits_metadata
.
top_logprobs_nums
):
output_top_logprobs_val
.
append
(
values
[
i
][:
k
])
output_top_logprobs_idx
.
append
(
indices
[
i
][:
k
])
return
None
,
None
,
output_top_logprobs_val
,
output_top_logprobs_idx
else
:
input_top_logprobs_val
,
input_top_logprobs_idx
=
[],
[]
input_top_logprobs_val
,
input_top_logprobs_idx
=
[],
[]
output_top_logprobs_val
,
output_top_logprobs_idx
=
[],
[]
pt
=
0
pt
=
0
for
k
,
pruned_len
in
zip
(
for
k
,
pruned_len
in
zip
(
...
@@ -321,8 +281,6 @@ class LogitsProcessor(nn.Module):
...
@@ -321,8 +281,6 @@ class LogitsProcessor(nn.Module):
if
pruned_len
<=
0
:
if
pruned_len
<=
0
:
input_top_logprobs_val
.
append
([])
input_top_logprobs_val
.
append
([])
input_top_logprobs_idx
.
append
([])
input_top_logprobs_idx
.
append
([])
output_top_logprobs_val
.
append
([])
output_top_logprobs_idx
.
append
([])
continue
continue
input_top_logprobs_val
.
append
(
input_top_logprobs_val
.
append
(
...
@@ -331,61 +289,55 @@ class LogitsProcessor(nn.Module):
...
@@ -331,61 +289,55 @@ class LogitsProcessor(nn.Module):
input_top_logprobs_idx
.
append
(
input_top_logprobs_idx
.
append
(
[
indices
[
pt
+
j
][:
k
]
for
j
in
range
(
pruned_len
-
1
)]
[
indices
[
pt
+
j
][:
k
]
for
j
in
range
(
pruned_len
-
1
)]
)
)
output_top_logprobs_val
.
append
(
list
(
values
[
pt
+
pruned_len
-
1
][:
k
],
)
)
output_top_logprobs_idx
.
append
(
list
(
indices
[
pt
+
pruned_len
-
1
][:
k
],
)
)
pt
+=
pruned_len
pt
+=
pruned_len
return
(
return
input_top_logprobs_val
,
input_top_logprobs_idx
input_top_logprobs_val
,
input_top_logprobs_idx
,
output_top_logprobs_val
,
output_top_logprobs_idx
,
)
@
staticmethod
@
staticmethod
def
compute_temp_top_p_normalized_logprobs
(
def
compute_temp_top_p_normalized_logprobs
(
last_logits
:
torch
.
Tensor
,
logits_metadata
:
LogitsMetadata
last_logits
:
torch
.
Tensor
,
logits_metadata
:
LogitsMetadata
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# TODO: Implement the temp and top-p normalization
return
torch
.
nn
.
functional
.
log_softmax
(
last_logits
,
dim
=-
1
)
return
torch
.
nn
.
functional
.
log_softmax
(
last_logits
,
dim
=-
1
)
def
test
():
@
triton
.
jit
all_logprobs
=
torch
.
tensor
(
def
fused_softcap_kernel
(
# s s s
full_logits_ptr
,
[[
0
,
1
,
2
,
3
],
[
1
,
2
,
3
,
4
],
[
2
,
3
,
4
,
5
],
[
3
,
4
,
5
,
6
],
[
4
,
5
,
6
,
7
]],
softcapping_value
,
dtype
=
torch
.
float32
,
n_elements
,
device
=
"cuda"
,
BLOCK_SIZE
:
tl
.
constexpr
,
)
):
seq_lens
=
torch
.
tensor
([
2
,
0
,
3
,
0
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
pid
=
tl
.
program_id
(
0
)
input_ids
=
torch
.
tensor
([
1
,
2
,
3
,
0
,
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
block_start
=
pid
*
BLOCK_SIZE
offsets
=
block_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offsets
<
n_elements
token_logprobs
=
all_logprobs
[
# Load values
torch
.
arange
(
all_logprobs
.
shape
[
0
],
device
=
"cuda"
),
x
=
tl
.
load
(
full_logits_ptr
+
offsets
,
mask
=
mask
)
torch
.
cat
([
input_ids
[
1
:],
torch
.
tensor
([
0
],
device
=
"cuda"
)]),
]
logprobs_cumsum
=
torch
.
cumsum
(
token_logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
len_cumsum
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
# Perform operations in-place
start
=
torch
.
cat
((
torch
.
tensor
([
0
],
device
=
"cuda"
),
len_cumsum
[:
-
1
]),
0
)
x
=
x
/
softcapping_value
end
=
start
+
seq_lens
-
2
start
.
clamp_
(
min
=
0
,
max
=
token_logprobs
.
shape
[
0
]
-
1
)
end
.
clamp_
(
min
=
0
,
max
=
token_logprobs
.
shape
[
0
]
-
1
)
sum_logp
=
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
token_logprobs
[
start
]
# assert logprobs == [2, _, 2, 4, _]
# Manual tanh implementation using exp
print
(
"token logprobs"
,
token_logprobs
)
exp2x
=
tl
.
exp
(
2
*
x
)
print
(
"start"
,
start
)
x
=
(
exp2x
-
1
)
/
(
exp2x
+
1
)
print
(
"end"
,
end
)
print
(
"sum_logp"
,
sum_logp
)
x
=
x
*
softcapping_value
if
__name__
==
"__main__"
:
# Store result
test
()
tl
.
store
(
full_logits_ptr
+
offsets
,
x
,
mask
=
mask
)
def
fused_softcap
(
full_logits
,
final_logit_softcapping
):
n_elements
=
full_logits
.
numel
()
BLOCK_SIZE
=
1024
grid
=
((
n_elements
+
BLOCK_SIZE
-
1
)
//
BLOCK_SIZE
,
1
,
1
)
fused_softcap_kernel
[
grid
](
full_logits_ptr
=
full_logits
,
softcapping_value
=
final_logit_softcapping
,
n_elements
=
n_elements
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
return
full_logits
python/sglang/srt/layers/sampler.py
View file @
9c6ba248
import
logging
import
logging
from
typing
import
Union
from
typing
import
List
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -28,13 +28,12 @@ class Sampler(nn.Module):
...
@@ -28,13 +28,12 @@ class Sampler(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
logits
:
Union
[
torch
.
Tensor
,
LogitsProcessorOutput
]
,
logits
_output
:
LogitsProcessorOutput
,
sampling_info
:
SamplingBatchInfo
,
sampling_info
:
SamplingBatchInfo
,
return_logprob
:
bool
,
top_logprobs_nums
:
List
[
int
],
):
):
if
isinstance
(
logits
,
LogitsProcessorOutput
):
logits
=
logits_output
.
next_token_logits
logits
=
logits
.
next_token_logits
logits
=
logits
.
contiguous
()
if
self
.
use_nan_detectioin
and
torch
.
any
(
torch
.
isnan
(
logits
)):
if
self
.
use_nan_detectioin
and
torch
.
any
(
torch
.
isnan
(
logits
)):
logger
.
warning
(
"Detected errors during sampling! NaN in the logits."
)
logger
.
warning
(
"Detected errors during sampling! NaN in the logits."
)
...
@@ -47,6 +46,8 @@ class Sampler(nn.Module):
...
@@ -47,6 +46,8 @@ class Sampler(nn.Module):
if
sampling_info
.
is_all_greedy
:
if
sampling_info
.
is_all_greedy
:
# Use torch.argmax if all requests use greedy sampling
# Use torch.argmax if all requests use greedy sampling
batch_next_token_ids
=
torch
.
argmax
(
logits
,
-
1
)
batch_next_token_ids
=
torch
.
argmax
(
logits
,
-
1
)
if
return_logprob
:
logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
else
:
else
:
# Post process logits
# Post process logits
logits
.
div_
(
sampling_info
.
temperatures
)
logits
.
div_
(
sampling_info
.
temperatures
)
...
@@ -54,6 +55,12 @@ class Sampler(nn.Module):
...
@@ -54,6 +55,12 @@ class Sampler(nn.Module):
del
logits
del
logits
if
global_server_args_dict
[
"sampling_backend"
]
==
"flashinfer"
:
if
global_server_args_dict
[
"sampling_backend"
]
==
"flashinfer"
:
if
return_logprob
:
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems
logprobs
=
torch
.
log
(
top_p_normalize_probs_torch
(
probs
,
sampling_info
.
top_ps
)
)
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
uniform_samples
=
torch
.
rand
(
uniform_samples
=
torch
.
rand
(
(
max_top_k_round
,
batch_size
),
device
=
probs
.
device
(
max_top_k_round
,
batch_size
),
device
=
probs
.
device
...
@@ -76,6 +83,7 @@ class Sampler(nn.Module):
...
@@ -76,6 +83,7 @@ class Sampler(nn.Module):
if
self
.
use_nan_detectioin
and
not
torch
.
all
(
success
):
if
self
.
use_nan_detectioin
and
not
torch
.
all
(
success
):
logger
.
warning
(
"Detected errors during sampling!"
)
logger
.
warning
(
"Detected errors during sampling!"
)
batch_next_token_ids
=
torch
.
zeros_like
(
batch_next_token_ids
)
batch_next_token_ids
=
torch
.
zeros_like
(
batch_next_token_ids
)
elif
global_server_args_dict
[
"sampling_backend"
]
==
"pytorch"
:
elif
global_server_args_dict
[
"sampling_backend"
]
==
"pytorch"
:
# A slower fallback implementation with torch native operations.
# A slower fallback implementation with torch native operations.
batch_next_token_ids
=
top_k_top_p_min_p_sampling_from_probs_torch
(
batch_next_token_ids
=
top_k_top_p_min_p_sampling_from_probs_torch
(
...
@@ -85,12 +93,31 @@ class Sampler(nn.Module):
...
@@ -85,12 +93,31 @@ class Sampler(nn.Module):
sampling_info
.
min_ps
,
sampling_info
.
min_ps
,
sampling_info
.
need_min_p_sampling
,
sampling_info
.
need_min_p_sampling
,
)
)
if
return_logprob
:
logprobs
=
torch
.
log
(
top_p_normalize_probs_torch
(
probs
,
sampling_info
.
top_ps
)
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Invalid sampling backend:
{
global_server_args_dict
[
'sampling_backend'
]
}
"
f
"Invalid sampling backend:
{
global_server_args_dict
[
'sampling_backend'
]
}
"
)
)
return
batch_next_token_ids
.
to
(
torch
.
int32
)
batch_next_token_ids
=
batch_next_token_ids
.
to
(
torch
.
int32
)
# Attach logprobs to logits_output (in-place modification)
if
return_logprob
:
if
any
(
x
>
0
for
x
in
top_logprobs_nums
):
(
logits_output
.
next_token_top_logprobs_val
,
logits_output
.
next_token_top_logprobs_idx
,
)
=
get_top_logprobs
(
logprobs
,
top_logprobs_nums
)
logits_output
.
next_token_logprobs
=
logprobs
[
torch
.
arange
(
len
(
batch_next_token_ids
),
device
=
sampling_info
.
device
),
batch_next_token_ids
,
]
return
batch_next_token_ids
def
top_k_top_p_min_p_sampling_from_probs_torch
(
def
top_k_top_p_min_p_sampling_from_probs_torch
(
...
@@ -120,20 +147,27 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
...
@@ -120,20 +147,27 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
return
batch_next_token_ids
return
batch_next_token_ids
def
top_p_normalize_probs
(
def
top_p_normalize_probs
_torch
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
):
):
if
global_server_args_dict
[
"sampling_backend"
]
==
"flashinfer"
:
return
top_p_renorm_prob
(
probs
,
top_ps
)
elif
global_server_args_dict
[
"sampling_backend"
]
==
"pytorch"
:
# See also top_k_top_p_min_p_sampling_from_probs_torch
# See also top_k_top_p_min_p_sampling_from_probs_torch
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
probs_sort
[(
probs_sum
-
probs_sort
)
>
top_ps
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
[(
probs_sum
-
probs_sort
)
>
top_ps
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
.
div_
(
probs_sort
.
sum
(
dim
=-
1
,
keepdim
=
True
))
probs_sort
.
div_
(
probs_sort
.
sum
(
dim
=-
1
,
keepdim
=
True
))
return
torch
.
zeros_like
(
probs_sort
).
scatter_
(
-
1
,
probs_idx
,
probs_sort
)
return
torch
.
zeros_like
(
probs_sort
).
scatter_
(
-
1
,
probs_idx
,
probs_sort
)
else
:
raise
ValueError
(
f
"Invalid sampling backend:
{
global_server_args_dict
[
'sampling_backend'
]
}
"
def
get_top_logprobs
(
logprobs
:
torch
.
Tensor
,
top_logprobs_nums
:
List
[
int
]):
)
max_k
=
max
(
top_logprobs_nums
)
ret
=
logprobs
.
topk
(
max_k
,
dim
=
1
)
values
=
ret
.
values
.
tolist
()
indices
=
ret
.
indices
.
tolist
()
output_top_logprobs_val
=
[]
output_top_logprobs_idx
=
[]
for
i
,
k
in
enumerate
(
top_logprobs_nums
):
output_top_logprobs_val
.
append
(
values
[
i
][:
k
])
output_top_logprobs_idx
.
append
(
indices
[
i
][:
k
])
return
output_top_logprobs_val
,
output_top_logprobs_idx
python/sglang/srt/managers/scheduler.py
View file @
9c6ba248
...
@@ -974,12 +974,10 @@ class Scheduler:
...
@@ -974,12 +974,10 @@ class Scheduler:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
else
:
else
:
# Move next_token_ids and logprobs to cpu
# Move next_token_ids and logprobs to cpu
next_token_ids
=
next_token_ids
.
tolist
()
if
batch
.
return_logprob
:
if
batch
.
return_logprob
:
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
[
logits_output
.
next_token_logprobs
.
tolist
()
torch
.
arange
(
len
(
next_token_ids
),
device
=
self
.
device
),
next_token_ids
,
].
tolist
()
)
)
logits_output
.
input_token_logprobs
=
(
logits_output
.
input_token_logprobs
=
(
logits_output
.
input_token_logprobs
.
tolist
()
logits_output
.
input_token_logprobs
.
tolist
()
...
@@ -987,7 +985,6 @@ class Scheduler:
...
@@ -987,7 +985,6 @@ class Scheduler:
logits_output
.
normalized_prompt_logprobs
=
(
logits_output
.
normalized_prompt_logprobs
=
(
logits_output
.
normalized_prompt_logprobs
.
tolist
()
logits_output
.
normalized_prompt_logprobs
.
tolist
()
)
)
next_token_ids
=
next_token_ids
.
tolist
()
# Check finish conditions
# Check finish conditions
logprob_pt
=
0
logprob_pt
=
0
...
@@ -1064,13 +1061,9 @@ class Scheduler:
...
@@ -1064,13 +1061,9 @@ class Scheduler:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
next_token_logprobs
=
logits_output
.
next_token_logprobs
next_token_logprobs
=
logits_output
.
next_token_logprobs
else
:
else
:
# Move next_token_ids and logprobs to cpu
if
batch
.
return_logprob
:
next_token_logprobs
=
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
self
.
device
),
next_token_ids
,
].
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
if
batch
.
return_logprob
:
next_token_logprobs
=
logits_output
.
next_token_logprobs
.
tolist
()
self
.
token_to_kv_pool
.
free_group_begin
()
self
.
token_to_kv_pool
.
free_group_begin
()
...
@@ -1095,10 +1088,10 @@ class Scheduler:
...
@@ -1095,10 +1088,10 @@ class Scheduler:
req
.
output_token_logprobs_idx
.
append
(
next_token_id
)
req
.
output_token_logprobs_idx
.
append
(
next_token_id
)
if
req
.
top_logprobs_num
>
0
:
if
req
.
top_logprobs_num
>
0
:
req
.
output_top_logprobs_val
.
append
(
req
.
output_top_logprobs_val
.
append
(
logits_output
.
output
_top_logprobs_val
[
i
]
logits_output
.
next_token
_top_logprobs_val
[
i
]
)
)
req
.
output_top_logprobs_idx
.
append
(
req
.
output_top_logprobs_idx
.
append
(
logits_output
.
output
_top_logprobs_idx
[
i
]
logits_output
.
next_token
_top_logprobs_idx
[
i
]
)
)
if
req
.
grammar
is
not
None
:
if
req
.
grammar
is
not
None
:
...
@@ -1200,8 +1193,9 @@ class Scheduler:
...
@@ -1200,8 +1193,9 @@ class Scheduler:
req
.
output_top_logprobs_idx
.
extend
(
req
.
output_top_logprobs_idx
.
extend
(
output
.
input_top_logprobs_idx
[
i
][
-
req
.
last_update_decode_tokens
:]
output
.
input_top_logprobs_idx
[
i
][
-
req
.
last_update_decode_tokens
:]
)
)
req
.
output_top_logprobs_val
.
append
(
output
.
output_top_logprobs_val
[
i
])
req
.
output_top_logprobs_idx
.
append
(
output
.
output_top_logprobs_idx
[
i
])
req
.
output_top_logprobs_val
.
append
(
output
.
next_token_top_logprobs_val
[
i
])
req
.
output_top_logprobs_idx
.
append
(
output
.
next_token_top_logprobs_idx
[
i
])
return
num_input_logprobs
return
num_input_logprobs
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
9c6ba248
...
@@ -144,10 +144,9 @@ class TpModelWorkerClient:
...
@@ -144,10 +144,9 @@ class TpModelWorkerClient:
# Copy results to the CPU
# Copy results to the CPU
if
model_worker_batch
.
return_logprob
:
if
model_worker_batch
.
return_logprob
:
logits_output
.
next_token_logprobs
=
logits_output
.
next_token_logprobs
[
logits_output
.
next_token_logprobs
=
(
torch
.
arange
(
len
(
next_token_ids
),
device
=
self
.
device
),
logits_output
.
next_token_logprobs
.
to
(
"cpu"
,
non_blocking
=
True
)
next_token_ids
,
)
].
to
(
"cpu"
,
non_blocking
=
True
)
if
logits_output
.
input_token_logprobs
is
not
None
:
if
logits_output
.
input_token_logprobs
is
not
None
:
logits_output
.
input_token_logprobs
=
(
logits_output
.
input_token_logprobs
=
(
logits_output
.
input_token_logprobs
.
to
(
"cpu"
,
non_blocking
=
True
)
logits_output
.
input_token_logprobs
.
to
(
"cpu"
,
non_blocking
=
True
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
9c6ba248
...
@@ -392,34 +392,7 @@ class CudaGraphRunner:
...
@@ -392,34 +392,7 @@ class CudaGraphRunner:
self
.
graphs
[
bs
].
replay
()
self
.
graphs
[
bs
].
replay
()
next_token_logits
=
self
.
output_buffers
[
bs
][:
raw_bs
]
next_token_logits
=
self
.
output_buffers
[
bs
][:
raw_bs
]
# Extract logprobs
if
forward_batch
.
return_logprob
:
logits_metadata
=
LogitsMetadata
(
forward_mode
=
ForwardMode
.
DECODE
,
top_logprobs_nums
=
forward_batch
.
top_logprobs_nums
,
)
next_token_logprobs
=
(
LogitsProcessor
.
compute_temp_top_p_normalized_logprobs
(
next_token_logits
,
logits_metadata
)
)
logits_output
=
LogitsProcessorOutput
(
logits_output
=
LogitsProcessorOutput
(
next_token_logits
=
next_token_logits
,
next_token_logits
=
next_token_logits
,
next_token_logprobs
=
next_token_logprobs
,
)
)
return_top_logprob
=
any
(
x
>
0
for
x
in
forward_batch
.
top_logprobs_nums
)
if
return_top_logprob
:
(
logits_output
.
output_top_logprobs_val
,
logits_output
.
output_top_logprobs_idx
,
)
=
LogitsProcessor
.
get_top_logprobs
(
next_token_logprobs
,
logits_metadata
)[
2
:
4
]
else
:
logits_output
=
LogitsProcessorOutput
(
next_token_logits
=
next_token_logits
,
)
return
logits_output
return
logits_output
python/sglang/srt/model_executor/model_runner.py
View file @
9c6ba248
...
@@ -36,7 +36,7 @@ from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
...
@@ -36,7 +36,7 @@ from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.layers.sampler
import
Sampler
,
get_top_logprobs
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_to_model
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_to_model
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
...
@@ -48,7 +48,6 @@ from sglang.srt.mem_cache.memory_pool import (
...
@@ -48,7 +48,6 @@ from sglang.srt.mem_cache.memory_pool import (
)
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader
import
get_model
from
sglang.srt.model_loader
import
get_model
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
enable_show_time_cost
,
enable_show_time_cost
,
...
@@ -192,7 +191,8 @@ class ModelRunner:
...
@@ -192,7 +191,8 @@ class ModelRunner:
torch
.
get_device_module
(
self
.
device
).
set_device
(
self
.
gpu_id
)
torch
.
get_device_module
(
self
.
device
).
set_device
(
self
.
gpu_id
)
if
self
.
device
==
"cuda"
:
if
self
.
device
==
"cuda"
:
backend
=
"nccl"
backend
=
"nccl"
# ToDO(liangan1):Just use gloo to bypass the initilization fail
# TODO(liangan1):Just use gloo to bypass the initilization fail
# Need to use xccl for xpu backend in the future
# Need to use xccl for xpu backend in the future
elif
self
.
device
==
"xpu"
:
elif
self
.
device
==
"xpu"
:
backend
=
"gloo"
backend
=
"gloo"
...
@@ -704,6 +704,7 @@ class ModelRunner:
...
@@ -704,6 +704,7 @@ class ModelRunner:
def
sample
(
def
sample
(
self
,
logits_output
:
LogitsProcessorOutput
,
forward_batch
:
ForwardBatch
self
,
logits_output
:
LogitsProcessorOutput
,
forward_batch
:
ForwardBatch
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Apply logit bias
sampling_info
=
forward_batch
.
sampling_info
sampling_info
=
forward_batch
.
sampling_info
if
sampling_info
.
sampling_info_done
:
if
sampling_info
.
sampling_info_done
:
# Overlap mode: the function update_regex_vocab_mask was executed
# Overlap mode: the function update_regex_vocab_mask was executed
...
@@ -714,34 +715,16 @@ class ModelRunner:
...
@@ -714,34 +715,16 @@ class ModelRunner:
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info
.
update_regex_vocab_mask
()
sampling_info
.
update_regex_vocab_mask
()
sampling_info
.
update_penalties
()
sampling_info
.
update_penalties
()
logits
=
self
.
apply_logits_bias
(
logits_output
.
next_token_logits
,
sampling_info
)
sampling_info
.
apply_logits_bias
(
logits_output
.
next_token_logits
)
# Sample the next tokens.
next_token_ids
=
self
.
sampler
(
logits
,
sampling_info
)
return
next_token_ids
def
apply_logits_bias
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
# Apply logit_bias
if
sampling_info
.
logit_bias
is
not
None
:
logits
.
add_
(
sampling_info
.
logit_bias
)
# min-token, presence, frequency
# Sample the next tokens
if
sampling_info
.
linear_penalties
is
not
None
:
next_token_ids
=
self
.
sampler
(
logits
.
add_
(
sampling_info
.
linear_penalties
)
logits_output
,
sampling_info
,
# repetition
forward_batch
.
return_logprob
,
if
sampling_info
.
scaling_penalties
is
not
None
:
forward_batch
.
top_logprobs_nums
,
logits
=
torch
.
where
(
logits
>
0
,
logits
/
sampling_info
.
scaling_penalties
,
logits
*
sampling_info
.
scaling_penalties
,
)
)
return
next_token_ids
# Apply regex vocab_mask
if
sampling_info
.
vocab_mask
is
not
None
:
sampling_info
.
apply_mask
(
logits
=
logits
,
vocab_mask
=
sampling_info
.
vocab_mask
)
return
logits
@
property
@
property
def
model_is_mrope
(
self
)
->
bool
:
def
model_is_mrope
(
self
)
->
bool
:
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
9c6ba248
...
@@ -232,3 +232,26 @@ class SamplingBatchInfo:
...
@@ -232,3 +232,26 @@ class SamplingBatchInfo:
self
.
logit_bias
=
SamplingBatchInfo
.
merge_bias_tensor
(
self
.
logit_bias
=
SamplingBatchInfo
.
merge_bias_tensor
(
self
.
logit_bias
,
other
.
logit_bias
,
len
(
self
),
len
(
other
),
self
.
device
self
.
logit_bias
,
other
.
logit_bias
,
len
(
self
),
len
(
other
),
self
.
device
)
)
def
apply_logits_bias
(
self
,
logits
:
torch
.
Tensor
):
# Apply logit_bias
if
self
.
logit_bias
is
not
None
:
logits
.
add_
(
self
.
logit_bias
)
# min-token, presence, frequency
if
self
.
linear_penalties
is
not
None
:
logits
.
add_
(
self
.
linear_penalties
)
# repetition
if
self
.
scaling_penalties
is
not
None
:
logits
=
torch
.
where
(
logits
>
0
,
logits
/
self
.
scaling_penalties
,
logits
*
self
.
scaling_penalties
,
)
# Apply regex vocab_mask
if
self
.
vocab_mask
is
not
None
:
self
.
apply_mask
(
logits
=
logits
,
vocab_mask
=
self
.
vocab_mask
)
return
logits
test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py
View file @
9c6ba248
...
@@ -6,7 +6,7 @@ import requests
...
@@ -6,7 +6,7 @@ import requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
popen_launch_server
,
...
@@ -17,7 +17,7 @@ class TestBatchPenalizerE2E(unittest.TestCase):
...
@@ -17,7 +17,7 @@ class TestBatchPenalizerE2E(unittest.TestCase):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
model
,
...
...
test/srt/test_srt_endpoint.py
View file @
9c6ba248
...
@@ -213,6 +213,41 @@ class TestSRTEndpoint(unittest.TestCase):
...
@@ -213,6 +213,41 @@ class TestSRTEndpoint(unittest.TestCase):
max_diff
=
np
.
max
(
diff
)
max_diff
=
np
.
max
(
diff
)
self
.
assertLess
(
max_diff
,
0.25
)
self
.
assertLess
(
max_diff
,
0.25
)
def
test_logprob_grammar
(
self
):
prompts
=
"Question: Is Paris the Capital of France? Answer:"
allowed_tokens
=
[
" Yes"
,
" No"
]
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
prompts
,
"sampling_params"
:
{
"temperature"
:
1.0
,
"max_new_tokens"
:
1
,
"regex"
:
"( Yes| No)"
,
},
"return_logprob"
:
True
,
"top_logprobs_num"
:
5
,
"return_text_in_logprobs"
:
True
,
},
)
response_json
=
response
.
json
()
output_top_logprobs
=
response_json
[
"meta_info"
][
"output_top_logprobs"
][
0
]
print
(
f
"
{
output_top_logprobs
=
}
"
)
# Parse results
# This is becaues the grammar constraint allows all prefix tokens
logprobs
=
[
None
]
*
2
for
i
in
range
(
len
(
output_top_logprobs
)):
try
:
idx
=
allowed_tokens
.
index
(
output_top_logprobs
[
i
][
2
])
except
ValueError
:
# Not found
continue
logprobs
[
idx
]
=
output_top_logprobs
[
i
][
0
]
self
.
assertTrue
(
all
(
x
is
not
None
for
x
in
logprobs
))
def
test_get_server_info
(
self
):
def
test_get_server_info
(
self
):
response
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
response
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
response_json
=
response
.
json
()
response_json
=
response
.
json
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment