Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a5a134f3
"vscode:/vscode.git/clone" did not exist on "70f88eecccb54e344bd8ada1698b4e62ca7d79ff"
Unverified
Commit
a5a134f3
authored
Sep 02, 2024
by
Liangsheng Yin
Committed by
GitHub
Sep 02, 2024
Browse files
Fix bugs in sampler with CUDA graph / torch.compile (#1306)
parent
2561ed01
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
48 additions
and
26 deletions
+48
-26
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+34
-10
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+2
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-1
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+11
-15
No files found.
python/sglang/srt/layers/sampler.py
View file @
a5a134f3
import
dataclasses
import
dataclasses
import
logging
import
logging
from
typing
import
Union
from
typing
import
Tuple
,
Union
import
torch
import
torch
from
flashinfer.sampling
import
(
from
flashinfer.sampling
import
(
...
@@ -9,6 +9,7 @@ from flashinfer.sampling import (
...
@@ -9,6 +9,7 @@ from flashinfer.sampling import (
top_k_top_p_sampling_from_probs
,
top_k_top_p_sampling_from_probs
,
top_p_renorm_prob
,
top_p_renorm_prob
,
)
)
from
torch.library
import
custom_op
as
torch_custom_op
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
...
@@ -30,6 +31,9 @@ class SampleOutput:
...
@@ -30,6 +31,9 @@ class SampleOutput:
class
Sampler
(
CustomOp
):
class
Sampler
(
CustomOp
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
# FIXME: torch.multinomial has too many bugs
self
.
forward_native
=
self
.
forward_cuda
self
.
is_torch_compile
=
False
def
_apply_penalties
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
def
_apply_penalties
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
# min-token, presence, frequency
# min-token, presence, frequency
...
@@ -46,16 +50,11 @@ class Sampler(CustomOp):
...
@@ -46,16 +50,11 @@ class Sampler(CustomOp):
return
logits
return
logits
def
_get_probs
(
def
_get_probs
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
,
is_torch_compile
:
bool
=
False
,
):
# Post process logits
# Post process logits
logits
=
logits
.
contiguous
()
logits
=
logits
.
contiguous
()
logits
.
div_
(
sampling_info
.
temperatures
)
logits
.
div_
(
sampling_info
.
temperatures
)
if
is_torch_compile
:
if
self
.
is_torch_compile
:
# FIXME: Temporary workaround for unknown bugs in torch.compile
# FIXME: Temporary workaround for unknown bugs in torch.compile
logits
.
add_
(
0
)
logits
.
add_
(
0
)
...
@@ -91,7 +90,7 @@ class Sampler(CustomOp):
...
@@ -91,7 +90,7 @@ class Sampler(CustomOp):
probs
,
uniform_samples
,
sampling_info
.
min_ps
probs
,
uniform_samples
,
sampling_info
.
min_ps
)
)
else
:
else
:
batch_next_token_ids
,
success
=
top_k_top_p
_sampling_from_probs
(
batch_next_token_ids
,
success
=
flashinfer_
top_k_top_p
(
probs
,
uniform_samples
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
probs
,
uniform_samples
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
)
)
else
:
else
:
...
@@ -110,7 +109,7 @@ class Sampler(CustomOp):
...
@@ -110,7 +109,7 @@ class Sampler(CustomOp):
if
isinstance
(
logits
,
LogitsProcessorOutput
):
if
isinstance
(
logits
,
LogitsProcessorOutput
):
logits
=
logits
.
next_token_logits
logits
=
logits
.
next_token_logits
probs
=
self
.
_get_probs
(
logits
,
sampling_info
,
is_torch_compile
=
True
)
probs
=
self
.
_get_probs
(
logits
,
sampling_info
)
batch_next_token_ids
,
success
=
top_k_top_p_min_p_sampling_from_probs_torch
(
batch_next_token_ids
,
success
=
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
,
sampling_info
.
min_ps
probs
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
,
sampling_info
.
min_ps
...
@@ -119,6 +118,31 @@ class Sampler(CustomOp):
...
@@ -119,6 +118,31 @@ class Sampler(CustomOp):
return
SampleOutput
(
success
,
probs
,
batch_next_token_ids
)
return
SampleOutput
(
success
,
probs
,
batch_next_token_ids
)
@
torch_custom_op
(
"my_lib::flashinfer_top_k_top_p"
,
mutates_args
=
{})
def
flashinfer_top_k_top_p
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
top_ks
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# NOTE: we do not use min_p neither in CUDA nor in torch.compile
return
top_k_top_p_sampling_from_probs
(
probs
,
uniform_samples
,
top_ks
,
top_ps
)
@
flashinfer_top_k_top_p
.
register_fake
def
_
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
top_ks
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
bs
=
probs
.
shape
[
0
]
return
(
torch
.
ones
(
bs
,
dtype
=
torch
.
bool
,
device
=
probs
.
device
),
torch
.
zeros
(
bs
,
dtype
=
torch
.
int32
,
device
=
probs
.
device
),
)
def
top_k_top_p_min_p_sampling_from_probs_torch
(
def
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
top_ks
:
torch
.
Tensor
,
top_ks
:
torch
.
Tensor
,
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
a5a134f3
...
@@ -46,8 +46,10 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
...
@@ -46,8 +46,10 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
if
isinstance
(
sub
,
CustomOp
):
if
isinstance
(
sub
,
CustomOp
):
if
reverse
:
if
reverse
:
sub
.
_forward_method
=
sub
.
forward_cuda
sub
.
_forward_method
=
sub
.
forward_cuda
setattr
(
sub
,
"is_torch_compile"
,
False
)
else
:
else
:
sub
.
_forward_method
=
sub
.
forward_native
sub
.
_forward_method
=
sub
.
forward_native
setattr
(
sub
,
"is_torch_compile"
,
True
)
if
isinstance
(
sub
,
torch
.
nn
.
Module
):
if
isinstance
(
sub
,
torch
.
nn
.
Module
):
_to_torch
(
sub
,
reverse
)
_to_torch
(
sub
,
reverse
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
a5a134f3
...
@@ -523,7 +523,7 @@ class ModelRunner:
...
@@ -523,7 +523,7 @@ class ModelRunner:
if
(
if
(
self
.
cuda_graph_runner
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
len
(
batch
.
reqs
))
and
self
.
cuda_graph_runner
.
can_run
(
len
(
batch
.
reqs
))
and
not
batch
.
sampling_info
.
has_bias
()
and
batch
.
sampling_info
.
can_run_in_cuda_graph
()
):
):
return
self
.
cuda_graph_runner
.
replay
(
batch
)
return
self
.
cuda_graph_runner
.
replay
(
batch
)
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
a5a134f3
...
@@ -34,12 +34,14 @@ class SamplingBatchInfo:
...
@@ -34,12 +34,14 @@ class SamplingBatchInfo:
linear_penalties
:
torch
.
Tensor
=
None
linear_penalties
:
torch
.
Tensor
=
None
scaling_penalties
:
torch
.
Tensor
=
None
scaling_penalties
:
torch
.
Tensor
=
None
def
has_bias
(
self
):
def
can_run_in_cuda_graph
(
self
):
# Vocab bias and min_ps are not supported in CUDA graph
return
(
return
(
self
.
logit_bias
is
not
None
self
.
logit_bias
is
None
or
self
.
vocab_mask
is
not
None
and
self
.
vocab_mask
is
None
or
self
.
linear_penalties
is
not
None
and
self
.
linear_penalties
is
None
or
self
.
scaling_penalties
is
not
None
and
self
.
scaling_penalties
is
None
and
not
self
.
need_min_p_sampling
)
)
@
classmethod
@
classmethod
...
@@ -48,35 +50,29 @@ class SamplingBatchInfo:
...
@@ -48,35 +50,29 @@ class SamplingBatchInfo:
ret
.
temperatures
=
torch
.
ones
((
max_bs
,
1
),
dtype
=
torch
.
float
,
device
=
"cuda"
)
ret
.
temperatures
=
torch
.
ones
((
max_bs
,
1
),
dtype
=
torch
.
float
,
device
=
"cuda"
)
ret
.
top_ps
=
torch
.
ones
((
max_bs
,),
dtype
=
torch
.
float
,
device
=
"cuda"
)
ret
.
top_ps
=
torch
.
ones
((
max_bs
,),
dtype
=
torch
.
float
,
device
=
"cuda"
)
ret
.
top_ks
=
torch
.
ones
((
max_bs
,),
dtype
=
torch
.
int
,
device
=
"cuda"
)
ret
.
top_ks
=
torch
.
ones
((
max_bs
,),
dtype
=
torch
.
int
,
device
=
"cuda"
)
ret
.
min_ps
=
torch
.
zeros
((
max_bs
,),
dtype
=
torch
.
float
,
device
=
"cuda"
)
return
ret
return
ret
def
__getitem__
(
self
,
key
):
def
__getitem__
(
self
,
key
):
if
isinstance
(
key
,
slice
):
if
isinstance
(
key
,
slice
):
# NOTE:
We do not use cuda graph when there is bias tensors
# NOTE:
This method is only used in CUDA graph
assert
not
self
.
has_bias
()
assert
self
.
can_run_in_cuda_graph
()
return
SamplingBatchInfo
(
return
SamplingBatchInfo
(
vocab_size
=
self
.
vocab_size
,
vocab_size
=
self
.
vocab_size
,
temperatures
=
self
.
temperatures
[
key
],
temperatures
=
self
.
temperatures
[
key
],
top_ps
=
self
.
top_ps
[
key
],
top_ps
=
self
.
top_ps
[
key
],
top_ks
=
self
.
top_ks
[
key
],
top_ks
=
self
.
top_ks
[
key
],
min_ps
=
self
.
min_ps
[
key
],
need_min_p_sampling
=
self
.
need_min_p_sampling
,
)
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
def
inplace_assign
(
self
,
bs
:
int
,
other
:
SamplingBatchInfo
):
def
inplace_assign
(
self
,
bs
:
int
,
other
:
SamplingBatchInfo
):
# NOTE:
We do not use cuda graph when there is bias tensors
# NOTE:
This method is only used in CUDA graph
assert
not
self
.
has_bias
()
assert
self
.
can_run_in_cuda_graph
()
self
.
vocab_size
=
other
.
vocab_size
self
.
vocab_size
=
other
.
vocab_size
self
.
need_min_p_sampling
=
other
.
need_min_p_sampling
self
.
temperatures
[:
bs
]
=
other
.
temperatures
self
.
temperatures
[:
bs
]
=
other
.
temperatures
self
.
top_ps
[:
bs
]
=
other
.
top_ps
self
.
top_ps
[:
bs
]
=
other
.
top_ps
self
.
top_ks
[:
bs
]
=
other
.
top_ks
self
.
top_ks
[:
bs
]
=
other
.
top_ks
self
.
min_ps
[:
bs
]
=
other
.
min_ps
@
classmethod
@
classmethod
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment