Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
27acf63b
"sgl-kernel/python/vscode:/vscode.git/clone" did not exist on "8e03b641baf12962ad71970c2578d5c85bb1cf61"
Unverified
Commit
27acf63b
authored
Jan 25, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 25, 2025
Browse files
Use torch.compile for scaling penalty (#3133)
parent
da6f8081
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
29 deletions
+14
-29
benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py
...fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py
+0
-1
python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py
.../srt/sampling/penaltylib/penalizers/repetition_penalty.py
+10
-14
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+4
-14
No files found.
benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py
View file @
27acf63b
import
argparse
import
itertools
import
time
import
torch
import
triton
...
...
python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py
View file @
27acf63b
...
...
@@ -3,11 +3,16 @@ from typing import List
import
torch
from
sglang.srt.sampling.penaltylib.orchestrator
import
_BatchedPenalizer
,
_TokenIDs
from
sglang.srt.utils
import
is_cuda_available
from
sglang.srt.utils
import
get_compiler_backend
is_cuda
=
is_cuda_available
()
if
is_cuda
:
from
sgl_kernel
import
sampling_scaling_penalties
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
())
def
apply_scaling_penalties
(
logits
,
scaling_penalties
):
logits
[:]
=
torch
.
where
(
logits
>
0
,
logits
/
scaling_penalties
,
logits
*
scaling_penalties
,
)
class
BatchedRepetitionPenalizer
(
_BatchedPenalizer
):
...
...
@@ -61,16 +66,7 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
self
.
cumulated_repetition_penalties
[
mask
]
=
self
.
repetition_penalties
[
mask
]
def
_apply
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
is_cuda
:
return
sampling_scaling_penalties
(
logits
,
self
.
cumulated_repetition_penalties
)
else
:
return
torch
.
where
(
logits
>
0
,
logits
/
self
.
cumulated_repetition_penalties
,
logits
*
self
.
cumulated_repetition_penalties
,
)
apply_scaling_penalties
(
logits
,
self
.
cumulated_repetition_penalties
)
def
_filter
(
self
,
indices_to_keep
:
List
[
int
],
indices_tensor_to_keep
:
torch
.
Tensor
):
self
.
repetition_penalties
=
self
.
repetition_penalties
[
indices_tensor_to_keep
]
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
27acf63b
...
...
@@ -7,14 +7,11 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
import
torch
from
sglang.srt.utils
import
is_cuda_available
is_cuda
=
is_cuda_available
()
if
is_cuda
:
from
sgl_kernel
import
sampling_scaling_penalties
import
sglang.srt.sampling.penaltylib
as
penaltylib
from
sglang.srt.sampling.custom_logit_processor
import
CustomLogitProcessor
from
sglang.srt.sampling.penaltylib.penalizers.repetition_penalty
import
(
apply_scaling_penalties
,
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -386,14 +383,7 @@ class SamplingBatchInfo:
# repetition
if
self
.
scaling_penalties
is
not
None
:
if
is_cuda
:
logits
[:]
=
sampling_scaling_penalties
(
logits
,
self
.
scaling_penalties
)
else
:
logits
[:]
=
torch
.
where
(
logits
>
0
,
logits
/
self
.
scaling_penalties
,
logits
*
self
.
scaling_penalties
,
)
apply_scaling_penalties
(
logits
,
self
.
scaling_penalties
)
# Apply regex vocab_mask
if
self
.
vocab_mask
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment