Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a5a134f3
Unverified
Commit
a5a134f3
authored
Sep 02, 2024
by
Liangsheng Yin
Committed by
GitHub
Sep 02, 2024
Browse files
Fix bugs in sampler with CUDA graph / torch.compile (#1306)
parent
2561ed01
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
48 additions
and
26 deletions
+48
-26
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+34
-10
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+2
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-1
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+11
-15
No files found.
python/sglang/srt/layers/sampler.py
View file @
a5a134f3
import
dataclasses
import
logging
from
typing
import
Union
from
typing
import
Tuple
,
Union
import
torch
from
flashinfer.sampling
import
(
...
...
@@ -9,6 +9,7 @@ from flashinfer.sampling import (
top_k_top_p_sampling_from_probs
,
top_p_renorm_prob
,
)
from
torch.library
import
custom_op
as
torch_custom_op
from
vllm.model_executor.custom_op
import
CustomOp
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
...
...
@@ -30,6 +31,9 @@ class SampleOutput:
class
Sampler
(
CustomOp
):
def
__init__
(
self
):
super
().
__init__
()
# FIXME: torch.multinomial has too many bugs
self
.
forward_native
=
self
.
forward_cuda
self
.
is_torch_compile
=
False
def
_apply_penalties
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
# min-token, presence, frequency
...
...
@@ -46,16 +50,11 @@ class Sampler(CustomOp):
return
logits
def
_get_probs
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
,
is_torch_compile
:
bool
=
False
,
):
def
_get_probs
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
# Post process logits
logits
=
logits
.
contiguous
()
logits
.
div_
(
sampling_info
.
temperatures
)
if
is_torch_compile
:
if
self
.
is_torch_compile
:
# FIXME: Temporary workaround for unknown bugs in torch.compile
logits
.
add_
(
0
)
...
...
@@ -91,7 +90,7 @@ class Sampler(CustomOp):
probs
,
uniform_samples
,
sampling_info
.
min_ps
)
else
:
batch_next_token_ids
,
success
=
top_k_top_p
_sampling_from_probs
(
batch_next_token_ids
,
success
=
flashinfer_
top_k_top_p
(
probs
,
uniform_samples
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
)
else
:
...
...
@@ -110,7 +109,7 @@ class Sampler(CustomOp):
if
isinstance
(
logits
,
LogitsProcessorOutput
):
logits
=
logits
.
next_token_logits
probs
=
self
.
_get_probs
(
logits
,
sampling_info
,
is_torch_compile
=
True
)
probs
=
self
.
_get_probs
(
logits
,
sampling_info
)
batch_next_token_ids
,
success
=
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
,
sampling_info
.
min_ps
...
...
@@ -119,6 +118,31 @@ class Sampler(CustomOp):
return
SampleOutput
(
success
,
probs
,
batch_next_token_ids
)
@
torch_custom_op
(
"my_lib::flashinfer_top_k_top_p"
,
mutates_args
=
{})
def
flashinfer_top_k_top_p
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
top_ks
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# NOTE: we do not use min_p neither in CUDA nor in torch.compile
return
top_k_top_p_sampling_from_probs
(
probs
,
uniform_samples
,
top_ks
,
top_ps
)
@
flashinfer_top_k_top_p
.
register_fake
def
_
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
top_ks
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
bs
=
probs
.
shape
[
0
]
return
(
torch
.
ones
(
bs
,
dtype
=
torch
.
bool
,
device
=
probs
.
device
),
torch
.
zeros
(
bs
,
dtype
=
torch
.
int32
,
device
=
probs
.
device
),
)
def
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
:
torch
.
Tensor
,
top_ks
:
torch
.
Tensor
,
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
a5a134f3
...
...
@@ -46,8 +46,10 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
if
isinstance
(
sub
,
CustomOp
):
if
reverse
:
sub
.
_forward_method
=
sub
.
forward_cuda
setattr
(
sub
,
"is_torch_compile"
,
False
)
else
:
sub
.
_forward_method
=
sub
.
forward_native
setattr
(
sub
,
"is_torch_compile"
,
True
)
if
isinstance
(
sub
,
torch
.
nn
.
Module
):
_to_torch
(
sub
,
reverse
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
a5a134f3
...
...
@@ -523,7 +523,7 @@ class ModelRunner:
if
(
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
len
(
batch
.
reqs
))
and
not
batch
.
sampling_info
.
has_bias
()
and
batch
.
sampling_info
.
can_run_in_cuda_graph
()
):
return
self
.
cuda_graph_runner
.
replay
(
batch
)
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
a5a134f3
...
...
@@ -34,12 +34,14 @@ class SamplingBatchInfo:
linear_penalties
:
torch
.
Tensor
=
None
scaling_penalties
:
torch
.
Tensor
=
None
def
has_bias
(
self
):
def
can_run_in_cuda_graph
(
self
):
# Vocab bias and min_ps are not supported in CUDA graph
return
(
self
.
logit_bias
is
not
None
or
self
.
vocab_mask
is
not
None
or
self
.
linear_penalties
is
not
None
or
self
.
scaling_penalties
is
not
None
self
.
logit_bias
is
None
and
self
.
vocab_mask
is
None
and
self
.
linear_penalties
is
None
and
self
.
scaling_penalties
is
None
and
not
self
.
need_min_p_sampling
)
@
classmethod
...
...
@@ -48,35 +50,29 @@ class SamplingBatchInfo:
ret
.
temperatures
=
torch
.
ones
((
max_bs
,
1
),
dtype
=
torch
.
float
,
device
=
"cuda"
)
ret
.
top_ps
=
torch
.
ones
((
max_bs
,),
dtype
=
torch
.
float
,
device
=
"cuda"
)
ret
.
top_ks
=
torch
.
ones
((
max_bs
,),
dtype
=
torch
.
int
,
device
=
"cuda"
)
ret
.
min_ps
=
torch
.
zeros
((
max_bs
,),
dtype
=
torch
.
float
,
device
=
"cuda"
)
return
ret
def
__getitem__
(
self
,
key
):
if
isinstance
(
key
,
slice
):
# NOTE:
We do not use cuda graph when there is bias tensors
assert
not
self
.
has_bias
()
# NOTE:
This method is only used in CUDA graph
assert
self
.
can_run_in_cuda_graph
()
return
SamplingBatchInfo
(
vocab_size
=
self
.
vocab_size
,
temperatures
=
self
.
temperatures
[
key
],
top_ps
=
self
.
top_ps
[
key
],
top_ks
=
self
.
top_ks
[
key
],
min_ps
=
self
.
min_ps
[
key
],
need_min_p_sampling
=
self
.
need_min_p_sampling
,
)
else
:
raise
NotImplementedError
def
inplace_assign
(
self
,
bs
:
int
,
other
:
SamplingBatchInfo
):
# NOTE:
We do not use cuda graph when there is bias tensors
assert
not
self
.
has_bias
()
# NOTE:
This method is only used in CUDA graph
assert
self
.
can_run_in_cuda_graph
()
self
.
vocab_size
=
other
.
vocab_size
self
.
need_min_p_sampling
=
other
.
need_min_p_sampling
self
.
temperatures
[:
bs
]
=
other
.
temperatures
self
.
top_ps
[:
bs
]
=
other
.
top_ps
self
.
top_ks
[:
bs
]
=
other
.
top_ks
self
.
min_ps
[:
bs
]
=
other
.
min_ps
@
classmethod
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment