Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2fa5cec7
"src/array/vscode:/vscode.git/clone" did not exist on "3f138ebaafdca85b02407310432f93c37e147659"
Unverified
Commit
2fa5cec7
authored
Sep 16, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 16, 2024
Browse files
Simplify sampler and its error handling (#1441)
parent
27b557ae
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
32 additions
and
159 deletions
+32
-159
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+22
-91
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+6
-18
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+3
-50
No files found.
python/sglang/srt/layers/sampler.py
View file @
2fa5cec7
import
dataclasses
import
logging
import
logging
from
typing
import
Tuple
,
Union
from
typing
import
Union
import
torch
import
torch
from
flashinfer.sampling
import
(
from
flashinfer.sampling
import
(
...
@@ -9,43 +8,17 @@ from flashinfer.sampling import (
...
@@ -9,43 +8,17 @@ from flashinfer.sampling import (
top_k_top_p_sampling_from_probs
,
top_k_top_p_sampling_from_probs
,
top_p_renorm_prob
,
top_p_renorm_prob
,
)
)
from
torch.library
import
custom_op
as
torch_custom_op
from
torch
import
nn
from
vllm.model_executor.custom_op
import
CustomOp
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
# TODO: move this dict to another place
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
@
dataclasses
.
dataclass
class
Sampler
(
nn
.
Module
):
class
SampleOutput
:
def
forward
(
success
:
torch
.
Tensor
probs
:
torch
.
Tensor
batch_next_token_ids
:
torch
.
Tensor
class
Sampler
(
CustomOp
):
def
__init__
(
self
):
super
().
__init__
()
# FIXME: torch.multinomial has too many bugs
self
.
forward_native
=
self
.
forward_cuda
self
.
is_torch_compile
=
False
def
_get_probs
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
# Post process logits
logits
=
logits
.
contiguous
()
logits
.
div_
(
sampling_info
.
temperatures
)
if
self
.
is_torch_compile
:
# FIXME: Temporary workaround for unknown bugs in torch.compile
logits
.
add_
(
0
)
return
torch
.
softmax
(
logits
,
dim
=-
1
)
def
forward_cuda
(
self
,
self
,
logits
:
Union
[
torch
.
Tensor
,
LogitsProcessorOutput
],
logits
:
Union
[
torch
.
Tensor
,
LogitsProcessorOutput
],
sampling_info
:
SamplingBatchInfo
,
sampling_info
:
SamplingBatchInfo
,
...
@@ -53,7 +26,15 @@ class Sampler(CustomOp):
...
@@ -53,7 +26,15 @@ class Sampler(CustomOp):
if
isinstance
(
logits
,
LogitsProcessorOutput
):
if
isinstance
(
logits
,
LogitsProcessorOutput
):
logits
=
logits
.
next_token_logits
logits
=
logits
.
next_token_logits
probs
=
self
.
_get_probs
(
logits
,
sampling_info
)
# Post process logits
logits
.
div_
(
sampling_info
.
temperatures
)
probs
=
logits
[:]
=
torch
.
softmax
(
logits
,
dim
=-
1
)
if
torch
.
any
(
torch
.
isnan
(
probs
)):
logger
.
warning
(
"Detected errors during sampling! NaN in the probability."
)
probs
=
torch
.
where
(
torch
.
isnan
(
probs
),
torch
.
full_like
(
probs
,
1e-10
),
probs
)
if
global_server_args_dict
[
"sampling_backend"
]
==
"flashinfer"
:
if
global_server_args_dict
[
"sampling_backend"
]
==
"flashinfer"
:
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
...
@@ -67,12 +48,16 @@ class Sampler(CustomOp):
...
@@ -67,12 +48,16 @@ class Sampler(CustomOp):
probs
,
uniform_samples
,
sampling_info
.
min_ps
probs
,
uniform_samples
,
sampling_info
.
min_ps
)
)
else
:
else
:
batch_next_token_ids
,
success
=
flashinfer_
top_k_top_p
(
batch_next_token_ids
,
success
=
top_k_top_p
_sampling_from_probs
(
probs
,
uniform_samples
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
probs
,
uniform_samples
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
)
)
if
not
torch
.
all
(
success
):
logger
.
warning
(
"Detected errors during sampling!"
)
batch_next_token_ids
=
torch
.
zeros_like
(
batch_next_token_ids
)
elif
global_server_args_dict
[
"sampling_backend"
]
==
"pytorch"
:
elif
global_server_args_dict
[
"sampling_backend"
]
==
"pytorch"
:
# Here we provide a slower fallback implementation.
# Here we provide a slower fallback implementation.
batch_next_token_ids
,
success
=
top_k_top_p_min_p_sampling_from_probs_torch
(
batch_next_token_ids
=
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
,
sampling_info
.
min_ps
probs
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
,
sampling_info
.
min_ps
)
)
else
:
else
:
...
@@ -80,48 +65,7 @@ class Sampler(CustomOp):
...
@@ -80,48 +65,7 @@ class Sampler(CustomOp):
f
"Invalid sampling backend:
{
global_server_args_dict
[
'sampling_backend'
]
}
"
f
"Invalid sampling backend:
{
global_server_args_dict
[
'sampling_backend'
]
}
"
)
)
return
SampleOutput
(
success
,
probs
,
batch_next_token_ids
)
return
batch_next_token_ids
def
forward_native
(
self
,
logits
:
Union
[
torch
.
Tensor
,
LogitsProcessorOutput
],
sampling_info
:
SamplingBatchInfo
,
):
if
isinstance
(
logits
,
LogitsProcessorOutput
):
logits
=
logits
.
next_token_logits
probs
=
self
.
_get_probs
(
logits
,
sampling_info
)
batch_next_token_ids
,
success
=
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
,
sampling_info
.
min_ps
)
return
SampleOutput
(
success
,
probs
,
batch_next_token_ids
)
@
torch_custom_op
(
"my_lib::flashinfer_top_k_top_p"
,
mutates_args
=
{})
def
flashinfer_top_k_top_p
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
top_ks
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# NOTE: we do not use min_p neither in CUDA nor in torch.compile
return
top_k_top_p_sampling_from_probs
(
probs
,
uniform_samples
,
top_ks
,
top_ps
)
@
flashinfer_top_k_top_p
.
register_fake
def
_
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
top_ks
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
bs
=
probs
.
shape
[
0
]
return
(
torch
.
ones
(
bs
,
dtype
=
torch
.
bool
,
device
=
probs
.
device
),
torch
.
zeros
(
bs
,
dtype
=
torch
.
int32
,
device
=
probs
.
device
),
)
def
top_k_top_p_min_p_sampling_from_probs_torch
(
def
top_k_top_p_min_p_sampling_from_probs_torch
(
...
@@ -141,19 +85,6 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
...
@@ -141,19 +85,6 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
]
=
0.0
]
=
0.0
probs_sort
[
probs_sort
<
min_p_thresholds
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
[
probs_sort
<
min_p_thresholds
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
try
:
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
# FIXME: torch.multiomial does not support num_samples = 1
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
2
,
replacement
=
True
)[
:,
:
1
]
except
RuntimeError
as
e
:
logger
.
warning
(
f
"Sampling error:
{
e
}
"
)
batch_next_token_ids
=
torch
.
zeros
(
(
probs_sort
.
shape
[
0
],),
dtype
=
torch
.
int32
,
device
=
probs
.
device
)
success
=
torch
.
zeros
(
probs
.
shape
[
0
],
dtype
=
torch
.
bool
,
device
=
probs
.
device
)
return
batch_next_token_ids
,
success
batch_next_token_ids
=
torch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
-
1
)
batch_next_token_ids
=
torch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
-
1
)
success
=
torch
.
ones
(
probs
.
shape
[
0
],
dtype
=
torch
.
bool
,
device
=
probs
.
device
)
return
batch_next_token_ids
return
batch_next_token_ids
,
success
python/sglang/srt/managers/schedule_batch.py
View file @
2fa5cec7
...
@@ -360,6 +360,7 @@ class ScheduleBatch:
...
@@ -360,6 +360,7 @@ class ScheduleBatch:
tree_cache
:
BasePrefixCache
tree_cache
:
BasePrefixCache
forward_mode
:
ForwardMode
=
None
forward_mode
:
ForwardMode
=
None
sampling_info
:
SamplingBatchInfo
=
None
# Batched arguments to model runner
# Batched arguments to model runner
input_ids
:
torch
.
Tensor
=
None
input_ids
:
torch
.
Tensor
=
None
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
2fa5cec7
...
@@ -40,7 +40,7 @@ from vllm.model_executor.models import ModelRegistry
...
@@ -40,7 +40,7 @@ from vllm.model_executor.models import ModelRegistry
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.layers.attention_backend
import
FlashInferAttnBackend
,
TritonAttnBackend
from
sglang.srt.layers.attention_backend
import
FlashInferAttnBackend
,
TritonAttnBackend
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
SampleOutput
,
Sampler
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
(
from
sglang.srt.mem_cache.memory_pool
import
(
...
@@ -516,21 +516,6 @@ class ModelRunner:
...
@@ -516,21 +516,6 @@ class ModelRunner:
else
:
else
:
raise
ValueError
(
f
"Invaid forward mode:
{
batch
.
forward_mode
}
"
)
raise
ValueError
(
f
"Invaid forward mode:
{
batch
.
forward_mode
}
"
)
def
_check_sample_results
(
self
,
sample_output
:
SampleOutput
):
if
not
torch
.
all
(
sample_output
.
success
):
probs
=
sample_output
.
probs
batch_next_token_ids
=
sample_output
.
batch_next_token_ids
logging
.
warning
(
"Sampling failed, fallback to top_k=1 strategy"
)
probs
=
probs
.
masked_fill
(
torch
.
isnan
(
probs
),
0.0
)
argmax_ids
=
torch
.
argmax
(
probs
,
dim
=-
1
)
batch_next_token_ids
=
torch
.
where
(
sample_output
.
success
,
batch_next_token_ids
,
argmax_ids
)
sample_output
.
probs
=
probs
sample_output
.
batch_next_token_ids
=
batch_next_token_ids
return
sample_output
.
batch_next_token_ids
def
_apply_logits_bias
(
def
_apply_logits_bias
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
):
...
@@ -559,13 +544,16 @@ class ModelRunner:
...
@@ -559,13 +544,16 @@ class ModelRunner:
def
sample
(
def
sample
(
self
,
logits_output
:
LogitsProcessorOutput
,
batch
:
ScheduleBatch
self
,
logits_output
:
LogitsProcessorOutput
,
batch
:
ScheduleBatch
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
batch
.
sampling_info
.
update_regex_vocab_mask
(
batch
)
batch
.
sampling_info
.
update_regex_vocab_mask
(
batch
)
batch
.
sampling_info
.
update_penalties
()
batch
.
sampling_info
.
update_penalties
()
logits
=
self
.
_apply_logits_bias
(
logits
=
self
.
_apply_logits_bias
(
logits_output
.
next_token_logits
,
batch
.
sampling_info
logits_output
.
next_token_logits
,
batch
.
sampling_info
)
)
sample_output
=
self
.
sampler
(
logits
,
batch
.
sampling_info
)
return
self
.
_check_sample_results
(
sample_output
)
# Sample the next tokens.
next_token_ids
=
self
.
sampler
(
logits
,
batch
.
sampling_info
)
return
next_token_ids
@
lru_cache
()
@
lru_cache
()
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
2fa5cec7
...
@@ -34,56 +34,6 @@ class SamplingBatchInfo:
...
@@ -34,56 +34,6 @@ class SamplingBatchInfo:
linear_penalties
:
torch
.
Tensor
=
None
linear_penalties
:
torch
.
Tensor
=
None
scaling_penalties
:
torch
.
Tensor
=
None
scaling_penalties
:
torch
.
Tensor
=
None
def
__len__
(
self
):
return
len
(
self
.
temperatures
)
def
can_run_in_cuda_graph
(
self
):
# Vocab bias and min_ps are not supported in CUDA graph
return
(
self
.
logit_bias
is
None
and
self
.
linear_penalties
is
None
and
self
.
scaling_penalties
is
None
and
not
self
.
need_min_p_sampling
)
@
classmethod
def
dummy_one
(
cls
,
max_bs
:
int
,
vocab_size
:
int
):
ret
=
cls
(
vocab_size
=
vocab_size
)
with
torch
.
device
(
"cuda"
):
ret
.
temperatures
=
torch
.
ones
((
max_bs
,
1
),
dtype
=
torch
.
float
)
ret
.
top_ps
=
torch
.
ones
((
max_bs
,),
dtype
=
torch
.
float
)
ret
.
top_ks
=
torch
.
ones
((
max_bs
,),
dtype
=
torch
.
int
)
ret
.
vocab_mask
=
torch
.
zeros
((
max_bs
,
vocab_size
),
dtype
=
torch
.
bool
)
return
ret
def
__getitem__
(
self
,
key
):
if
isinstance
(
key
,
slice
):
# NOTE:This method is only used in CUDA graph
assert
self
.
can_run_in_cuda_graph
()
return
SamplingBatchInfo
(
vocab_size
=
self
.
vocab_size
,
temperatures
=
self
.
temperatures
[
key
],
top_ps
=
self
.
top_ps
[
key
],
top_ks
=
self
.
top_ks
[
key
],
vocab_mask
=
self
.
vocab_mask
[
key
],
)
else
:
raise
NotImplementedError
def
inplace_assign
(
self
,
bs
:
int
,
other
:
SamplingBatchInfo
):
# NOTE:This method is only used in CUDA graph
assert
self
.
can_run_in_cuda_graph
()
self
.
vocab_size
=
other
.
vocab_size
self
.
temperatures
[:
bs
]
=
other
.
temperatures
self
.
top_ps
[:
bs
]
=
other
.
top_ps
self
.
top_ks
[:
bs
]
=
other
.
top_ks
if
other
.
vocab_mask
is
None
:
self
.
vocab_mask
[:
bs
].
fill_
(
False
)
else
:
self
.
vocab_mask
[:
bs
]
=
other
.
vocab_mask
@
classmethod
@
classmethod
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
reqs
=
batch
.
reqs
reqs
=
batch
.
reqs
...
@@ -130,6 +80,9 @@ class SamplingBatchInfo:
...
@@ -130,6 +80,9 @@ class SamplingBatchInfo:
return
ret
return
ret
def
__len__
(
self
):
return
len
(
self
.
temperatures
)
def
update_penalties
(
self
):
def
update_penalties
(
self
):
self
.
scaling_penalties
=
None
self
.
scaling_penalties
=
None
self
.
linear_penalties
=
None
self
.
linear_penalties
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment