Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2fa5cec7
Unverified
Commit
2fa5cec7
authored
Sep 16, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 16, 2024
Browse files
Simplify sampler and its error handling (#1441)
parent
27b557ae
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
32 additions
and
159 deletions
+32
-159
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+22
-91
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+6
-18
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+3
-50
No files found.
python/sglang/srt/layers/sampler.py
View file @
2fa5cec7
import
dataclasses
import
logging
from
typing
import
Tuple
,
Union
from
typing
import
Union
import
torch
from
flashinfer.sampling
import
(
...
...
@@ -9,43 +8,17 @@ from flashinfer.sampling import (
top_k_top_p_sampling_from_probs
,
top_p_renorm_prob
,
)
from
torch.library
import
custom_op
as
torch_custom_op
from
vllm.model_executor.custom_op
import
CustomOp
from
torch
import
nn
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
# TODO: move this dict to another place
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
logger
=
logging
.
getLogger
(
__name__
)
@
dataclasses
.
dataclass
class
SampleOutput
:
success
:
torch
.
Tensor
probs
:
torch
.
Tensor
batch_next_token_ids
:
torch
.
Tensor
class
Sampler
(
CustomOp
):
def
__init__
(
self
):
super
().
__init__
()
# FIXME: torch.multinomial has too many bugs
self
.
forward_native
=
self
.
forward_cuda
self
.
is_torch_compile
=
False
def
_get_probs
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
# Post process logits
logits
=
logits
.
contiguous
()
logits
.
div_
(
sampling_info
.
temperatures
)
if
self
.
is_torch_compile
:
# FIXME: Temporary workaround for unknown bugs in torch.compile
logits
.
add_
(
0
)
return
torch
.
softmax
(
logits
,
dim
=-
1
)
def
forward_cuda
(
class
Sampler
(
nn
.
Module
):
def
forward
(
self
,
logits
:
Union
[
torch
.
Tensor
,
LogitsProcessorOutput
],
sampling_info
:
SamplingBatchInfo
,
...
...
@@ -53,7 +26,15 @@ class Sampler(CustomOp):
if
isinstance
(
logits
,
LogitsProcessorOutput
):
logits
=
logits
.
next_token_logits
probs
=
self
.
_get_probs
(
logits
,
sampling_info
)
# Post process logits
logits
.
div_
(
sampling_info
.
temperatures
)
probs
=
logits
[:]
=
torch
.
softmax
(
logits
,
dim
=-
1
)
if
torch
.
any
(
torch
.
isnan
(
probs
)):
logger
.
warning
(
"Detected errors during sampling! NaN in the probability."
)
probs
=
torch
.
where
(
torch
.
isnan
(
probs
),
torch
.
full_like
(
probs
,
1e-10
),
probs
)
if
global_server_args_dict
[
"sampling_backend"
]
==
"flashinfer"
:
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
...
...
@@ -67,12 +48,16 @@ class Sampler(CustomOp):
probs
,
uniform_samples
,
sampling_info
.
min_ps
)
else
:
batch_next_token_ids
,
success
=
flashinfer_
top_k_top_p
(
batch_next_token_ids
,
success
=
top_k_top_p
_sampling_from_probs
(
probs
,
uniform_samples
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
)
if
not
torch
.
all
(
success
):
logger
.
warning
(
"Detected errors during sampling!"
)
batch_next_token_ids
=
torch
.
zeros_like
(
batch_next_token_ids
)
elif
global_server_args_dict
[
"sampling_backend"
]
==
"pytorch"
:
# Here we provide a slower fallback implementation.
batch_next_token_ids
,
success
=
top_k_top_p_min_p_sampling_from_probs_torch
(
batch_next_token_ids
=
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
,
sampling_info
.
min_ps
)
else
:
...
...
@@ -80,48 +65,7 @@ class Sampler(CustomOp):
f
"Invalid sampling backend:
{
global_server_args_dict
[
'sampling_backend'
]
}
"
)
return
SampleOutput
(
success
,
probs
,
batch_next_token_ids
)
def
forward_native
(
self
,
logits
:
Union
[
torch
.
Tensor
,
LogitsProcessorOutput
],
sampling_info
:
SamplingBatchInfo
,
):
if
isinstance
(
logits
,
LogitsProcessorOutput
):
logits
=
logits
.
next_token_logits
probs
=
self
.
_get_probs
(
logits
,
sampling_info
)
batch_next_token_ids
,
success
=
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
,
sampling_info
.
min_ps
)
return
SampleOutput
(
success
,
probs
,
batch_next_token_ids
)
@
torch_custom_op
(
"my_lib::flashinfer_top_k_top_p"
,
mutates_args
=
{})
def
flashinfer_top_k_top_p
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
top_ks
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# NOTE: we do not use min_p neither in CUDA nor in torch.compile
return
top_k_top_p_sampling_from_probs
(
probs
,
uniform_samples
,
top_ks
,
top_ps
)
@
flashinfer_top_k_top_p
.
register_fake
def
_
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
top_ks
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
bs
=
probs
.
shape
[
0
]
return
(
torch
.
ones
(
bs
,
dtype
=
torch
.
bool
,
device
=
probs
.
device
),
torch
.
zeros
(
bs
,
dtype
=
torch
.
int32
,
device
=
probs
.
device
),
)
return
batch_next_token_ids
def
top_k_top_p_min_p_sampling_from_probs_torch
(
...
...
@@ -141,19 +85,6 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
]
=
0.0
probs_sort
[
probs_sort
<
min_p_thresholds
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
try
:
# FIXME: torch.multiomial does not support num_samples = 1
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
2
,
replacement
=
True
)[
:,
:
1
]
except
RuntimeError
as
e
:
logger
.
warning
(
f
"Sampling error:
{
e
}
"
)
batch_next_token_ids
=
torch
.
zeros
(
(
probs_sort
.
shape
[
0
],),
dtype
=
torch
.
int32
,
device
=
probs
.
device
)
success
=
torch
.
zeros
(
probs
.
shape
[
0
],
dtype
=
torch
.
bool
,
device
=
probs
.
device
)
return
batch_next_token_ids
,
success
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
batch_next_token_ids
=
torch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
-
1
)
success
=
torch
.
ones
(
probs
.
shape
[
0
],
dtype
=
torch
.
bool
,
device
=
probs
.
device
)
return
batch_next_token_ids
,
success
return
batch_next_token_ids
python/sglang/srt/managers/schedule_batch.py
View file @
2fa5cec7
...
...
@@ -360,6 +360,7 @@ class ScheduleBatch:
tree_cache
:
BasePrefixCache
forward_mode
:
ForwardMode
=
None
sampling_info
:
SamplingBatchInfo
=
None
# Batched arguments to model runner
input_ids
:
torch
.
Tensor
=
None
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
2fa5cec7
...
...
@@ -40,7 +40,7 @@ from vllm.model_executor.models import ModelRegistry
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.layers.attention_backend
import
FlashInferAttnBackend
,
TritonAttnBackend
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
SampleOutput
,
Sampler
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
(
...
...
@@ -516,21 +516,6 @@ class ModelRunner:
else
:
raise
ValueError
(
f
"Invaid forward mode:
{
batch
.
forward_mode
}
"
)
def
_check_sample_results
(
self
,
sample_output
:
SampleOutput
):
if
not
torch
.
all
(
sample_output
.
success
):
probs
=
sample_output
.
probs
batch_next_token_ids
=
sample_output
.
batch_next_token_ids
logging
.
warning
(
"Sampling failed, fallback to top_k=1 strategy"
)
probs
=
probs
.
masked_fill
(
torch
.
isnan
(
probs
),
0.0
)
argmax_ids
=
torch
.
argmax
(
probs
,
dim
=-
1
)
batch_next_token_ids
=
torch
.
where
(
sample_output
.
success
,
batch_next_token_ids
,
argmax_ids
)
sample_output
.
probs
=
probs
sample_output
.
batch_next_token_ids
=
batch_next_token_ids
return
sample_output
.
batch_next_token_ids
def
_apply_logits_bias
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
...
...
@@ -559,13 +544,16 @@ class ModelRunner:
def
sample
(
self
,
logits_output
:
LogitsProcessorOutput
,
batch
:
ScheduleBatch
)
->
torch
.
Tensor
:
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
batch
.
sampling_info
.
update_regex_vocab_mask
(
batch
)
batch
.
sampling_info
.
update_penalties
()
logits
=
self
.
_apply_logits_bias
(
logits_output
.
next_token_logits
,
batch
.
sampling_info
)
sample_output
=
self
.
sampler
(
logits
,
batch
.
sampling_info
)
return
self
.
_check_sample_results
(
sample_output
)
# Sample the next tokens.
next_token_ids
=
self
.
sampler
(
logits
,
batch
.
sampling_info
)
return
next_token_ids
@
lru_cache
()
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
2fa5cec7
...
...
@@ -34,56 +34,6 @@ class SamplingBatchInfo:
linear_penalties
:
torch
.
Tensor
=
None
scaling_penalties
:
torch
.
Tensor
=
None
def
__len__
(
self
):
return
len
(
self
.
temperatures
)
def
can_run_in_cuda_graph
(
self
):
# Vocab bias and min_ps are not supported in CUDA graph
return
(
self
.
logit_bias
is
None
and
self
.
linear_penalties
is
None
and
self
.
scaling_penalties
is
None
and
not
self
.
need_min_p_sampling
)
@
classmethod
def
dummy_one
(
cls
,
max_bs
:
int
,
vocab_size
:
int
):
ret
=
cls
(
vocab_size
=
vocab_size
)
with
torch
.
device
(
"cuda"
):
ret
.
temperatures
=
torch
.
ones
((
max_bs
,
1
),
dtype
=
torch
.
float
)
ret
.
top_ps
=
torch
.
ones
((
max_bs
,),
dtype
=
torch
.
float
)
ret
.
top_ks
=
torch
.
ones
((
max_bs
,),
dtype
=
torch
.
int
)
ret
.
vocab_mask
=
torch
.
zeros
((
max_bs
,
vocab_size
),
dtype
=
torch
.
bool
)
return
ret
def
__getitem__
(
self
,
key
):
if
isinstance
(
key
,
slice
):
# NOTE:This method is only used in CUDA graph
assert
self
.
can_run_in_cuda_graph
()
return
SamplingBatchInfo
(
vocab_size
=
self
.
vocab_size
,
temperatures
=
self
.
temperatures
[
key
],
top_ps
=
self
.
top_ps
[
key
],
top_ks
=
self
.
top_ks
[
key
],
vocab_mask
=
self
.
vocab_mask
[
key
],
)
else
:
raise
NotImplementedError
def
inplace_assign
(
self
,
bs
:
int
,
other
:
SamplingBatchInfo
):
# NOTE:This method is only used in CUDA graph
assert
self
.
can_run_in_cuda_graph
()
self
.
vocab_size
=
other
.
vocab_size
self
.
temperatures
[:
bs
]
=
other
.
temperatures
self
.
top_ps
[:
bs
]
=
other
.
top_ps
self
.
top_ks
[:
bs
]
=
other
.
top_ks
if
other
.
vocab_mask
is
None
:
self
.
vocab_mask
[:
bs
].
fill_
(
False
)
else
:
self
.
vocab_mask
[:
bs
]
=
other
.
vocab_mask
@
classmethod
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
reqs
=
batch
.
reqs
...
...
@@ -130,6 +80,9 @@ class SamplingBatchInfo:
return
ret
def
__len__
(
self
):
return
len
(
self
.
temperatures
)
def
update_penalties
(
self
):
self
.
scaling_penalties
=
None
self
.
linear_penalties
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment