Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d08c77c4
Unverified
Commit
d08c77c4
authored
Jan 13, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Jan 13, 2025
Browse files
Sampling penalties memory interface (#2870)
parent
c1e097ca
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
256 additions
and
46 deletions
+256
-46
benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py
...fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py
+2
-1
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py
.../srt/sampling/penaltylib/penalizers/repetition_penalty.py
+15
-5
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+14
-5
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+4
-0
sgl-kernel/benchmark/benchmark_sampling_scaling_penalties.py
sgl-kernel/benchmark/benchmark_sampling_scaling_penalties.py
+159
-0
sgl-kernel/tests/test_moe_align.py
sgl-kernel/tests/test_moe_align.py
+61
-34
No files found.
benchmark/kernels/fused_moe_triton/benchmark_moe_align_blocks.py
→
benchmark/kernels/fused_moe_triton/benchmark_
deepseekv3_
moe_align_blocks.py
View file @
d08c77c4
...
...
@@ -222,8 +222,9 @@ configs = list(itertools.product(batch_size_range, seq_length_range))
def
benchmark
(
batch_size
,
seq_len
,
provider
):
num_experts
=
256
block_size
=
128
topk
=
8
topk_ids
=
torch
.
randint
(
0
,
num_experts
,
(
batch_size
,
seq_len
),
dtype
=
torch
.
int32
,
device
=
"cuda"
0
,
num_experts
,
(
batch_size
*
seq_len
,
topk
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
...
...
python/pyproject.toml
View file @
d08c77c4
...
...
@@ -27,7 +27,7 @@ runtime_common = [
]
srt
=
[
"sglang[runtime_common]"
,
"cuda-python"
,
"sgl-kernel>=0.0.2.post1
1
"
,
"torch"
,
"vllm>=0.6.3.post1,<=0.6.4.post1"
,
"sgl-kernel>=0.0.2.post1
2
"
,
"torch"
,
"vllm>=0.6.3.post1,<=0.6.4.post1"
,
"flashinfer==0.1.6"
]
...
...
python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py
View file @
d08c77c4
...
...
@@ -3,6 +3,11 @@ from typing import List
import
torch
from
sglang.srt.sampling.penaltylib.orchestrator
import
_BatchedPenalizer
,
_TokenIDs
from
sglang.srt.utils
import
is_cuda_available
is_cuda
=
is_cuda_available
()
if
is_cuda
:
from
sgl_kernel
import
sampling_scaling_penalties
class
BatchedRepetitionPenalizer
(
_BatchedPenalizer
):
...
...
@@ -56,11 +61,16 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
self
.
cumulated_repetition_penalties
[
mask
]
=
self
.
repetition_penalties
[
mask
]
def
_apply
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
where
(
logits
>
0
,
logits
/
self
.
cumulated_repetition_penalties
,
logits
*
self
.
cumulated_repetition_penalties
,
)
if
is_cuda
:
return
sampling_scaling_penalties
(
logits
,
self
.
cumulated_repetition_penalties
)
else
:
return
torch
.
where
(
logits
>
0
,
logits
/
self
.
cumulated_repetition_penalties
,
logits
*
self
.
cumulated_repetition_penalties
,
)
def
_filter
(
self
,
indices_to_keep
:
List
[
int
],
indices_tensor_to_keep
:
torch
.
Tensor
):
self
.
repetition_penalties
=
self
.
repetition_penalties
[
indices_tensor_to_keep
]
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
d08c77c4
...
...
@@ -7,6 +7,12 @@ from typing import TYPE_CHECKING, Callable, List, Optional
import
torch
from
sglang.srt.utils
import
is_cuda_available
is_cuda
=
is_cuda_available
()
if
is_cuda
:
from
sgl_kernel
import
sampling_scaling_penalties
import
sglang.srt.sampling.penaltylib
as
penaltylib
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -245,11 +251,14 @@ class SamplingBatchInfo:
# repetition
if
self
.
scaling_penalties
is
not
None
:
logits
[:]
=
torch
.
where
(
logits
>
0
,
logits
/
self
.
scaling_penalties
,
logits
*
self
.
scaling_penalties
,
)
if
is_cuda
:
logits
[:]
=
sampling_scaling_penalties
(
logits
,
self
.
scaling_penalties
)
else
:
logits
[:]
=
torch
.
where
(
logits
>
0
,
logits
/
self
.
scaling_penalties
,
logits
*
self
.
scaling_penalties
,
)
# Apply regex vocab_mask
if
self
.
vocab_mask
is
not
None
:
...
...
python/sglang/srt/utils.py
View file @
d08c77c4
...
...
@@ -97,6 +97,10 @@ def is_flashinfer_available():
return
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
def
is_cuda_available
():
return
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
def
is_ipv6
(
address
):
try
:
ipaddress
.
IPv6Address
(
address
)
...
...
sgl-kernel/benchmark/benchmark_sampling_scaling_penalties.py
0 → 100644
View file @
d08c77c4
import
itertools
import
torch
import
triton
from
sgl_kernel
import
sampling_scaling_penalties
def
sampling_scaling_penalties_naive
(
logits
,
scaling_penalties
):
return
torch
.
where
(
logits
>
0
,
logits
/
scaling_penalties
,
logits
*
scaling_penalties
)
def
sampling_scaling_penalties_kernel
(
logits
,
scaling_penalties
):
return
sampling_scaling_penalties
(
logits
,
scaling_penalties
)
def
test_memory
(
func
,
_iter
):
total_mem
=
[]
for
_
in
range
(
_iter
):
torch
.
cuda
.
memory
.
reset_peak_memory_stats
()
func
()
mem
=
torch
.
cuda
.
max_memory_allocated
()
/
(
2
**
20
)
total_mem
.
append
(
mem
)
return
sum
(
total_mem
)
/
len
(
total_mem
)
def
calculate_diff
(
batch_size
,
vocab_size
):
dtype
=
torch
.
bfloat16
device
=
torch
.
device
(
"cuda"
)
logits
=
torch
.
randn
(
batch_size
,
vocab_size
,
device
=
device
,
dtype
=
dtype
)
scaling_penalties
=
(
torch
.
rand
(
batch_size
,
vocab_size
,
device
=
device
,
dtype
=
dtype
)
+
0.5
)
output_naive
=
sampling_scaling_penalties_naive
(
logits
.
clone
(),
scaling_penalties
.
clone
()
)
output_kernel
=
sampling_scaling_penalties_kernel
(
logits
.
clone
(),
scaling_penalties
.
clone
()
)
print
(
f
"Naive output=
{
output_naive
}
"
)
print
(
f
"Kernel output=
{
output_kernel
}
"
)
if
torch
.
allclose
(
output_naive
,
output_kernel
,
atol
=
1e-2
,
rtol
=
1e-2
):
print
(
"✅ Both implementations match"
)
else
:
print
(
"❌ Implementations differ"
)
batch_size_range
=
[
2
**
i
for
i
in
range
(
0
,
12
)]
vocab_size_range
=
[
2
**
i
for
i
in
range
(
10
,
17
)]
configs
=
list
(
itertools
.
product
(
batch_size_range
,
vocab_size_range
))
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
,
"vocab_size"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
line_arg
=
"provider"
,
line_vals
=
[
"naive"
,
"kernel"
],
line_names
=
[
"PyTorch Naive"
,
"SGL Kernel"
],
styles
=
[(
"blue"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"sampling-scaling-penalties-performance"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
vocab_size
,
provider
):
dtype
=
torch
.
bfloat16
device
=
torch
.
device
(
"cuda"
)
logits
=
torch
.
randn
(
batch_size
,
vocab_size
,
device
=
device
,
dtype
=
dtype
)
scaling_penalties
=
(
torch
.
rand
(
batch_size
,
vocab_size
,
device
=
device
,
dtype
=
dtype
)
+
0.5
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"naive"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
sampling_scaling_penalties_naive
(
logits
.
clone
(),
scaling_penalties
.
clone
(),
),
quantiles
=
quantiles
,
)
else
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
sampling_scaling_penalties_kernel
(
logits
.
clone
(),
scaling_penalties
.
clone
(),
),
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
,
"vocab_size"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
line_arg
=
"provider"
,
line_vals
=
[
"naive"
,
"kernel"
],
line_names
=
[
"PyTorch Naive"
,
"SGL Kernel"
],
styles
=
[(
"blue"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"GPU memory usage (MB)"
,
plot_name
=
"sampling-scaling-penalties-memory"
,
args
=
{},
)
)
def
benchmark_memory
(
batch_size
,
vocab_size
,
provider
):
dtype
=
torch
.
bfloat16
device
=
torch
.
device
(
"cuda"
)
print
(
f
"Running memory benchmark with batch_size=
{
batch_size
}
, vocab_size=
{
vocab_size
}
, provider=
{
provider
}
"
)
def
run_kernel
():
logits
=
torch
.
randn
(
batch_size
,
vocab_size
,
device
=
device
,
dtype
=
dtype
)
scaling_penalties
=
(
torch
.
rand
(
batch_size
,
vocab_size
,
device
=
device
,
dtype
=
dtype
)
+
0.5
)
if
provider
==
"naive"
:
return
sampling_scaling_penalties_naive
(
logits
,
scaling_penalties
)
else
:
return
sampling_scaling_penalties_kernel
(
logits
,
scaling_penalties
)
mem
=
test_memory
(
run_kernel
,
_iter
=
10
)
return
mem
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--save_path"
,
type
=
str
,
default
=
"./configs/benchmark_ops/sampling_scaling_penalties/"
,
help
=
"Path to save sampling_scaling_penalties benchmark results"
,
)
args
=
parser
.
parse_args
()
# Run correctness test
calculate_diff
(
batch_size
=
4
,
vocab_size
=
4096
)
# Run performance benchmark
benchmark
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
)
# Run memory benchmark
benchmark_memory
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
)
sgl-kernel/tests/test_moe_align.py
View file @
d08c77c4
...
...
@@ -3,38 +3,65 @@ from sgl_kernel import moe_align_block_size
def
test_moe_align_block_size
():
# For DeepSeek V3, we have 256 experts
num_experts
=
256
block_size
=
128
topk_ids
=
torch
.
randint
(
0
,
num_experts
,
(
3
,
4
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
empty
(
(
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
max_num_tokens_padded
//
block_size
expert_ids
=
torch
.
empty
(
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
token_cnts_buffer
=
torch
.
empty
(
(
num_experts
+
1
)
*
num_experts
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
cumsum_buffer
=
torch
.
empty
(
num_experts
+
1
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
token_cnts_buffer
,
cumsum_buffer
,
)
test_moe_align_block_size
()
# Test different combinations of block_size, num_tokens and topk
for
block_size
in
[
32
,
64
,
128
,
256
]:
print
(
f
"
\n
Testing block_size=
{
block_size
}
"
)
for
num_tokens
in
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
]:
for
topk
in
[
1
,
2
,
4
,
8
,
16
,
32
,
64
]:
print
(
f
"Testing block_size=
{
block_size
}
, num_tokens=
{
num_tokens
}
, topk=
{
topk
}
"
)
# Create random topk_ids with shape [num_tokens, topk]
topk_ids
=
torch
.
randint
(
0
,
num_experts
,
(
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
empty
(
(
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
max_num_tokens_padded
//
block_size
expert_ids
=
torch
.
empty
(
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
(
(
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
token_cnts_buffer
=
torch
.
empty
(
(
num_experts
+
1
)
*
num_experts
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
,
)
cumsum_buffer
=
torch
.
empty
(
num_experts
+
1
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
try
:
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
token_cnts_buffer
,
cumsum_buffer
,
)
except
Exception
as
e
:
print
(
f
"Error occurred with block_size=
{
block_size
}
, num_tokens=
{
num_tokens
}
, topk=
{
topk
}
"
)
print
(
f
"Error message:
{
str
(
e
)
}
"
)
raise
e
if
__name__
==
"__main__"
:
test_moe_align_block_size
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment