Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
27acf63b
Unverified
Commit
27acf63b
authored
Jan 25, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 25, 2025
Browse files
Use torch.compile for scaling penalty (#3133)
parent
da6f8081
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
29 deletions
+14
-29
benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py
...fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py
+0
-1
python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py
.../srt/sampling/penaltylib/penalizers/repetition_penalty.py
+10
-14
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+4
-14
No files found.
benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py
View file @
27acf63b
import
argparse
import
argparse
import
itertools
import
itertools
import
time
import
torch
import
torch
import
triton
import
triton
...
...
python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py
View file @
27acf63b
...
@@ -3,11 +3,16 @@ from typing import List
...
@@ -3,11 +3,16 @@ from typing import List
import
torch
import
torch
from
sglang.srt.sampling.penaltylib.orchestrator
import
_BatchedPenalizer
,
_TokenIDs
from
sglang.srt.sampling.penaltylib.orchestrator
import
_BatchedPenalizer
,
_TokenIDs
from
sglang.srt.utils
import
is_cuda_available
from
sglang.srt.utils
import
get_compiler_backend
is_cuda
=
is_cuda_available
()
if
is_cuda
:
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
())
from
sgl_kernel
import
sampling_scaling_penalties
def
apply_scaling_penalties
(
logits
,
scaling_penalties
):
logits
[:]
=
torch
.
where
(
logits
>
0
,
logits
/
scaling_penalties
,
logits
*
scaling_penalties
,
)
class
BatchedRepetitionPenalizer
(
_BatchedPenalizer
):
class
BatchedRepetitionPenalizer
(
_BatchedPenalizer
):
...
@@ -61,16 +66,7 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
...
@@ -61,16 +66,7 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
self
.
cumulated_repetition_penalties
[
mask
]
=
self
.
repetition_penalties
[
mask
]
self
.
cumulated_repetition_penalties
[
mask
]
=
self
.
repetition_penalties
[
mask
]
def
_apply
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_apply
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
is_cuda
:
apply_scaling_penalties
(
logits
,
self
.
cumulated_repetition_penalties
)
return
sampling_scaling_penalties
(
logits
,
self
.
cumulated_repetition_penalties
)
else
:
return
torch
.
where
(
logits
>
0
,
logits
/
self
.
cumulated_repetition_penalties
,
logits
*
self
.
cumulated_repetition_penalties
,
)
def
_filter
(
self
,
indices_to_keep
:
List
[
int
],
indices_tensor_to_keep
:
torch
.
Tensor
):
def
_filter
(
self
,
indices_to_keep
:
List
[
int
],
indices_tensor_to_keep
:
torch
.
Tensor
):
self
.
repetition_penalties
=
self
.
repetition_penalties
[
indices_tensor_to_keep
]
self
.
repetition_penalties
=
self
.
repetition_penalties
[
indices_tensor_to_keep
]
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
27acf63b
...
@@ -7,14 +7,11 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
...
@@ -7,14 +7,11 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
import
torch
import
torch
from
sglang.srt.utils
import
is_cuda_available
is_cuda
=
is_cuda_available
()
if
is_cuda
:
from
sgl_kernel
import
sampling_scaling_penalties
import
sglang.srt.sampling.penaltylib
as
penaltylib
import
sglang.srt.sampling.penaltylib
as
penaltylib
from
sglang.srt.sampling.custom_logit_processor
import
CustomLogitProcessor
from
sglang.srt.sampling.custom_logit_processor
import
CustomLogitProcessor
from
sglang.srt.sampling.penaltylib.penalizers.repetition_penalty
import
(
apply_scaling_penalties
,
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -386,14 +383,7 @@ class SamplingBatchInfo:
...
@@ -386,14 +383,7 @@ class SamplingBatchInfo:
# repetition
# repetition
if
self
.
scaling_penalties
is
not
None
:
if
self
.
scaling_penalties
is
not
None
:
if
is_cuda
:
apply_scaling_penalties
(
logits
,
self
.
scaling_penalties
)
logits
[:]
=
sampling_scaling_penalties
(
logits
,
self
.
scaling_penalties
)
else
:
logits
[:]
=
torch
.
where
(
logits
>
0
,
logits
/
self
.
scaling_penalties
,
logits
*
self
.
scaling_penalties
,
)
# Apply regex vocab_mask
# Apply regex vocab_mask
if
self
.
vocab_mask
is
not
None
:
if
self
.
vocab_mask
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment