Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
006693ed
Commit
006693ed
authored
Dec 01, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.11.2' into v0.11.2-ori
parents
4b51e6f1
275de341
Changes
544
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
927 additions
and
467 deletions
+927
-467
benchmarks/kernels/benchmark_layernorm.py
benchmarks/kernels/benchmark_layernorm.py
+2
-1
benchmarks/kernels/benchmark_lora.py
benchmarks/kernels/benchmark_lora.py
+457
-40
benchmarks/kernels/benchmark_machete.py
benchmarks/kernels/benchmark_machete.py
+16
-17
benchmarks/kernels/benchmark_marlin.py
benchmarks/kernels/benchmark_marlin.py
+1
-1
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+21
-4
benchmarks/kernels/benchmark_moe_permute_unpermute.py
benchmarks/kernels/benchmark_moe_permute_unpermute.py
+2
-2
benchmarks/kernels/benchmark_mrope.py
benchmarks/kernels/benchmark_mrope.py
+1
-1
benchmarks/kernels/benchmark_paged_attention.py
benchmarks/kernels/benchmark_paged_attention.py
+3
-4
benchmarks/kernels/benchmark_per_token_group_quant.py
benchmarks/kernels/benchmark_per_token_group_quant.py
+1
-1
benchmarks/kernels/benchmark_polynorm.py
benchmarks/kernels/benchmark_polynorm.py
+0
-155
benchmarks/kernels/benchmark_quant.py
benchmarks/kernels/benchmark_quant.py
+2
-1
benchmarks/kernels/benchmark_reshape_and_cache.py
benchmarks/kernels/benchmark_reshape_and_cache.py
+172
-0
benchmarks/kernels/benchmark_reshape_and_cache_flash.py
benchmarks/kernels/benchmark_reshape_and_cache_flash.py
+2
-4
benchmarks/kernels/benchmark_rmsnorm.py
benchmarks/kernels/benchmark_rmsnorm.py
+5
-6
benchmarks/kernels/benchmark_rope.py
benchmarks/kernels/benchmark_rope.py
+65
-92
benchmarks/kernels/benchmark_shapes.py
benchmarks/kernels/benchmark_shapes.py
+2
-2
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
+168
-123
benchmarks/kernels/benchmark_trtllm_decode_attention.py
benchmarks/kernels/benchmark_trtllm_decode_attention.py
+2
-5
benchmarks/kernels/benchmark_trtllm_prefill_attention.py
benchmarks/kernels/benchmark_trtllm_prefill_attention.py
+2
-5
benchmarks/kernels/benchmark_w8a8_block_fp8.py
benchmarks/kernels/benchmark_w8a8_block_fp8.py
+3
-3
No files found.
Too many changes to show.
To preserve performance only
544 of 544+
files are displayed.
Plain diff
Email patch
benchmarks/kernels/benchmark_layernorm.py
View file @
006693ed
...
@@ -7,7 +7,8 @@ import torch
...
@@ -7,7 +7,8 @@ import torch
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
...
benchmarks/kernels/benchmark_lora.py
View file @
006693ed
...
@@ -6,11 +6,12 @@ import copy
...
@@ -6,11 +6,12 @@ import copy
import
json
import
json
import
pickle
import
pickle
import
time
import
time
from
collections.abc
import
Callable
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
itertools
import
product
from
itertools
import
product
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Optional
from
typing
import
Any
import
torch
import
torch
import
torch.utils.benchmark
as
TBenchmark
import
torch.utils.benchmark
as
TBenchmark
...
@@ -18,13 +19,24 @@ from torch.utils.benchmark import Measurement as TMeasurement
...
@@ -18,13 +19,24 @@ from torch.utils.benchmark import Measurement as TMeasurement
from
utils
import
ArgPool
,
Bench
,
CudaGraphBenchParams
from
utils
import
ArgPool
,
Bench
,
CudaGraphBenchParams
from
weight_shapes
import
WEIGHT_SHAPES
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm.triton_utils
import
HAS_TRITON
from
vllm.lora.ops.triton_ops.utils
import
get_lora_op_configs
from
vllm.triton_utils
import
HAS_TRITON
,
triton
if
HAS_TRITON
:
if
HAS_TRITON
:
from
vllm.lora.ops.triton_ops
import
LoRAKernelMeta
,
lora_expand
,
lora_shrink
from
vllm.lora.ops.triton_ops
import
(
## added fused_moe_lora
LoRAKernelMeta
,
fused_moe_lora_expand
,
fused_moe_lora_shrink
,
lora_expand
,
lora_shrink
,
)
from
vllm.lora.ops.triton_ops.fused_moe_lora_op
import
(
_LORA_PTR_DICT
,
## added _LORA_PTR_DICT for fused_moe_lora
)
from
vllm.lora.ops.triton_ops.utils
import
_LORA_A_PTR_DICT
,
_LORA_B_PTR_DICT
from
vllm.lora.ops.triton_ops.utils
import
_LORA_A_PTR_DICT
,
_LORA_B_PTR_DICT
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.math_utils
import
round_up
DEFAULT_MODELS
=
list
(
WEIGHT_SHAPES
.
keys
())
DEFAULT_MODELS
=
list
(
WEIGHT_SHAPES
.
keys
())
DEFAULT_TP_SIZES
=
[
1
]
DEFAULT_TP_SIZES
=
[
1
]
...
@@ -58,6 +70,8 @@ DEFAULT_NUM_LORAS = [1, 2, 3, 4]
...
@@ -58,6 +70,8 @@ DEFAULT_NUM_LORAS = [1, 2, 3, 4]
DEFAULT_SORT_BY_LORA_IDS
=
[
False
,
True
]
DEFAULT_SORT_BY_LORA_IDS
=
[
False
,
True
]
DEFAULT_SEQ_LENGTHS
=
[
1
]
DEFAULT_SEQ_LENGTHS
=
[
1
]
DEFAULT_EXPAND_FN_ADD_INPUTS
=
[
True
,
False
]
DEFAULT_EXPAND_FN_ADD_INPUTS
=
[
True
,
False
]
DEFAULT_TOP_K_NUMS
=
[
1
]
# Added for MoE LoRA top_k
DEFAULT_NUM_EXPERTS
=
[
8
]
# Added for MoE LoRA num_experts
# Utilities
# Utilities
...
@@ -158,7 +172,7 @@ def ref_group_gemm(
...
@@ -158,7 +172,7 @@ def ref_group_gemm(
seq_lens_cpu
:
torch
.
Tensor
,
seq_lens_cpu
:
torch
.
Tensor
,
prompt_lora_mapping_cpu
:
torch
.
Tensor
,
prompt_lora_mapping_cpu
:
torch
.
Tensor
,
scaling
:
float
,
scaling
:
float
,
add_inputs
:
Optional
[
bool
]
,
add_inputs
:
bool
|
None
,
):
):
"""
"""
Torch group gemm reference implementation to test correctness of
Torch group gemm reference implementation to test correctness of
...
@@ -190,6 +204,11 @@ class OpType(Enum):
...
@@ -190,6 +204,11 @@ class OpType(Enum):
LORA_SHRINK
=
auto
()
LORA_SHRINK
=
auto
()
LORA_EXPAND
=
auto
()
LORA_EXPAND
=
auto
()
## Adding support for fused moe lora
FUSED_MOE_LORA_GATE_UP_SHRINK
=
auto
()
## Gate/Up projection variant with shrink
FUSED_MOE_LORA_GATE_UP_EXPAND
=
auto
()
## Gate/Up projection variant with expand
FUSED_MOE_LORA_DOWN_SHRINK
=
auto
()
## Down projection variant with shrink
FUSED_MOE_LORA_DOWN_EXPAND
=
auto
()
## Down projection variant with expand
@
staticmethod
@
staticmethod
def
from_str
(
s
:
str
)
->
"OpType"
:
def
from_str
(
s
:
str
)
->
"OpType"
:
...
@@ -197,6 +216,15 @@ class OpType(Enum):
...
@@ -197,6 +216,15 @@ class OpType(Enum):
return
OpType
.
LORA_SHRINK
return
OpType
.
LORA_SHRINK
if
s
.
lower
()
==
"lora_expand"
:
if
s
.
lower
()
==
"lora_expand"
:
return
OpType
.
LORA_EXPAND
return
OpType
.
LORA_EXPAND
# Adding support for fused moe lora, both in gate_up and down
if
s
.
lower
()
==
"fused_moe_lora_gate_up_shrink"
:
## Gate/Up variant with shrink
return
OpType
.
FUSED_MOE_LORA_GATE_UP_SHRINK
if
s
.
lower
()
==
"fused_moe_lora_gate_up_expand"
:
## Gate/Up variant with expand
return
OpType
.
FUSED_MOE_LORA_GATE_UP_EXPAND
if
s
.
lower
()
==
"fused_moe_lora_down_shrink"
:
## Down variant with shrink
return
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
if
s
.
lower
()
==
"fused_moe_lora_down_expand"
:
## Down variant with expand
return
OpType
.
FUSED_MOE_LORA_DOWN_EXPAND
raise
ValueError
(
f
"Unrecognized str
{
s
}
to convert to OpType"
)
raise
ValueError
(
f
"Unrecognized str
{
s
}
to convert to OpType"
)
def
is_shrink_fn
(
self
)
->
bool
:
def
is_shrink_fn
(
self
)
->
bool
:
...
@@ -205,19 +233,56 @@ class OpType(Enum):
...
@@ -205,19 +233,56 @@ class OpType(Enum):
def
is_expand_fn
(
self
)
->
bool
:
def
is_expand_fn
(
self
)
->
bool
:
return
self
in
[
OpType
.
LORA_EXPAND
]
return
self
in
[
OpType
.
LORA_EXPAND
]
def
is_fused_moe_lora_fn
(
self
)
->
bool
:
## adding for fused MoE LoRA
return
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_SHRINK
,
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
,
OpType
.
FUSED_MOE_LORA_GATE_UP_EXPAND
,
OpType
.
FUSED_MOE_LORA_DOWN_EXPAND
,
]
def
is_fused_moe_lora_gate_up_fn
(
self
,
)
->
bool
:
## adding for fused MoE LoRA Gate/Up
return
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_SHRINK
,
OpType
.
FUSED_MOE_LORA_GATE_UP_EXPAND
,
]
def
is_fused_moe_lora_down_fn
(
self
)
->
bool
:
## adding for fused MoE LoRA Down
return
self
in
[
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
,
OpType
.
FUSED_MOE_LORA_DOWN_EXPAND
,
]
def
is_fused_moe_lora_shrink_fn
(
self
)
->
bool
:
return
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_SHRINK
,
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
,
]
def
is_fused_moe_lora_expand_fn
(
self
)
->
bool
:
return
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_EXPAND
,
OpType
.
FUSED_MOE_LORA_DOWN_EXPAND
,
]
def
num_slices
(
self
)
->
list
[
int
]:
def
num_slices
(
self
)
->
list
[
int
]:
if
self
.
is_fused_moe_lora_gate_up_fn
():
return
[
2
]
elif
self
.
is_fused_moe_lora_down_fn
():
return
[
1
]
return
[
1
,
2
,
3
]
return
[
1
,
2
,
3
]
def
mkn
(
def
mkn
(
self
,
batch_size
:
int
,
seq_length
:
int
,
hidden_size
:
int
,
lora_rank
:
int
self
,
batch_size
:
int
,
seq_length
:
int
,
hidden_size
:
int
,
lora_rank
:
int
)
->
tuple
[
int
,
int
,
int
]:
)
->
tuple
[
int
,
int
,
int
]:
num_tokens
=
batch_size
*
seq_length
num_tokens
=
batch_size
*
seq_length
if
self
.
is_shrink_fn
():
if
self
.
is_shrink_fn
()
or
self
.
is_fused_moe_lora_fn
()
:
m
=
num_tokens
m
=
num_tokens
k
=
hidden_size
k
=
hidden_size
n
=
lora_rank
n
=
lora_rank
else
:
elif
self
.
is_expand_fn
():
assert
self
.
is_expand_fn
()
m
=
num_tokens
m
=
num_tokens
k
=
lora_rank
k
=
lora_rank
n
=
hidden_size
n
=
hidden_size
...
@@ -231,9 +296,36 @@ class OpType(Enum):
...
@@ -231,9 +296,36 @@ class OpType(Enum):
"""
"""
if
self
.
is_shrink_fn
():
if
self
.
is_shrink_fn
():
return
op_dtype
,
op_dtype
,
torch
.
float32
return
op_dtype
,
op_dtype
,
torch
.
float32
else
:
elif
self
.
is_expand_fn
():
assert
self
.
is_expand_fn
()
return
torch
.
float32
,
op_dtype
,
op_dtype
return
torch
.
float32
,
op_dtype
,
op_dtype
else
:
assert
self
.
is_fused_moe_lora_fn
()
return
op_dtype
,
op_dtype
,
op_dtype
def
matmul_shapes_fused_moe_lora
(
self
,
m
:
int
,
n
:
int
,
k
:
int
,
num_loras
:
int
,
num_slices
:
int
,
top_k_num
:
int
,
num_experts
:
int
,
)
->
tuple
[
tuple
[
int
],
tuple
[
int
],
tuple
[
int
],
tuple
[
int
]]:
if
self
.
is_fused_moe_lora_shrink_fn
():
input_shape
=
(
(
m
*
top_k_num
,
n
)
if
self
in
[
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
]
else
(
m
,
n
)
)
output_shape
=
(
num_slices
,
m
,
top_k_num
,
k
)
weight_shape
=
(
num_loras
,
num_experts
,
k
,
n
)
else
:
assert
self
.
is_fused_moe_lora_expand_fn
()
input_shape
=
(
num_slices
,
m
,
top_k_num
,
k
)
output_shape
=
(
m
,
top_k_num
,
n
*
num_slices
)
weight_shape
=
(
num_loras
,
num_experts
,
n
,
k
)
return
(
input_shape
,
weight_shape
,
output_shape
)
def
matmul_shapes
(
def
matmul_shapes
(
self
,
self
,
...
@@ -243,6 +335,8 @@ class OpType(Enum):
...
@@ -243,6 +335,8 @@ class OpType(Enum):
lora_rank
:
int
,
lora_rank
:
int
,
num_loras
:
int
,
num_loras
:
int
,
num_slices
:
int
,
num_slices
:
int
,
top_k_num
:
int
|
None
=
None
,
num_experts
:
int
|
None
=
None
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
"""
"""
Given num_slices, return the shapes of the A, B, and C matrices
Given num_slices, return the shapes of the A, B, and C matrices
...
@@ -257,6 +351,16 @@ class OpType(Enum):
...
@@ -257,6 +351,16 @@ class OpType(Enum):
if
self
in
[
OpType
.
LORA_EXPAND
]:
if
self
in
[
OpType
.
LORA_EXPAND
]:
# LoRA expand kernels support num_slices inherently in the kernel
# LoRA expand kernels support num_slices inherently in the kernel
return
((
num_slices
,
m
,
k
),
b_shape
,
(
m
,
n
*
num_slices
))
return
((
num_slices
,
m
,
k
),
b_shape
,
(
m
,
n
*
num_slices
))
if
self
.
is_fused_moe_lora_fn
():
return
self
.
matmul_shapes_fused_moe_lora
(
m
,
k
,
n
,
num_loras
,
num_slices
,
top_k_num
,
num_experts
,
)
raise
ValueError
(
f
"Unrecognized op_type
{
self
}
"
)
raise
ValueError
(
f
"Unrecognized op_type
{
self
}
"
)
def
bench_fn
(
self
)
->
Callable
:
def
bench_fn
(
self
)
->
Callable
:
...
@@ -264,6 +368,16 @@ class OpType(Enum):
...
@@ -264,6 +368,16 @@ class OpType(Enum):
return
lora_shrink
return
lora_shrink
if
self
==
OpType
.
LORA_EXPAND
:
if
self
==
OpType
.
LORA_EXPAND
:
return
lora_expand
return
lora_expand
if
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_SHRINK
,
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
,
]:
return
fused_moe_lora_shrink
if
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_EXPAND
,
OpType
.
FUSED_MOE_LORA_DOWN_EXPAND
,
]:
return
fused_moe_lora_expand
raise
ValueError
(
f
"Unrecognized optype
{
self
}
"
)
raise
ValueError
(
f
"Unrecognized optype
{
self
}
"
)
...
@@ -316,8 +430,10 @@ class BenchmarkContext:
...
@@ -316,8 +430,10 @@ class BenchmarkContext:
lora_rank
:
int
lora_rank
:
int
sort_by_lora_id
:
bool
sort_by_lora_id
:
bool
dtype
:
torch
.
dtype
dtype
:
torch
.
dtype
seq_length
:
Optional
[
int
]
=
None
seq_length
:
int
|
None
=
None
num_slices
:
Optional
[
int
]
=
None
# num_slices for slice based ops
num_experts
:
int
|
None
=
None
# num_experts for MoE based ops
top_k_num
:
int
|
None
=
None
# top_k for MoE based ops
num_slices
:
int
|
None
=
None
# num_slices for slice based ops
def
with_seq_length
(
self
,
seq_length
:
int
)
->
"BenchmarkContext"
:
def
with_seq_length
(
self
,
seq_length
:
int
)
->
"BenchmarkContext"
:
ctx
=
copy
.
copy
(
self
)
ctx
=
copy
.
copy
(
self
)
...
@@ -372,6 +488,11 @@ class BenchmarkTensors:
...
@@ -372,6 +488,11 @@ class BenchmarkTensors:
f
"
{
dtype_to_str
(
self
.
output
.
dtype
)
}
"
f
"
{
dtype_to_str
(
self
.
output
.
dtype
)
}
"
)
)
def
get_num_tokens
(
self
,
size
:
int
,
top_k_num
:
int
,
op_type
:
OpType
):
return
(
size
*
top_k_num
if
op_type
in
[
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
]
else
size
)
@
staticmethod
@
staticmethod
def
make
(
def
make
(
ctx
:
BenchmarkContext
,
op_type
:
OpType
,
device
:
str
=
"cuda"
ctx
:
BenchmarkContext
,
op_type
:
OpType
,
device
:
str
=
"cuda"
...
@@ -384,6 +505,8 @@ class BenchmarkTensors:
...
@@ -384,6 +505,8 @@ class BenchmarkTensors:
ctx
.
lora_rank
,
ctx
.
lora_rank
,
ctx
.
num_loras
,
ctx
.
num_loras
,
ctx
.
num_slices
,
ctx
.
num_slices
,
ctx
.
top_k_num
,
ctx
.
num_experts
,
)
)
a_type
,
b_type
,
c_type
=
op_type
.
matmul_dtypes
(
ctx
.
dtype
)
a_type
,
b_type
,
c_type
=
op_type
.
matmul_dtypes
(
ctx
.
dtype
)
input_tensor
,
lora_weights
,
output_tensor
=
make_rand_tensors
(
input_tensor
,
lora_weights
,
output_tensor
=
make_rand_tensors
(
...
@@ -431,17 +554,27 @@ class BenchmarkTensors:
...
@@ -431,17 +554,27 @@ class BenchmarkTensors:
prompt_lora_indices_tensor
,
prompt_lora_indices_tensor
,
)
)
def
sanity_check
(
self
)
->
None
:
def
sanity_check
(
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
)
->
None
:
"""
"""
Fails asserts when non-conformality is detected.
Fails asserts when non-conformality is detected.
"""
"""
num_tokens
=
self
.
input
.
shape
[
-
2
]
num_tokens
=
(
self
.
input
.
shape
[
1
]
if
op_type
.
is_fused_moe_lora_expand_fn
()
else
self
.
input
.
shape
[
-
2
]
)
# check metadata tensors
# check metadata tensors
assert
torch
.
sum
(
self
.
seq_lens
)
==
num_tokens
## In down shrink case, each token is repeated top_k_num times
assert
num_tokens
==
self
.
get_num_tokens
(
torch
.
sum
(
self
.
seq_lens
),
ctx
.
top_k_num
,
op_type
),
f
"Expected
{
num_tokens
}
tokens, but got
{
torch
.
sum
(
self
.
seq_lens
)
}
"
num_seqs
=
self
.
seq_lens
.
shape
[
0
]
num_seqs
=
self
.
seq_lens
.
shape
[
0
]
# assert self.seq_start_loc.shape[0] == num_seqs
# assert self.seq_start_loc.shape[0] == num_seqs
## In down shrink case, each prompt corresponds to top_k_num sequences
assert
self
.
prompt_lora_mapping
.
shape
[
0
]
==
num_seqs
assert
self
.
prompt_lora_mapping
.
shape
[
0
]
==
num_seqs
assert
self
.
lora_kernel_meta
.
token_lora_mapping
.
shape
[
0
]
==
num_tokens
assert
self
.
get_num_tokens
(
self
.
lora_kernel_meta
.
token_lora_mapping
.
shape
[
0
],
ctx
.
top_k_num
,
op_type
)
def
to_device
(
self
,
device
:
str
):
def
to_device
(
self
,
device
:
str
):
"""
"""
...
@@ -470,21 +603,111 @@ class BenchmarkTensors:
...
@@ -470,21 +603,111 @@ class BenchmarkTensors:
to_device
(
field
)
if
field_name
!=
"no_lora_flag_cpu"
else
field
,
to_device
(
field
)
if
field_name
!=
"no_lora_flag_cpu"
else
field
,
)
)
def
metadata
(
self
)
->
tuple
[
int
,
int
,
int
]:
def
metadata
(
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
)
->
tuple
[
int
,
int
,
int
]:
"""
"""
Return num_seqs, num_tokens and max_seq_len
Return num_seqs, num_tokens and max_seq_len
"""
"""
num_seqs
=
self
.
seq_lens
.
shape
[
0
]
num_seqs
=
self
.
seq_lens
.
shape
[
0
]
num_tokens
=
self
.
lora_kernel_meta
.
token_lora_mapping
.
shape
[
0
]
num_tokens
=
self
.
get_num_tokens
(
self
.
lora_kernel_meta
.
token_lora_mapping
.
shape
[
0
],
ctx
.
top_k_num
,
op_type
)
max_seq_len
=
torch
.
max
(
self
.
seq_lens
).
item
()
max_seq_len
=
torch
.
max
(
self
.
seq_lens
).
item
()
num_slices
=
len
(
self
.
lora_weights_lst
)
num_slices
=
len
(
self
.
lora_weights_lst
)
return
num_seqs
,
num_tokens
,
max_seq_len
,
num_slices
return
num_seqs
,
num_tokens
,
max_seq_len
,
num_slices
def
as_lora_shrink_kwargs
(
self
)
->
dict
[
str
,
Any
]:
def
fused_moe_lora_data_prepare
(
self
.
sanity_check
()
self
,
block_size
:
int
,
token_lora_mapping
:
torch
.
Tensor
,
ctx
:
BenchmarkContext
,
):
def
moe_lora_align_block_size
(
topk_ids
:
torch
.
Tensor
,
token_lora_mapping
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
,
max_loras
:
int
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
pad_sorted_ids
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.
"""
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
if
pad_sorted_ids
:
max_num_tokens_padded
=
round_up
(
max_num_tokens_padded
,
block_size
)
sorted_ids
=
torch
.
empty
(
(
max_loras
*
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
,
)
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
# Expert ids must be set default to -1 to prevent a blank block
expert_ids
=
torch
.
empty
(
(
max_loras
*
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
,
)
num_tokens_post_pad
=
torch
.
empty
(
(
max_loras
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
ops
.
moe_lora_align_block_size
(
topk_ids
,
token_lora_mapping
,
num_experts
,
block_size
,
max_loras
,
max_num_tokens_padded
,
max_num_m_blocks
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
)
if
expert_map
is
not
None
:
expert_ids
=
expert_map
[
expert_ids
]
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
num_tokens
=
ctx
.
batch_size
curr_topk_ids
=
torch
.
randint
(
0
,
ctx
.
num_experts
,
(
num_tokens
,
ctx
.
top_k_num
),
device
=
"cuda"
,
dtype
=
torch
.
int32
,
)
topk_weights
=
torch
.
randint
(
0
,
ctx
.
num_experts
,
(
num_tokens
,
ctx
.
top_k_num
),
device
=
"cuda"
,
dtype
=
torch
.
int32
,
)
(
sorted_token_ids_lora
,
expert_ids_lora
,
num_tokens_post_padded_lora
)
=
(
moe_lora_align_block_size
(
topk_ids
=
curr_topk_ids
,
token_lora_mapping
=
token_lora_mapping
,
block_size
=
block_size
,
num_experts
=
ctx
.
num_experts
,
max_loras
=
ctx
.
num_loras
,
)
)
sorted_token_ids
=
sorted_token_ids_lora
.
view
(
ctx
.
num_loras
,
-
1
)
expert_ids
=
expert_ids_lora
.
view
(
ctx
.
num_loras
,
-
1
)
num_tokens_post_padded
=
num_tokens_post_padded_lora
return
(
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
)
def
as_lora_shrink_kwargs
(
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
)
->
dict
[
str
,
Any
]:
self
.
sanity_check
(
ctx
,
op_type
)
self
.
to_device
(
self
.
input
.
device
)
self
.
to_device
(
self
.
input
.
device
)
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
()
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
(
ctx
,
op_type
)
# Sanity check matrix shapes.
# Sanity check matrix shapes.
i_shape
,
lw_shape
,
o_shape
=
(
i_shape
,
lw_shape
,
o_shape
=
(
...
@@ -519,11 +742,13 @@ class BenchmarkTensors:
...
@@ -519,11 +742,13 @@ class BenchmarkTensors:
"no_lora_flag_cpu"
:
self
.
lora_kernel_meta
.
no_lora_flag_cpu
,
"no_lora_flag_cpu"
:
self
.
lora_kernel_meta
.
no_lora_flag_cpu
,
}
}
def
as_lora_expand_kwargs
(
self
,
add_inputs
:
bool
)
->
dict
[
str
,
Any
]:
def
as_lora_expand_kwargs
(
self
.
sanity_check
()
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
,
add_inputs
:
bool
)
->
dict
[
str
,
Any
]:
self
.
sanity_check
(
ctx
,
op_type
)
self
.
to_device
(
self
.
input
.
device
)
self
.
to_device
(
self
.
input
.
device
)
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
()
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
(
ctx
,
op_type
)
# Sanity check matrix shapes.
# Sanity check matrix shapes.
i_shape
,
lw_shape
,
o_shape
=
(
i_shape
,
lw_shape
,
o_shape
=
(
...
@@ -560,22 +785,177 @@ class BenchmarkTensors:
...
@@ -560,22 +785,177 @@ class BenchmarkTensors:
"no_lora_flag_cpu"
:
self
.
lora_kernel_meta
.
no_lora_flag_cpu
,
"no_lora_flag_cpu"
:
self
.
lora_kernel_meta
.
no_lora_flag_cpu
,
}
}
def
as_fused_moe_lora_shrink_kwargs
(
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
)
->
dict
[
str
,
Any
]:
self
.
sanity_check
(
ctx
,
op_type
)
self
.
to_device
(
self
.
input
.
device
)
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
(
ctx
,
op_type
)
# Sanity check matrix shapes.
i_shape
,
lw_shape
,
o_shape
=
(
self
.
input
.
shape
,
self
.
lora_weights_lst
[
0
].
shape
,
self
.
output
.
shape
,
)
# Expected input shape : [num_tokens, hidden_size] for gate_up
# Expected input shape : [top_k_num * num_tokens, hidden_size] for down
assert
len
(
i_shape
)
==
2
assert
i_shape
[
0
]
==
num_tokens
hidden_size
=
i_shape
[
1
]
# Expected lora weight shape [max_lora, num_experts, lora_rank, hidden_size]
assert
len
(
lw_shape
)
==
4
assert
lw_shape
[
-
1
]
==
hidden_size
lora_rank
=
lw_shape
[
-
2
]
# Expected output shape : [num_slices, num_tokens, top_k_num, lora_rank]
assert
len
(
o_shape
)
==
4
assert
(
o_shape
==
(
num_slices
,
num_tokens
//
ctx
.
top_k_num
,
ctx
.
top_k_num
,
lora_rank
)
if
op_type
in
[
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
]
else
o_shape
==
(
num_slices
,
num_tokens
,
ctx
.
top_k_num
,
lora_rank
)
)
kernel_config
=
get_lora_op_configs
(
op_type
.
name
.
lower
(),
max_loras
=
lw_shape
[
0
],
batch
=
num_tokens
,
hidden_size
=
hidden_size
,
rank
=
lora_rank
,
num_slices
=
num_slices
,
add_inputs
=
False
,
)
(
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
)
=
(
self
.
fused_moe_lora_data_prepare
(
block_size
=
kernel_config
[
"BLOCK_SIZE_M"
],
token_lora_mapping
=
self
.
lora_kernel_meta
.
token_lora_mapping
,
ctx
=
ctx
,
)
)
return
{
"qcurr_hidden_states"
:
self
.
input
,
"lora_a_stacked"
:
self
.
lora_weights_lst
,
"a_intermediate_cache1"
:
self
.
output
,
"topk_weights"
:
topk_weights
,
"sorted_token_ids"
:
sorted_token_ids
,
"expert_ids"
:
expert_ids
,
"num_tokens_post_padded"
:
num_tokens_post_padded
,
"top_k_num"
:
ctx
.
top_k_num
,
"device"
:
self
.
input
.
device
,
"N"
:
lora_rank
,
"M"
:
topk_weights
.
shape
[
0
],
"EM"
:
sorted_token_ids
.
shape
[
1
],
"K"
:
self
.
input
.
shape
[
1
],
"num_tokens"
:
num_tokens
,
"num_experts"
:
ctx
.
num_experts
,
"num_slices"
:
num_slices
,
"shrink_block_size_m"
:
kernel_config
[
"BLOCK_SIZE_M"
],
"shrink_block_size_n"
:
kernel_config
[
"BLOCK_SIZE_N"
],
"shrink_block_size_k"
:
kernel_config
[
"BLOCK_SIZE_K"
],
"shrink_group_size_m"
:
kernel_config
[
"GROUP_SIZE_M"
],
"shrink_num_warps"
:
kernel_config
[
"NUM_WARPS"
],
"shrink_num_stages"
:
kernel_config
[
"NUM_STAGES"
],
"shrink_split_k"
:
kernel_config
.
get
(
"SPLIT_K"
,
1
),
"mul_routed_weight"
:
op_type
.
is_fused_moe_lora_down_fn
(),
}
def
as_fused_moe_lora_expand_kwargs
(
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
)
->
dict
[
str
,
Any
]:
self
.
sanity_check
(
ctx
,
op_type
)
self
.
to_device
(
self
.
input
.
device
)
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
(
ctx
,
op_type
)
# Sanity check matrix shapes.
i_shape
,
lw_shape
,
o_shape
=
(
self
.
input
.
shape
,
self
.
lora_weights_lst
[
0
].
shape
,
self
.
output
.
shape
,
)
# Expected input shape : [num_slices, num_tokens, top_k_num, lora_rank]
assert
len
(
i_shape
)
==
4
assert
i_shape
[
0
]
==
num_slices
assert
i_shape
[
1
]
==
num_tokens
lora_rank
=
i_shape
[
-
1
]
# Expected lora weight shape : [num_loras, num_experts, hidden_size, lora_rank]
assert
len
(
lw_shape
)
==
4
assert
lw_shape
[
-
1
]
==
lora_rank
hidden_size
=
lw_shape
[
-
2
]
# Expected output shape : [num_tokens, top_k_num, hidden_size * num_slices]
assert
len
(
o_shape
)
==
3
assert
o_shape
==
(
num_tokens
,
ctx
.
top_k_num
,
hidden_size
*
num_slices
)
kernel_config
=
get_lora_op_configs
(
op_type
.
name
.
lower
(),
max_loras
=
lw_shape
[
0
],
batch
=
num_tokens
,
hidden_size
=
hidden_size
,
rank
=
lora_rank
,
num_slices
=
num_slices
,
add_inputs
=
False
,
)
(
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
)
=
(
self
.
fused_moe_lora_data_prepare
(
block_size
=
kernel_config
[
"BLOCK_SIZE_M"
],
token_lora_mapping
=
self
.
lora_kernel_meta
.
token_lora_mapping
,
ctx
=
ctx
,
)
)
return
{
"a_intermediate_cache1"
:
self
.
input
,
"lora_b_stacked"
:
self
.
lora_weights_lst
,
"output"
:
self
.
output
,
"topk_weights"
:
topk_weights
,
"sorted_token_ids"
:
sorted_token_ids
,
"expert_ids"
:
expert_ids
,
"num_tokens_post_padded"
:
num_tokens_post_padded
,
"top_k_num"
:
ctx
.
top_k_num
,
"device"
:
self
.
input
.
device
,
"N"
:
lora_rank
,
"M"
:
topk_weights
.
shape
[
0
],
"EM"
:
sorted_token_ids
.
shape
[
1
],
"K"
:
self
.
input
.
shape
[
1
],
"num_tokens"
:
num_tokens
,
"num_experts"
:
ctx
.
num_experts
,
"num_slices"
:
num_slices
,
"max_lora_rank"
:
lora_rank
,
"w1_output_dim_size"
:
lw_shape
[
2
],
"expand_block_size_m"
:
kernel_config
[
"BLOCK_SIZE_M"
],
"expand_block_size_n"
:
kernel_config
[
"BLOCK_SIZE_N"
],
"expand_block_size_k"
:
kernel_config
[
"BLOCK_SIZE_K"
],
"expand_group_size_m"
:
kernel_config
[
"GROUP_SIZE_M"
],
"expand_num_warps"
:
kernel_config
[
"NUM_WARPS"
],
"expand_num_stages"
:
kernel_config
[
"NUM_STAGES"
],
"expand_split_k"
:
kernel_config
.
get
(
"SPLIT_K"
,
1
),
"mul_routed_weight"
:
op_type
.
is_fused_moe_lora_down_fn
(),
}
def
bench_fn_kwargs
(
def
bench_fn_kwargs
(
self
,
op_type
:
OpType
,
add_inputs
:
Optional
[
bool
]
=
None
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
,
add_inputs
:
bool
|
None
=
None
)
->
dict
[
str
,
Any
]:
)
->
dict
[
str
,
Any
]:
if
op_type
.
is_shrink_fn
():
if
op_type
.
is_shrink_fn
()
or
op_type
.
is_fused_moe_lora_fn
()
:
assert
add_inputs
is
None
assert
add_inputs
is
None
else
:
else
:
assert
add_inputs
is
not
None
assert
add_inputs
is
not
None
if
op_type
==
OpType
.
LORA_SHRINK
:
if
op_type
==
OpType
.
LORA_SHRINK
:
return
self
.
as_lora_shrink_kwargs
()
return
self
.
as_lora_shrink_kwargs
(
ctx
,
op_type
)
if
op_type
==
OpType
.
LORA_EXPAND
:
if
op_type
==
OpType
.
LORA_EXPAND
:
return
self
.
as_lora_expand_kwargs
(
add_inputs
)
return
self
.
as_lora_expand_kwargs
(
ctx
,
op_type
,
add_inputs
)
if
op_type
.
is_fused_moe_lora_shrink_fn
():
return
self
.
as_fused_moe_lora_shrink_kwargs
(
ctx
,
op_type
)
if
op_type
.
is_fused_moe_lora_expand_fn
():
return
self
.
as_fused_moe_lora_expand_kwargs
(
ctx
,
op_type
)
raise
ValueError
(
f
"Unrecognized optype
{
self
}
"
)
raise
ValueError
(
f
"Unrecognized optype
{
self
}
"
)
def
test_correctness
(
def
test_correctness
(
self
,
op_type
:
OpType
,
expand_fn_add_inputs
:
Optional
[
bool
]
self
,
op_type
:
OpType
,
expand_fn_add_inputs
:
bool
|
None
)
->
bool
:
)
->
bool
:
"""
"""
Test correctness of op_type implementation against a grouped gemm
Test correctness of op_type implementation against a grouped gemm
...
@@ -611,12 +991,12 @@ def bench_optype(
...
@@ -611,12 +991,12 @@ def bench_optype(
ctx
:
BenchmarkContext
,
ctx
:
BenchmarkContext
,
arg_pool_size
:
int
,
arg_pool_size
:
int
,
op_type
:
OpType
,
op_type
:
OpType
,
cuda_graph_nops
:
Optional
[
int
]
=
None
,
cuda_graph_nops
:
int
|
None
=
None
,
expand_fn_add_inputs
:
Optional
[
bool
]
=
None
,
expand_fn_add_inputs
:
bool
|
None
=
None
,
test_correctness
:
bool
=
False
,
test_correctness
:
bool
=
False
,
)
->
TMeasurement
:
)
->
TMeasurement
:
assert
arg_pool_size
>=
1
assert
arg_pool_size
>=
1
if
op_type
.
is_shrink_fn
():
if
op_type
.
is_shrink_fn
()
or
op_type
.
is_fused_moe_lora_fn
()
:
assert
expand_fn_add_inputs
is
None
assert
expand_fn_add_inputs
is
None
else
:
else
:
assert
expand_fn_add_inputs
is
not
None
assert
expand_fn_add_inputs
is
not
None
...
@@ -626,23 +1006,30 @@ def bench_optype(
...
@@ -626,23 +1006,30 @@ def bench_optype(
BenchmarkTensors
.
make
(
ctx
,
op_type
)
for
_
in
range
(
arg_pool_size
)
BenchmarkTensors
.
make
(
ctx
,
op_type
)
for
_
in
range
(
arg_pool_size
)
]
]
for
bt
in
bench_tensors
:
for
bt
in
bench_tensors
:
bt
.
sanity_check
()
bt
.
sanity_check
(
ctx
,
op_type
)
# Test correctness of our implementation.
# Test correctness of our implementation.
if
test_correctness
:
if
test_correctness
:
assert
op_type
in
[
OpType
.
LORA_SHRINK
,
OpType
.
LORA_EXPAND
],
(
f
"Correctness testing is not supported for
{
op_type
.
name
}
."
)
assert
all
(
assert
all
(
[
bt
.
test_correctness
(
op_type
,
expand_fn_add_inputs
)
for
bt
in
bench_tensors
]
[
bt
.
test_correctness
(
ctx
,
op_type
,
expand_fn_add_inputs
)
for
bt
in
bench_tensors
]
)
)
# BenchmarkTensors -> dict (kwargs)
# BenchmarkTensors -> dict (kwargs)
kwargs_list
=
[
kwargs_list
=
[
bt
.
bench_fn_kwargs
(
op_type
,
add_inputs
=
expand_fn_add_inputs
)
bt
.
bench_fn_kwargs
(
ctx
,
op_type
,
add_inputs
=
expand_fn_add_inputs
)
for
bt
in
bench_tensors
for
bt
in
bench_tensors
]
]
# Clear LoRA optimization hash-maps.
# Clear LoRA optimization hash-maps.
_LORA_A_PTR_DICT
.
clear
()
_LORA_A_PTR_DICT
.
clear
()
_LORA_B_PTR_DICT
.
clear
()
_LORA_B_PTR_DICT
.
clear
()
_LORA_PTR_DICT
.
clear
()
# Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up
# Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up
for
kwargs
in
kwargs_list
:
for
kwargs
in
kwargs_list
:
op_type
.
bench_fn
()(
**
kwargs
)
op_type
.
bench_fn
()(
**
kwargs
)
...
@@ -679,7 +1066,7 @@ def bench_torch_mm(
...
@@ -679,7 +1066,7 @@ def bench_torch_mm(
ctx
:
BenchmarkContext
,
ctx
:
BenchmarkContext
,
arg_pool_size
:
int
,
arg_pool_size
:
int
,
op_type
:
OpType
,
op_type
:
OpType
,
cuda_graph_nops
:
Optional
[
int
]
=
None
,
cuda_graph_nops
:
int
|
None
=
None
,
)
->
TMeasurement
:
)
->
TMeasurement
:
"""
"""
Benchmark basic torch.mm as a roofline.
Benchmark basic torch.mm as a roofline.
...
@@ -744,7 +1131,7 @@ def use_cuda_graph_recommendation() -> str:
...
@@ -744,7 +1131,7 @@ def use_cuda_graph_recommendation() -> str:
"""
"""
def
print_timers
(
timers
:
list
[
TMeasurement
],
args
:
Optional
[
argparse
.
Namespace
]
=
None
):
def
print_timers
(
timers
:
list
[
TMeasurement
],
args
:
argparse
.
Namespace
|
None
=
None
):
compare
=
TBenchmark
.
Compare
(
timers
)
compare
=
TBenchmark
.
Compare
(
timers
)
compare
.
print
()
compare
.
print
()
...
@@ -792,7 +1179,9 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
...
@@ -792,7 +1179,9 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
# Benchmark bench_op
# Benchmark bench_op
expand_fn_add_inputs
=
(
expand_fn_add_inputs
=
(
[
None
]
if
bench_op
.
is_shrink_fn
()
else
args
.
expand_fn_add_inputs
[
None
]
if
bench_op
.
is_shrink_fn
()
or
bench_op
.
is_fused_moe_lora_fn
()
else
args
.
expand_fn_add_inputs
)
)
for
add_input_arg
in
expand_fn_add_inputs
:
for
add_input_arg
in
expand_fn_add_inputs
:
seq_len_timers
.
append
(
seq_len_timers
.
append
(
...
@@ -830,12 +1219,22 @@ def as_benchmark_contexts(
...
@@ -830,12 +1219,22 @@ def as_benchmark_contexts(
hidden_sizes
:
list
[
int
],
lora_ranks
:
list
[
int
],
args
:
argparse
.
Namespace
hidden_sizes
:
list
[
int
],
lora_ranks
:
list
[
int
],
args
:
argparse
.
Namespace
)
->
list
[
BenchmarkContext
]:
)
->
list
[
BenchmarkContext
]:
ctxs
:
list
[
BenchmarkContext
]
=
[]
ctxs
:
list
[
BenchmarkContext
]
=
[]
for
batch_size
,
hidden_size
,
lora_rank
,
num_loras
,
sort_by_lora_id
in
product
(
# noqa
for
(
batch_size
,
hidden_size
,
lora_rank
,
num_loras
,
sort_by_lora_id
,
top_k_num
,
num_experts
,
)
in
product
(
# noqa
args
.
batch_sizes
,
args
.
batch_sizes
,
list
(
hidden_sizes
),
list
(
hidden_sizes
),
lora_ranks
,
lora_ranks
,
args
.
num_loras
,
args
.
num_loras
,
args
.
sort_by_lora_id
,
args
.
sort_by_lora_id
,
args
.
top_k_nums
,
args
.
num_experts
,
):
):
ctxs
.
append
(
ctxs
.
append
(
BenchmarkContext
(
BenchmarkContext
(
...
@@ -850,6 +1249,8 @@ def as_benchmark_contexts(
...
@@ -850,6 +1249,8 @@ def as_benchmark_contexts(
seq_length
=
None
,
seq_length
=
None
,
sort_by_lora_id
=
sort_by_lora_id
,
sort_by_lora_id
=
sort_by_lora_id
,
dtype
=
args
.
dtype
,
dtype
=
args
.
dtype
,
top_k_num
=
top_k_num
,
num_experts
=
num_experts
,
# To be filled based on the OpType to benchmark
# To be filled based on the OpType to benchmark
num_slices
=
None
,
num_slices
=
None
,
)
)
...
@@ -1011,6 +1412,22 @@ if __name__ == "__main__":
...
@@ -1011,6 +1412,22 @@ if __name__ == "__main__":
),
),
)
)
p
.
add_argument
(
"--top-k-nums"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_TOP_K_NUMS
,
help
=
"Top-K values for MoE LoRA operations"
,
)
p
.
add_argument
(
"--num-experts"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_NUM_EXPERTS
,
help
=
"Number of experts for MoE LoRA operations"
,
)
parser
=
FlexibleArgumentParser
(
parser
=
FlexibleArgumentParser
(
description
=
f
"""
description
=
f
"""
Benchmark LoRA kernels:
Benchmark LoRA kernels:
...
...
benchmarks/kernels/benchmark_machete.py
View file @
006693ed
...
@@ -8,10 +8,9 @@ import math
...
@@ -8,10 +8,9 @@ import math
import
os
import
os
import
pickle
as
pkl
import
pickle
as
pkl
import
time
import
time
from
collections.abc
import
Iterable
from
collections.abc
import
Callable
,
Iterable
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
itertools
import
product
from
itertools
import
product
from
typing
import
Callable
,
Optional
import
pandas
as
pd
import
pandas
as
pd
import
torch
import
torch
...
@@ -34,7 +33,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -34,7 +33,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights
,
quantize_weights
,
)
)
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
[
"meta-llama/Llama-3-8b"
,
"meta-llama/Llama-2-70b-hf"
]
DEFAULT_MODELS
=
[
"meta-llama/Llama-3-8b"
,
"meta-llama/Llama-2-70b-hf"
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
]
...
@@ -63,23 +62,23 @@ class BenchmarkTensors:
...
@@ -63,23 +62,23 @@ class BenchmarkTensors:
a
:
torch
.
Tensor
a
:
torch
.
Tensor
w_q
:
torch
.
Tensor
w_q
:
torch
.
Tensor
group_size
:
Optional
[
int
]
group_size
:
int
|
None
wtype
:
ScalarType
wtype
:
ScalarType
w_g_s
:
torch
.
Tensor
w_g_s
:
torch
.
Tensor
w_g_zp
:
Optional
[
torch
.
Tensor
]
w_g_zp
:
torch
.
Tensor
|
None
w_ch_s
:
Optional
[
torch
.
Tensor
]
w_ch_s
:
torch
.
Tensor
|
None
w_tok_s
:
Optional
[
torch
.
Tensor
]
w_tok_s
:
torch
.
Tensor
|
None
@
dataclass
@
dataclass
class
TypeConfig
:
class
TypeConfig
:
act_type
:
torch
.
dtype
act_type
:
torch
.
dtype
weight_type
:
ScalarType
weight_type
:
ScalarType
output_type
:
Optional
[
torch
.
dtype
]
output_type
:
torch
.
dtype
|
None
group_scale_type
:
Optional
[
torch
.
dtype
]
group_scale_type
:
torch
.
dtype
|
None
group_zero_type
:
Optional
[
torch
.
dtype
]
group_zero_type
:
torch
.
dtype
|
None
channel_scale_type
:
Optional
[
torch
.
dtype
]
channel_scale_type
:
torch
.
dtype
|
None
token_scale_type
:
Optional
[
torch
.
dtype
]
token_scale_type
:
torch
.
dtype
|
None
def
rand_data
(
shape
,
dtype
=
torch
.
float16
,
scale
=
1
):
def
rand_data
(
shape
,
dtype
=
torch
.
float16
,
scale
=
1
):
...
@@ -93,8 +92,8 @@ def quantize_and_pack(
...
@@ -93,8 +92,8 @@ def quantize_and_pack(
atype
:
torch
.
dtype
,
atype
:
torch
.
dtype
,
w
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
wtype
:
ScalarType
,
wtype
:
ScalarType
,
stype
:
Optional
[
torch
.
dtype
]
,
stype
:
torch
.
dtype
|
None
,
group_size
:
Optional
[
int
]
,
group_size
:
int
|
None
,
zero_points
:
bool
=
False
,
zero_points
:
bool
=
False
,
):
):
assert
wtype
.
is_integer
(),
"TODO: support floating point weights"
assert
wtype
.
is_integer
(),
"TODO: support floating point weights"
...
@@ -113,7 +112,7 @@ def quantize_and_pack(
...
@@ -113,7 +112,7 @@ def quantize_and_pack(
def
create_bench_tensors
(
def
create_bench_tensors
(
shape
:
tuple
[
int
,
int
,
int
],
types
:
TypeConfig
,
group_size
:
Optional
[
int
]
shape
:
tuple
[
int
,
int
,
int
],
types
:
TypeConfig
,
group_size
:
int
|
None
)
->
list
[
BenchmarkTensors
]:
)
->
list
[
BenchmarkTensors
]:
m
,
n
,
k
=
shape
m
,
n
,
k
=
shape
...
@@ -331,8 +330,8 @@ def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable])
...
@@ -331,8 +330,8 @@ def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable])
return
res
return
res
_SWEEP_SCHEDULES_RESULTS
:
Optional
[
pd
.
DataFrame
]
=
None
_SWEEP_SCHEDULES_RESULTS
:
pd
.
DataFrame
|
None
=
None
_SWEEP_SCHEDULES_RESULTS_CSV
:
Optional
[
str
]
=
None
_SWEEP_SCHEDULES_RESULTS_CSV
:
str
|
None
=
None
def
bench
(
def
bench
(
...
...
benchmarks/kernels/benchmark_marlin.py
View file @
006693ed
...
@@ -44,7 +44,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -44,7 +44,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
sort_weights
,
sort_weights
,
)
)
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
[
"meta-llama/Llama-2-7b-hf/TP1"
]
DEFAULT_MODELS
=
[
"meta-llama/Llama-2-7b-hf/TP1"
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
]
...
...
benchmarks/kernels/benchmark_moe.py
View file @
006693ed
...
@@ -22,7 +22,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import *
...
@@ -22,7 +22,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import *
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.transformers_utils.config
import
get_config
from
vllm.transformers_utils.config
import
get_config
from
vllm.triton_utils
import
triton
from
vllm.triton_utils
import
triton
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
...
@@ -211,7 +211,7 @@ def get_rocm_tuning_space(use_fp16):
...
@@ -211,7 +211,7 @@ def get_rocm_tuning_space(use_fp16):
num_warps_range
=
[
1
,
2
,
4
,
8
]
num_warps_range
=
[
1
,
2
,
4
,
8
]
group_m_range
=
[
1
,
4
,
8
,
16
,
32
]
group_m_range
=
[
1
,
4
,
8
,
16
,
32
]
num_stage_range
=
[
2
]
num_stage_range
=
[
2
]
waves_per_eu_range
=
[
0
]
waves_per_eu_range
=
[
0
,
1
,
2
,
4
]
matrix_instr_nonkdim_range
=
[
16
,
32
]
if
use_fp16
else
[]
matrix_instr_nonkdim_range
=
[
16
,
32
]
if
use_fp16
else
[]
kpack_range
=
[
1
,
2
]
if
use_fp16
else
[]
kpack_range
=
[
1
,
2
]
if
use_fp16
else
[]
...
@@ -579,19 +579,23 @@ def main(args: argparse.Namespace):
...
@@ -579,19 +579,23 @@ def main(args: argparse.Namespace):
E
=
config
.
ffn_config
.
moe_num_experts
E
=
config
.
ffn_config
.
moe_num_experts
topk
=
config
.
ffn_config
.
moe_top_k
topk
=
config
.
ffn_config
.
moe_top_k
intermediate_size
=
config
.
ffn_config
.
ffn_hidden_size
intermediate_size
=
config
.
ffn_config
.
ffn_hidden_size
hidden_size
=
config
.
hidden_size
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
E
=
config
.
num_experts
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
intermediate_size
=
config
.
intermediate_size
hidden_size
=
config
.
hidden_size
elif
config
.
architectures
[
0
]
in
(
elif
config
.
architectures
[
0
]
in
(
"DeepseekV2ForCausalLM"
,
"DeepseekV2ForCausalLM"
,
"DeepseekV3ForCausalLM"
,
"DeepseekV3ForCausalLM"
,
"DeepseekV32ForCausalLM"
,
"DeepseekV32ForCausalLM"
,
"Glm4MoeForCausalLM"
,
"Glm4MoeForCausalLM"
,
"NemotronHForCausalLM"
,
):
):
E
=
config
.
n_routed_experts
E
=
config
.
n_routed_experts
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
hidden_size
=
config
.
hidden_size
elif
config
.
architectures
[
0
]
in
(
elif
config
.
architectures
[
0
]
in
(
"Qwen2MoeForCausalLM"
,
"Qwen2MoeForCausalLM"
,
"Qwen3MoeForCausalLM"
,
"Qwen3MoeForCausalLM"
,
...
@@ -600,10 +604,23 @@ def main(args: argparse.Namespace):
...
@@ -600,10 +604,23 @@ def main(args: argparse.Namespace):
E
=
config
.
num_experts
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
hidden_size
=
config
.
hidden_size
elif
config
.
architectures
[
0
]
==
"Qwen3VLMoeForConditionalGeneration"
:
text_config
=
config
.
get_text_config
()
E
=
text_config
.
num_experts
topk
=
text_config
.
num_experts_per_tok
intermediate_size
=
text_config
.
moe_intermediate_size
hidden_size
=
text_config
.
hidden_size
elif
config
.
architectures
[
0
]
in
(
"HunYuanMoEV1ForCausalLM"
):
elif
config
.
architectures
[
0
]
in
(
"HunYuanMoEV1ForCausalLM"
):
E
=
config
.
num_experts
E
=
config
.
num_experts
topk
=
config
.
moe_topk
[
0
]
topk
=
config
.
moe_topk
[
0
]
intermediate_size
=
config
.
moe_intermediate_size
[
0
]
intermediate_size
=
config
.
moe_intermediate_size
[
0
]
hidden_size
=
config
.
hidden_size
elif
config
.
architectures
[
0
]
in
[
"Qwen3OmniMoeForConditionalGeneration"
]:
E
=
config
.
thinker_config
.
text_config
.
num_experts
topk
=
config
.
thinker_config
.
text_config
.
num_experts_per_tok
intermediate_size
=
config
.
thinker_config
.
text_config
.
moe_intermediate_size
hidden_size
=
config
.
thinker_config
.
text_config
.
hidden_size
else
:
else
:
# Support for llama4
# Support for llama4
config
=
config
.
get_text_config
()
config
=
config
.
get_text_config
()
...
@@ -611,6 +628,7 @@ def main(args: argparse.Namespace):
...
@@ -611,6 +628,7 @@ def main(args: argparse.Namespace):
E
=
config
.
num_local_experts
E
=
config
.
num_local_experts
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
intermediate_size
=
config
.
intermediate_size
hidden_size
=
config
.
hidden_size
enable_ep
=
bool
(
args
.
enable_expert_parallel
)
enable_ep
=
bool
(
args
.
enable_expert_parallel
)
if
enable_ep
:
if
enable_ep
:
ensure_divisibility
(
E
,
args
.
tp_size
,
"Number of experts"
)
ensure_divisibility
(
E
,
args
.
tp_size
,
"Number of experts"
)
...
@@ -619,8 +637,7 @@ def main(args: argparse.Namespace):
...
@@ -619,8 +637,7 @@ def main(args: argparse.Namespace):
else
:
else
:
ensure_divisibility
(
intermediate_size
,
args
.
tp_size
,
"intermediate_size"
)
ensure_divisibility
(
intermediate_size
,
args
.
tp_size
,
"intermediate_size"
)
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
hidden_size
=
config
.
hidden_size
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
dtype
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
torch_dtype
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
block_quant_shape
=
get_weight_block_size_safety
(
config
)
block_quant_shape
=
get_weight_block_size_safety
(
config
)
...
...
benchmarks/kernels/benchmark_moe_permute_unpermute.py
View file @
006693ed
...
@@ -17,7 +17,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
...
@@ -17,7 +17,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
)
)
from
vllm.model_executor.layers.fused_moe.utils
import
_fp8_quantize
from
vllm.model_executor.layers.fused_moe.utils
import
_fp8_quantize
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
...
@@ -344,7 +344,7 @@ def main(args: argparse.Namespace):
...
@@ -344,7 +344,7 @@ def main(args: argparse.Namespace):
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
torch_
dtype
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
dtype
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
use_customized_permute
=
args
.
use_customized_permute
use_customized_permute
=
args
.
use_customized_permute
...
...
benchmarks/kernels/benchmark_mrope.py
View file @
006693ed
...
@@ -39,7 +39,7 @@ import torch
...
@@ -39,7 +39,7 @@ import torch
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.transformers_utils.config
import
get_config
from
vllm.transformers_utils.config
import
get_config
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
...
...
benchmarks/kernels/benchmark_paged_attention.py
View file @
006693ed
...
@@ -3,16 +3,15 @@
...
@@ -3,16 +3,15 @@
import
random
import
random
import
time
import
time
from
typing
import
Optional
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
(
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
,
create_kv_caches_with_random
,
create_kv_caches_with_random
,
)
)
...
@@ -37,7 +36,7 @@ def main(
...
@@ -37,7 +36,7 @@ def main(
seed
:
int
,
seed
:
int
,
do_profile
:
bool
,
do_profile
:
bool
,
device
:
str
=
"cuda"
,
device
:
str
=
"cuda"
,
kv_cache_dtype
:
Optional
[
str
]
=
None
,
kv_cache_dtype
:
str
|
None
=
None
,
)
->
None
:
)
->
None
:
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
...
...
benchmarks/kernels/benchmark_per_token_group_quant.py
View file @
006693ed
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
import
argparse
import
argparse
import
math
import
math
from
collections.abc
import
Callable
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Callable
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
torch
import
torch
...
...
benchmarks/kernels/benchmark_polynorm.py
deleted
100644 → 0
View file @
4b51e6f1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
import
torch
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm.triton_utils
import
triton
def
polynorm_naive
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
):
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
def
norm
(
x
,
eps
:
float
):
return
x
/
torch
.
sqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
eps
)
x
=
x
.
float
()
return
(
(
weight
[
0
]
*
norm
(
x
**
3
,
eps
)
+
weight
[
1
]
*
norm
(
x
**
2
,
eps
)
+
weight
[
2
]
*
norm
(
x
,
eps
)
+
bias
)
.
to
(
weight
.
dtype
)
.
view
(
orig_shape
)
)
def
polynorm_vllm
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
):
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
out
=
torch
.
empty_like
(
x
)
vllm_ops
.
poly_norm
(
out
,
x
,
weight
,
bias
,
eps
)
output
=
out
output
=
output
.
view
(
orig_shape
)
return
output
def
calculate_diff
(
batch_size
,
seq_len
,
hidden_dim
):
dtype
=
torch
.
bfloat16
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_dim
,
dtype
=
dtype
,
device
=
"cuda"
)
weight
=
torch
.
ones
(
3
,
dtype
=
dtype
,
device
=
"cuda"
)
bias
=
torch
.
ones
(
1
,
dtype
=
dtype
,
device
=
"cuda"
)
output_naive
=
polynorm_naive
(
x
,
weight
,
bias
)
output_vllm
=
polynorm_vllm
(
x
,
weight
,
bias
)
if
torch
.
allclose
(
output_naive
,
output_vllm
,
atol
=
1e-2
,
rtol
=
1e-2
):
print
(
"✅ All implementations match"
)
else
:
print
(
"❌ Implementations differ"
)
batch_size_range
=
[
2
**
i
for
i
in
range
(
0
,
7
,
2
)]
seq_length_range
=
[
2
**
i
for
i
in
range
(
6
,
11
,
1
)]
dim_range
=
[
2048
,
4096
]
configs
=
list
(
itertools
.
product
(
dim_range
,
batch_size_range
,
seq_length_range
))
def
get_benchmark
():
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"dim"
,
"batch_size"
,
"seq_len"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
line_arg
=
"provider"
,
line_vals
=
[
"naive"
,
"vllm"
],
line_names
=
[
"Naive"
,
"vLLM"
],
styles
=
[(
"blue"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"polynorm-perf"
,
args
=
{},
)
)
def
benchmark
(
dim
,
batch_size
,
seq_len
,
provider
):
dtype
=
torch
.
bfloat16
hidden_dim
=
dim
*
4
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_dim
,
dtype
=
dtype
,
device
=
"cuda"
)
weight
=
torch
.
ones
(
3
,
dtype
=
dtype
,
device
=
"cuda"
)
bias
=
torch
.
ones
(
1
,
dtype
=
dtype
,
device
=
"cuda"
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"naive"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
polynorm_naive
(
x
,
weight
,
bias
),
quantiles
=
quantiles
,
)
else
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
polynorm_vllm
(
x
,
weight
,
bias
),
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
return
benchmark
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
4
,
help
=
"Batch size"
,
)
parser
.
add_argument
(
"--seq-len"
,
type
=
int
,
default
=
128
,
help
=
"Sequence length"
,
)
parser
.
add_argument
(
"--hidden-dim"
,
type
=
int
,
default
=
8192
,
help
=
"Intermediate size of MLP"
,
)
parser
.
add_argument
(
"--save-path"
,
type
=
str
,
default
=
"./configs/polnorm/"
,
help
=
"Path to save polnorm benchmark results"
,
)
args
=
parser
.
parse_args
()
# Run correctness test
calculate_diff
(
batch_size
=
args
.
batch_size
,
seq_len
=
args
.
seq_len
,
hidden_dim
=
args
.
hidden_dim
,
)
benchmark
=
get_benchmark
()
# Run performance benchmark
benchmark
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
)
benchmarks/kernels/benchmark_quant.py
View file @
006693ed
...
@@ -7,7 +7,8 @@ import torch
...
@@ -7,7 +7,8 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
...
benchmarks/kernels/benchmark_reshape_and_cache.py
0 → 100644
View file @
006693ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
import
time
import
torch
from
tabulate
import
tabulate
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
create_kv_caches_with_random
,
)
logger
=
init_logger
(
__name__
)
@
torch
.
inference_mode
()
def
run_benchmark
(
num_tokens
:
int
,
num_heads
:
int
,
head_size
:
int
,
block_size
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
str
,
num_iters
:
int
,
benchmark_mode
:
str
,
device
:
str
=
"cuda"
,
)
->
float
:
"""Return latency (seconds) for given num_tokens."""
if
kv_cache_dtype
==
"fp8"
and
head_size
%
16
:
raise
ValueError
(
"fp8 kv-cache requires head_size to be a multiple of 16."
)
current_platform
.
seed_everything
(
42
)
torch
.
set_default_device
(
device
)
# create random key / value tensors [T, H, D].
key
=
torch
.
randn
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
device
)
value
=
torch
.
randn_like
(
key
)
# prepare the slot mapping.
# each token is assigned a unique slot in the KV-cache.
num_slots
=
block_size
*
num_blocks
if
num_tokens
>
num_slots
:
raise
ValueError
(
"num_tokens cannot exceed the total number of cache slots"
)
slot_mapping_lst
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping_lst
,
dtype
=
torch
.
long
,
device
=
device
)
key_caches
,
value_caches
=
create_kv_caches_with_random
(
num_blocks
,
block_size
,
1
,
# num_layers
num_heads
,
head_size
,
kv_cache_dtype
,
dtype
,
device
=
device
,
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# to free unused memory
del
key_caches
,
value_caches
# compute per-kernel scaling factors for fp8 conversion (if used).
k_scale
=
(
key
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
v_scale
=
(
value
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
function_under_test
=
lambda
:
ops
.
reshape_and_cache
(
key
,
# noqa: F821
value
,
# noqa: F821
key_cache
,
# noqa: F821
value_cache
,
# noqa: F821
slot_mapping
,
# noqa: F821
kv_cache_dtype
,
k_scale
,
v_scale
,
)
if
benchmark_mode
==
"cudagraph"
:
g
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
g
):
function_under_test
()
torch
.
cuda
.
synchronize
()
function_under_test
=
lambda
:
g
.
replay
()
def
run_cuda_benchmark
(
n_iters
:
int
)
->
float
:
nonlocal
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
torch
.
cuda
.
synchronize
()
start
=
time
.
perf_counter
()
for
_
in
range
(
n_iters
):
function_under_test
()
torch
.
cuda
.
synchronize
()
end
=
time
.
perf_counter
()
return
(
end
-
start
)
/
n_iters
# warm-up
run_cuda_benchmark
(
3
)
lat
=
run_cuda_benchmark
(
num_iters
)
# free tensors to mitigate OOM when sweeping
del
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
torch
.
cuda
.
empty_cache
()
return
lat
def
main
(
args
):
rows
=
[]
for
exp
in
range
(
1
,
17
):
n_tok
=
2
**
exp
lat
=
run_benchmark
(
num_tokens
=
n_tok
,
num_heads
=
args
.
num_heads
,
head_size
=
args
.
head_size
,
block_size
=
args
.
block_size
,
num_blocks
=
args
.
num_blocks
,
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
],
kv_cache_dtype
=
args
.
kv_cache_dtype
,
num_iters
=
args
.
iters
,
benchmark_mode
=
args
.
mode
,
device
=
"cuda"
,
)
rows
.
append
([
n_tok
,
lat
*
1e6
])
# convert to microseconds
print
(
f
"Benchmark results for implementation cuda (measuring with
{
args
.
mode
}
):"
)
print
(
tabulate
(
rows
,
headers
=
[
"num_tokens"
,
"latency (µs)"
],
floatfmt
=
".3f"
))
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
choices
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
],
default
=
128
,
)
parser
.
add_argument
(
"--block-size"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
16
)
parser
.
add_argument
(
"--num-blocks"
,
type
=
int
,
default
=
128
*
128
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"bfloat16"
,
)
parser
.
add_argument
(
"--kv-cache-dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8"
],
default
=
"auto"
,
)
parser
.
add_argument
(
"--iters"
,
type
=
int
,
default
=
200
)
parser
.
add_argument
(
"--mode"
,
type
=
str
,
choices
=
[
"cudagraph"
,
"no_graph"
],
default
=
"cudagraph"
,
)
args
=
parser
.
parse_args
()
main
(
args
)
benchmarks/kernels/benchmark_reshape_and_cache_flash.py
View file @
006693ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
import
random
import
random
import
time
import
time
...
@@ -14,9 +12,9 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import (
...
@@ -14,9 +12,9 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import (
)
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
(
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
,
create_kv_caches_with_random_flash
,
create_kv_caches_with_random_flash
,
)
)
...
...
benchmarks/kernels/benchmark_rmsnorm.py
View file @
006693ed
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
import
itertools
from
typing
import
Optional
,
Union
import
torch
import
torch
from
flashinfer.norm
import
fused_add_rmsnorm
,
rmsnorm
from
flashinfer.norm
import
fused_add_rmsnorm
,
rmsnorm
...
@@ -21,8 +20,8 @@ class HuggingFaceRMSNorm(nn.Module):
...
@@ -21,8 +20,8 @@ class HuggingFaceRMSNorm(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
torch
.
Tensor
|
None
=
None
,
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
]
:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
orig_dtype
=
x
.
dtype
orig_dtype
=
x
.
dtype
x
=
x
.
to
(
torch
.
float32
)
x
=
x
.
to
(
torch
.
float32
)
if
residual
is
not
None
:
if
residual
is
not
None
:
...
@@ -41,7 +40,7 @@ class HuggingFaceRMSNorm(nn.Module):
...
@@ -41,7 +40,7 @@ class HuggingFaceRMSNorm(nn.Module):
def
rmsnorm_naive
(
def
rmsnorm_naive
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
torch
.
Tensor
|
None
=
None
,
eps
:
float
=
1e-6
,
eps
:
float
=
1e-6
,
):
):
naive_norm
=
HuggingFaceRMSNorm
(
x
.
shape
[
-
1
],
eps
=
eps
)
naive_norm
=
HuggingFaceRMSNorm
(
x
.
shape
[
-
1
],
eps
=
eps
)
...
@@ -65,7 +64,7 @@ def rmsnorm_naive(
...
@@ -65,7 +64,7 @@ def rmsnorm_naive(
def
rmsnorm_flashinfer
(
def
rmsnorm_flashinfer
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
torch
.
Tensor
|
None
=
None
,
eps
:
float
=
1e-6
,
eps
:
float
=
1e-6
,
):
):
orig_shape
=
x
.
shape
orig_shape
=
x
.
shape
...
@@ -89,7 +88,7 @@ def rmsnorm_flashinfer(
...
@@ -89,7 +88,7 @@ def rmsnorm_flashinfer(
def
rmsnorm_vllm
(
def
rmsnorm_vllm
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
torch
.
Tensor
|
None
=
None
,
eps
:
float
=
1e-6
,
eps
:
float
=
1e-6
,
):
):
orig_shape
=
x
.
shape
orig_shape
=
x
.
shape
...
...
benchmarks/kernels/benchmark_rope.py
View file @
006693ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
itertools
import
accumulate
import
itertools
from
typing
import
Optional
import
nvtx
import
torch
import
torch
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
,
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.
platforms
import
current_platform
from
vllm.
triton_utils
import
triton
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
batch_size_range
=
[
2
**
i
for
i
in
range
(
0
,
8
,
2
)]
seq_len_range
=
[
2
**
i
for
i
in
range
(
6
,
10
,
1
)]
num_heads_range
=
[
32
,
48
]
configs
=
list
(
itertools
.
product
(
batch_size_range
,
seq_len_range
,
num_heads_range
))
def
benchmark_rope_kernels_multi_lora
(
is_neox_style
:
bool
,
def
get_benchmark
(
head_size
,
rotary_dim
,
is_neox_style
,
device
):
batch_size
:
int
,
@
triton
.
testing
.
perf_report
(
seq_len
:
int
,
triton
.
testing
.
Benchmark
(
num_heads
:
int
,
x_names
=
[
"batch_size"
,
"seq_len"
,
"num_heads"
],
head_size
:
int
,
x_vals
=
[
list
(
_
)
for
_
in
configs
],
rotary_dim
:
Optional
[
int
],
line_arg
=
"provider"
,
dtype
:
torch
.
dtype
,
line_vals
=
[
"torch"
,
"flashinfer"
,
"vllm"
],
seed
:
int
,
line_names
=
[
"PyTorch"
,
"FlashInfer"
,
"vLLM"
],
device
:
str
,
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
)],
max_position
:
int
=
8192
,
ylabel
=
"us"
,
base
:
float
=
10000
,
plot_name
=
f
"rope-perf
{
'-neox-style'
if
is_neox_style
else
''
}
"
,
)
->
None
:
args
=
{},
current_platform
.
seed_everything
(
seed
)
)
torch
.
set_default_device
(
device
)
if
rotary_dim
is
None
:
rotary_dim
=
head_size
# silulating serving 4 LoRAs
scaling_factors
=
[
1
,
2
,
4
,
8
]
# batched RoPE can take multiple scaling factors
batched_rope
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
{
"rope_type"
:
"linear"
,
"factor"
:
tuple
(
scaling_factors
)},
)
)
# non-batched RoPE takes only one scaling factor, we create multiple
def
benchmark
(
batch_size
,
seq_len
,
num_heads
,
provider
):
# instances to simulate the same behavior
dtype
=
torch
.
bfloat16
non_batched_ropes
:
list
[
RotaryEmbedding
]
=
[]
max_position
=
8192
for
scaling_factor
in
scaling_factors
:
base
=
10000
non_batched_ropes
.
append
(
rope
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
)
get_rope
(
rope
=
rope
.
to
(
dtype
=
dtype
,
device
=
device
)
head_size
,
cos_sin_cache
=
rope
.
cos_sin_cache
.
to
(
dtype
=
torch
.
float
,
device
=
device
)
rotary_dim
,
max_position
,
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
),
device
=
device
)
base
,
query
=
torch
.
randn
(
is_neox_style
,
(
batch_size
,
seq_len
,
num_heads
*
head_size
),
dtype
=
dtype
,
device
=
device
{
"rope_type"
:
"linear"
,
"factor"
:
(
scaling_factor
,)},
)
)
)
key
=
torch
.
randn_like
(
query
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
))
quantiles
=
[
0.5
,
0.2
,
0.8
]
query
=
torch
.
randn
(
batch_size
,
seq_len
,
num_heads
*
head_size
,
dtype
=
dtype
)
key
=
torch
.
randn_like
(
query
)
# create query offsets for batched RoPE, we concat multiple kv cache
if
provider
==
"torch"
:
# together and each query needs to find the right kv cache of its type
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
offset_map
=
torch
.
tensor
(
lambda
:
rope
.
forward_native
(
positions
,
query
.
clone
(),
key
.
clone
()),
list
(
quantiles
=
quantiles
,
accumulate
(
[
0
]
+
[
max_position
*
scaling_factor
*
2
for
scaling_factor
in
scaling_factors
[:
-
1
]
]
)
)
)
elif
provider
==
"flashinfer"
:
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
query_types
=
torch
.
randint
(
lambda
:
torch
.
ops
.
vllm
.
flashinfer_rotary_embedding
(
0
,
len
(
scaling_factors
),
(
batch_size
,
seq_len
),
device
=
device
positions
,
)
query
.
clone
(),
# map query types to offsets
key
.
clone
(),
query_offsets
=
offset_map
[
query_types
]
head_size
,
# the kernel takes flattened offsets
cos_sin_cache
,
flatten_offsets
=
query_offsets
.
flatten
()
is_neox_style
,
),
quantiles
=
quantiles
,
)
else
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
rope
.
forward_cuda
(
positions
,
query
.
clone
(),
key
.
clone
()),
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
# batched queries of the same type together for non-batched RoPE
return
benchmark
queries
=
[
query
[
query_types
==
i
]
for
i
in
range
(
len
(
scaling_factors
))]
keys
=
[
key
[
query_types
==
i
]
for
i
in
range
(
len
(
scaling_factors
))]
packed_qkr
=
zip
(
queries
,
keys
,
non_batched_ropes
)
# synchronize before start timing
torch
.
cuda
.
synchronize
()
with
nvtx
.
annotate
(
"non-batched"
,
color
=
"yellow"
):
for
q
,
k
,
r
in
packed_qkr
:
r
.
forward
(
positions
,
q
,
k
)
torch
.
cuda
.
synchronize
()
with
nvtx
.
annotate
(
"batched"
,
color
=
"green"
):
batched_rope
.
forward
(
positions
,
query
,
key
,
flatten_offsets
)
torch
.
cuda
.
synchronize
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
@@ -117,17 +95,12 @@ if __name__ == "__main__":
...
@@ -117,17 +95,12 @@ if __name__ == "__main__":
parser
.
add_argument
(
parser
.
add_argument
(
"--device"
,
type
=
str
,
choices
=
[
"cuda:0"
,
"cuda:1"
],
default
=
"cuda:0"
"--device"
,
type
=
str
,
choices
=
[
"cuda:0"
,
"cuda:1"
],
default
=
"cuda:0"
)
)
parser
.
add_argument
(
"--save-path"
,
type
=
str
,
default
=
"./configs/rope/"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
args
)
benchmark_rope_kernels_multi_lora
(
# Get the benchmark function
is_neox_style
=
args
.
is_neox_style
,
benchmark
=
get_benchmark
(
batch_size
=
args
.
batch_size
,
args
.
head_size
,
args
.
rotary_dim
,
args
.
is_neox_style
,
args
.
device
seq_len
=
args
.
seq_len
,
num_heads
=
args
.
num_heads
,
head_size
=
args
.
head_size
,
rotary_dim
=
args
.
rotary_dim
,
dtype
=
getattr
(
torch
,
args
.
dtype
),
seed
=
args
.
seed
,
device
=
args
.
device
,
)
)
# Run performance benchmark
benchmark
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
)
benchmarks/kernels/benchmark_shapes.py
View file @
006693ed
...
@@ -78,11 +78,11 @@ WEIGHT_SHAPES = {
...
@@ -78,11 +78,11 @@ WEIGHT_SHAPES = {
}
}
WEIGHT_SHAPES_MOE
=
{
WEIGHT_SHAPES_MOE
=
{
"
nm-testing
/Mixtral-8x7B-Instruct-v0.1"
:
[
"
mistralai
/Mixtral-8x7B-Instruct-v0.1"
:
[
[
8
,
2
,
4096
,
28672
],
[
8
,
2
,
4096
,
28672
],
[
8
,
2
,
14336
,
4096
],
[
8
,
2
,
14336
,
4096
],
],
],
"
nm-testing/d
eep
s
eek
v
2-
l
ite"
:
[
"
deepseek-ai/D
eep
S
eek
-V
2-
L
ite"
:
[
[
64
,
6
,
2048
,
1408
],
[
64
,
6
,
2048
,
1408
],
],
],
"ibm-granite/granite-3.0-1b-a400m"
:
[
"ibm-granite/granite-3.0-1b-a400m"
:
[
...
...
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
View file @
006693ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Comprehensive 3-way SiLU Benchmark Suite
This benchmark compares three SiLU implementations:
1. SiLU V2 (CUDA) - Optimized CUDA kernel implementation
2. Triton Kernel - Triton-based implementation
The suite generates detailed performance comparisons including:
- Memory bandwidth utilization
- Speedup ratios (baseline vs optimized implementations)
- Performance across different expert configurations and token distributions
"""
from
collections.abc
import
Callable
from
collections.abc
import
Callable
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
...
@@ -7,7 +21,7 @@ import numpy as np
...
@@ -7,7 +21,7 @@ import numpy as np
import
torch
import
torch
from
vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe
import
(
from
vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe
import
(
silu_mul_
fp8_
quant
_deep_gemm_cuda
,
persistent_masked_m_
silu_mul_quant
,
)
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
...
@@ -94,6 +108,7 @@ def silu_mul_fp8_quant_deep_gemm_triton(
...
@@ -94,6 +108,7 @@ def silu_mul_fp8_quant_deep_gemm_triton(
num_parallel_tokens
,
num_parallel_tokens
,
group_size
:
int
=
128
,
group_size
:
int
=
128
,
eps
:
float
=
1e-10
,
eps
:
float
=
1e-10
,
expert_offsets
:
torch
.
Tensor
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
...
@@ -174,7 +189,7 @@ def silu_mul_fp8_quant_deep_gemm_triton(
...
@@ -174,7 +189,7 @@ def silu_mul_fp8_quant_deep_gemm_triton(
# Parse generation strategies
# Parse generation strategies
strategies
=
[
"uniform"
,
"max_t"
,
"first_t"
]
strategies
=
[
"random_imbalanced"
,
"uniform"
,
"max_t"
]
def
benchmark
(
def
benchmark
(
...
@@ -195,15 +210,27 @@ def benchmark(
...
@@ -195,15 +210,27 @@ def benchmark(
current_platform
.
seed_everything
(
42
+
seed_offset
)
current_platform
.
seed_everything
(
42
+
seed_offset
)
y
=
torch
.
rand
((
E
,
T
,
2
*
H
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
y
=
torch
.
rand
((
E
,
T
,
2
*
H
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
if
gen_strategy
==
"uniform"
:
if
gen_strategy
==
"random_imbalanced"
:
r
=
torch
.
rand
(
size
=
(
E
,),
device
=
"cuda"
)
def
generate_expert_loads
(
n_e
,
total_tokens
,
ratio
,
device
=
"cuda"
):
mean
=
total_tokens
//
n_e
min_max
=
mean
//
ratio
e
=
torch
.
ones
(
size
=
(
E
,),
dtype
=
torch
.
int64
,
device
=
device
)
*
mean
e
[
0
]
=
min_max
r
=
torch
.
rand
(
size
=
(
E
-
1
,))
r
/=
r
.
sum
()
r
*=
total_tokens
-
min_max
r
=
r
.
round
().
long
()
e
[
1
:]
=
r
.
to
(
device
=
device
)
return
e
tokens_per_expert
=
generate_expert_loads
(
E
,
total_tokens
,
0.7
,
"cuda"
)
elif
gen_strategy
==
"uniform"
:
r
=
torch
.
rand
(
size
=
(
E
,))
r
/=
r
.
sum
()
r
/=
r
.
sum
()
r
*=
total_tokens
r
*=
total_tokens
tokens_per_expert
=
r
.
int
()
r
=
r
.
round
().
long
()
tokens_per_expert
=
torch
.
minimum
(
tokens_per_expert
=
r
tokens_per_expert
,
torch
.
ones
((
E
,),
device
=
r
.
device
,
dtype
=
torch
.
int
)
*
T
,
)
elif
gen_strategy
==
"max_t"
:
elif
gen_strategy
==
"max_t"
:
tokens_per_expert
=
torch
.
empty
(
size
=
(
E
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
tokens_per_expert
=
torch
.
empty
(
size
=
(
E
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
tokens_per_expert
.
fill_
(
total_tokens
/
E
)
tokens_per_expert
.
fill_
(
total_tokens
/
E
)
...
@@ -281,40 +308,34 @@ def benchmark(
...
@@ -281,40 +308,34 @@ def benchmark(
def
create_comparison_plot
(
def
create_comparison_plot
(
ratio
,
cuda
_times
,
baseline
_times
,
config_labels
,
strategy_name
,
id
ratio
s
,
silu_v2
_times
,
triton
_times
,
config_labels
,
strategy_name
,
id
):
):
"""Create a comparison plot for a specific generation strategy"""
fig
,
ax
=
plt
.
subplots
(
1
,
1
,
figsize
=
(
18
,
6
))
fig
,
ax
=
plt
.
subplots
(
1
,
1
,
figsize
=
(
16
,
6
))
# Configure x-axis positions
# Configure x-axis positions
x
=
np
.
arange
(
len
(
config_labels
))
x
=
np
.
arange
(
len
(
config_labels
))
width
=
0.
3
5
width
=
0.
2
5
# Execution Time plot (lower is better)
# Execution Time plot (lower is better)
ax
.
bar
(
x
,
silu_v2_times
,
width
,
label
=
"SiLU V2 (CUDA)"
,
alpha
=
0.8
,
color
=
"blue"
)
ax
.
bar
(
ax
.
bar
(
x
-
width
/
2
,
cuda_times
,
width
,
label
=
"CUDA Kernel"
,
alpha
=
0.8
,
color
=
"blue"
x
+
width
,
triton_times
,
width
,
label
=
"Triton Kernel"
,
alpha
=
0.8
,
color
=
"green"
)
ax
.
bar
(
x
+
width
/
2
,
baseline_times
,
width
,
label
=
"Baseline"
,
alpha
=
0.8
,
color
=
"orange"
,
)
)
# Add speedup labels over each bar
pair
# Add speedup labels over each bar
trio
for
i
in
range
(
len
(
x
)):
for
i
in
range
(
len
(
x
)):
speedup
=
ratio
[
i
]
triton_v2_speedup
=
ratios
[
i
][
1
]
# triton/v2
max_height
=
max
(
cuda_times
[
i
],
baseline_times
[
i
])
max_height
=
max
(
silu_v2_times
[
i
],
triton_times
[
i
])
# Triton/V2 speedup
ax
.
text
(
ax
.
text
(
x
[
i
],
x
[
i
]
+
width
/
2
,
max_height
+
max_height
*
0.02
,
max_height
+
max_height
*
0.02
,
f
"
{
speedup
:.
2
f
}
x"
,
f
"
{
triton_v2_
speedup
:.
2
f
}
x"
,
ha
=
"center"
,
ha
=
"center"
,
va
=
"bottom"
,
va
=
"bottom"
,
fontweight
=
"bold"
,
fontweight
=
"bold"
,
fontsize
=
9
,
fontsize
=
8
,
)
)
ax
.
set_xlabel
(
"Configuration"
)
ax
.
set_xlabel
(
"Configuration"
)
...
@@ -332,56 +353,75 @@ def create_comparison_plot(
...
@@ -332,56 +353,75 @@ def create_comparison_plot(
def
create_combined_plot
(
all_results
):
def
create_combined_plot
(
all_results
):
"""Create a combined plot with all strategies in one PNG"""
num_strategies
=
len
(
all_results
)
num_strategies
=
len
(
all_results
)
fig
,
axes
=
plt
.
subplots
(
num_strategies
,
1
,
figsize
=
(
2
0
,
6
*
num_strategies
))
fig
,
axes
=
plt
.
subplots
(
num_strategies
,
1
,
figsize
=
(
2
2
,
7
*
num_strategies
))
if
num_strategies
==
1
:
if
num_strategies
==
1
:
axes
=
[
axes
]
axes
=
[
axes
]
for
idx
,
(
for
idx
,
(
strategy_name
,
strategy_name
,
ratio
,
all_
ratio
s
,
cuda_time
s
,
all_silu_v2_result
s
,
baseline_time
s
,
all_triton_result
s
,
config_labels
,
config_labels
,
config_x_axis
,
)
in
enumerate
(
all_results
):
)
in
enumerate
(
all_results
):
ax
=
axes
[
idx
]
ax
=
axes
[
idx
]
# Flatten the nested results to get bandwidth percentages for plotting
silu_v2_bandwidths
=
[]
triton_bandwidths
=
[]
flat_ratios
=
[]
for
config_results
in
all_silu_v2_results
:
for
result
in
config_results
:
silu_v2_bandwidths
.
append
(
result
[
3
])
# bandwidth percentage
for
config_results
in
all_triton_results
:
for
result
in
config_results
:
triton_bandwidths
.
append
(
result
[
3
])
# bandwidth percentage
for
config_ratios
in
all_ratios
:
for
ratio
in
config_ratios
:
flat_ratios
.
append
(
ratio
)
# Configure x-axis positions
# Configure x-axis positions
x
=
np
.
arange
(
len
(
config_labels
))
x
=
np
.
arange
(
len
(
config_labels
))
width
=
0.
3
5
width
=
0.
2
5
#
Execution Time
plot (
low
er is better)
#
Bandwidth utilization
plot (
high
er is better)
ax
.
bar
(
ax
.
bar
(
x
-
width
/
2
,
x
,
cuda_time
s
,
silu_v2_bandwidth
s
,
width
,
width
,
label
=
"
CUDA Kernel
"
,
label
=
"
SiLU V2 (CUDA)
"
,
alpha
=
0.8
,
alpha
=
0.8
,
color
=
"blue"
,
color
=
"blue"
,
)
)
ax
.
bar
(
ax
.
bar
(
x
+
width
/
2
,
x
+
width
,
baseline_time
s
,
triton_bandwidth
s
,
width
,
width
,
label
=
"
Baseli
ne"
,
label
=
"
Triton Ker
ne
l
"
,
alpha
=
0.8
,
alpha
=
0.8
,
color
=
"
orange
"
,
color
=
"
green
"
,
)
)
# Add speedup labels over each bar
pair
# Add speedup labels over each bar
trio
for
i
in
range
(
len
(
x
)):
for
i
in
range
(
len
(
x
)):
speedup
=
ratio
[
i
]
triton_v2_speedup
=
flat_ratios
[
i
]
# triton/v2
max_height
=
max
(
cuda_times
[
i
],
baseline_times
[
i
])
max_height
=
max
(
silu_v2_bandwidths
[
i
],
triton_bandwidths
[
i
])
# Triton/V2 speedup
ax
.
text
(
ax
.
text
(
x
[
i
],
x
[
i
]
+
width
/
2
,
max_height
+
max_height
*
0.02
,
max_height
+
max_height
*
0.02
,
f
"
{
speedup
:.
2
f
}
x"
,
f
"
{
triton_v2_
speedup
:.
2
f
}
x"
,
ha
=
"center"
,
ha
=
"center"
,
va
=
"bottom"
,
va
=
"bottom"
,
fontweight
=
"bold"
,
fontweight
=
"bold"
,
fontsize
=
9
,
fontsize
=
8
,
)
)
ax
.
set_xlabel
(
"Configuration"
)
ax
.
set_xlabel
(
"Configuration"
)
...
@@ -395,7 +435,7 @@ def create_combined_plot(all_results):
...
@@ -395,7 +435,7 @@ def create_combined_plot(all_results):
ax
.
grid
(
True
,
alpha
=
0.3
)
ax
.
grid
(
True
,
alpha
=
0.3
)
plt
.
tight_layout
()
plt
.
tight_layout
()
filename
=
"
../../silu_bench/
silu_benchmark_combined.png"
filename
=
"silu_benchmark_combined
_3way
.png"
plt
.
savefig
(
filename
,
dpi
=
300
,
bbox_inches
=
"tight"
)
plt
.
savefig
(
filename
,
dpi
=
300
,
bbox_inches
=
"tight"
)
plt
.
show
()
plt
.
show
()
...
@@ -405,7 +445,9 @@ def create_combined_plot(all_results):
...
@@ -405,7 +445,9 @@ def create_combined_plot(all_results):
outer_dim
=
7168
outer_dim
=
7168
configs
=
[
configs
=
[
# DeepSeekV3 Configs
# DeepSeekV3 Configs
# (1, 56, 7168),
(
8
,
1024
,
7168
),
(
8
,
1024
,
7168
),
# (32, 56, 7168),
# DeepSeekV3 Configs
# DeepSeekV3 Configs
(
32
,
1024
,
7168
),
(
32
,
1024
,
7168
),
# DeepSeekV3 Configs
# DeepSeekV3 Configs
...
@@ -417,6 +459,7 @@ num_warmups = 20
...
@@ -417,6 +459,7 @@ num_warmups = 20
strategy_descriptions
=
{
strategy_descriptions
=
{
"uniform"
:
"Uniform Random"
,
"uniform"
:
"Uniform Random"
,
"random_imbalanced"
:
"Imbalanced Random"
,
"max_t"
:
"Even Assignment"
,
"max_t"
:
"Even Assignment"
,
"first_t"
:
"experts[0] = T, experts[1:] = 0"
,
"first_t"
:
"experts[0] = T, experts[1:] = 0"
,
}
}
...
@@ -433,28 +476,31 @@ for id, strategy in enumerate(strategies):
...
@@ -433,28 +476,31 @@ for id, strategy in enumerate(strategies):
print
(
f
"Testing strategy:
{
strategy_descriptions
[
strategy
]
}
"
)
print
(
f
"Testing strategy:
{
strategy_descriptions
[
strategy
]
}
"
)
print
(
f
"
{
'='
*
60
}
"
)
print
(
f
"
{
'='
*
60
}
"
)
# Collect benchmark data for
both
algorithms
# Collect benchmark data for
all three
algorithms
config_labels
=
[]
config_labels
=
[]
config_x_axis
=
[]
config_x_axis
=
[]
all_
cuda
_results
=
[]
all_
silu_v2
_results
=
[]
all_
baseline
_results
=
[]
all_
triton
_results
=
[]
all_ratios
=
[]
all_ratios
=
[]
for
E
,
T
,
H
in
configs
:
for
E
,
T
,
H
in
configs
:
total_tokens_config
=
[
8
*
E
,
16
*
E
,
32
*
E
,
64
*
E
,
128
*
E
,
256
*
E
]
total_tokens_config
=
[]
for
i
in
[
8
,
16
,
32
,
64
,
128
,
256
,
512
]:
if
i
<=
T
:
total_tokens_config
.
append
(
i
*
E
)
config_x_axis
.
append
(
total_tokens_config
)
config_x_axis
.
append
(
total_tokens_config
)
cuda
_results
=
[]
silu_v2
_results
=
[]
baseline
_results
=
[]
triton
_results
=
[]
ratios
=
[]
ratios
=
[]
for
total_tokens
in
total_tokens_config
:
for
total_tokens
in
total_tokens_config
:
config_label
=
f
"E=
{
E
}
,T=
{
T
}
,H=
{
H
}
,TT=
{
total_tokens
}
"
config_label
=
f
"E=
{
E
}
,T=
{
T
}
,H=
{
H
}
,TT=
{
total_tokens
}
"
config_labels
.
append
(
config_label
)
config_labels
.
append
(
config_label
)
# CUDA kernel results
#
SiLU V2 (
CUDA kernel
)
results
time_ms_
cuda
,
gflops
,
gbps
,
perc
=
benchmark
(
time_ms_
silu_v2
,
gflops
,
gbps
,
perc
=
benchmark
(
silu_mul_
fp8_
quant
_deep_gemm_cuda
,
persistent_masked_m_
silu_mul_quant
,
E
,
E
,
T
,
T
,
H
,
H
,
...
@@ -463,9 +509,9 @@ for id, strategy in enumerate(strategies):
...
@@ -463,9 +509,9 @@ for id, strategy in enumerate(strategies):
num_warmups
=
num_warmups
,
num_warmups
=
num_warmups
,
gen_strategy
=
strategy
,
gen_strategy
=
strategy
,
)
)
cuda
_results
.
append
((
time_ms_
cuda
,
gflops
,
gbps
,
perc
))
silu_v2
_results
.
append
((
time_ms_
silu_v2
,
gflops
,
gbps
,
perc
))
#
Baseli
ne results
#
Triton ker
ne
l
results
time_ms_triton
,
gflops
,
gbps
,
perc
=
benchmark
(
time_ms_triton
,
gflops
,
gbps
,
perc
=
benchmark
(
silu_mul_fp8_quant_deep_gemm_triton
,
silu_mul_fp8_quant_deep_gemm_triton
,
E
,
E
,
...
@@ -476,12 +522,20 @@ for id, strategy in enumerate(strategies):
...
@@ -476,12 +522,20 @@ for id, strategy in enumerate(strategies):
num_warmups
=
num_warmups
,
num_warmups
=
num_warmups
,
gen_strategy
=
strategy
,
gen_strategy
=
strategy
,
)
)
baseline_results
.
append
((
time_ms_triton
,
gflops
,
gbps
,
perc
))
triton_results
.
append
((
time_ms_triton
,
gflops
,
gbps
,
perc
))
ratios
.
append
(
time_ms_triton
/
time_ms_cuda
)
print
(
f
"Completed:
{
config_label
}
"
)
# Calculate speedup ratios (triton baseline / implementation)
all_cuda_results
.
append
(
cuda_results
)
triton_v2_ratio
=
time_ms_triton
/
time_ms_silu_v2
all_baseline_results
.
append
(
baseline_results
)
ratios
.
append
(
triton_v2_ratio
)
print
(
f
"Completed:
{
config_label
}
:"
f
" V2:
{
time_ms_silu_v2
:.
3
f
}
ms,"
f
" Triton:
{
time_ms_triton
:.
3
f
}
ms"
)
all_silu_v2_results
.
append
(
silu_v2_results
)
all_triton_results
.
append
(
triton_results
)
all_ratios
.
append
(
ratios
)
all_ratios
.
append
(
ratios
)
# Store results for combined plotting
# Store results for combined plotting
...
@@ -489,8 +543,8 @@ for id, strategy in enumerate(strategies):
...
@@ -489,8 +543,8 @@ for id, strategy in enumerate(strategies):
(
(
strategy_descriptions
[
strategy
],
strategy_descriptions
[
strategy
],
all_ratios
,
all_ratios
,
all_
cuda
_results
,
all_
silu_v2
_results
,
all_
baseline
_results
,
all_
triton
_results
,
config_labels
,
config_labels
,
config_x_axis
,
config_x_axis
,
)
)
...
@@ -498,15 +552,18 @@ for id, strategy in enumerate(strategies):
...
@@ -498,15 +552,18 @@ for id, strategy in enumerate(strategies):
# Print summary table for this strategy
# Print summary table for this strategy
print
(
f
"
\n
Summary Table -
{
strategy_descriptions
[
strategy
]
}
:"
)
print
(
f
"
\n
Summary Table -
{
strategy_descriptions
[
strategy
]
}
:"
)
print
(
f
"
{
'Config'
:
<
20
}
{
'CUDA
Time(ms)'
:
<
12
}
{
'
Base
Time(ms)'
:
<
1
2
}
{
'
Speedup
'
:
<
8
}
"
)
print
(
f
"
{
'V2
Time(ms)'
:
<
12
}
{
'
Triton
Time(ms)'
:
<
1
4
}
{
'
Triton/V2
'
:
<
10
}
"
)
print
(
"-"
*
6
0
)
print
(
"-"
*
9
0
)
for
i
,
(
E
,
T
,
H
)
in
enumerate
(
configs
):
for
i
,
(
E
,
T
,
H
)
in
enumerate
(
configs
):
speedup
=
baseline_results
[
i
][
0
]
/
cuda_results
[
i
][
0
]
# Get the first result for each config (simplifying for summary)
v2_time
=
silu_v2_results
[
i
][
0
]
triton_time
=
triton_results
[
i
][
0
]
triton_v2_speedup
=
triton_time
/
v2_time
config_label
=
f
"E=
{
E
:
3
d
}
,T=
{
T
:
4
d
}
,H=
{
H
:
4
d
}
"
config_label
=
f
"E=
{
E
:
3
d
}
,T=
{
T
:
4
d
}
,H=
{
H
:
4
d
}
"
print
(
print
(
f
"
{
config_label
:
<
20
}
{
cuda_results
[
i
][
0
]:
8
.5
f
}
"
f
"
{
config_label
:
<
20
}
{
v2_time
:
8.5
f
}
{
triton_time
:
10
.5
f
}
"
f
"
{
baseline_results
[
i
][
0
]:
8.5
f
}
{
speedup
:
6
.2
f
}
x"
f
"
{
triton_v2_
speedup
:
8
.2
f
}
x"
)
)
...
@@ -514,15 +571,14 @@ def create_total_tokens_plot(all_results):
...
@@ -514,15 +571,14 @@ def create_total_tokens_plot(all_results):
num_strategies
=
len
(
all_results
)
num_strategies
=
len
(
all_results
)
num_configs
=
len
(
configs
)
num_configs
=
len
(
configs
)
# Create side-by-side subplots: 2 columns for speedup and bandwidth percentage
fig
,
axs
=
plt
.
subplots
(
fig
,
axs
=
plt
.
subplots
(
num_strategies
,
num_configs
*
2
,
figsize
=
(
2
8
,
6
*
num_strategies
)
num_strategies
,
num_configs
*
2
,
figsize
=
(
3
2
,
8
*
num_strategies
)
)
)
# Add main title to the entire figure
# Add main title to the entire figure
fig
.
suptitle
(
fig
.
suptitle
(
"Performance Analysis: Speedup vs Bandwidth Utilization (
Triton & CUDA
)"
,
"Performance Analysis: Speedup vs Bandwidth Utilization (
SiLU V2, and Triton
)"
,
fontsize
=
1
6
,
fontsize
=
1
8
,
fontweight
=
"bold"
,
fontweight
=
"bold"
,
y
=
0.98
,
y
=
0.98
,
)
)
...
@@ -539,8 +595,8 @@ def create_total_tokens_plot(all_results):
...
@@ -539,8 +595,8 @@ def create_total_tokens_plot(all_results):
(
(
strategy_name
,
strategy_name
,
all_ratios
,
all_ratios
,
all_
cuda
_results
,
all_
silu_v2
_results
,
all_
baseline
_results
,
all_
triton
_results
,
config_labels
,
config_labels
,
config_x_axis
,
config_x_axis
,
)
=
result
)
=
result
...
@@ -555,42 +611,54 @@ def create_total_tokens_plot(all_results):
...
@@ -555,42 +611,54 @@ def create_total_tokens_plot(all_results):
ratios
=
all_ratios
[
config_idx
]
ratios
=
all_ratios
[
config_idx
]
total_tokens_values
=
config_x_axis
[
config_idx
]
total_tokens_values
=
config_x_axis
[
config_idx
]
# Extract CUDA and Triton bandwidth percentages
# Extract speedup ratios
cuda_bandwidth_percentages
=
[
triton_v2_ratios
=
[
ratio
for
ratio
in
ratios
]
result
[
3
]
for
result
in
all_cuda_results
[
config_idx
]
# Extract bandwidth percentages for all implementations
v2_bandwidth_percentages
=
[
result
[
3
]
for
result
in
all_silu_v2_results
[
config_idx
]
]
]
triton_bandwidth_percentages
=
[
triton_bandwidth_percentages
=
[
result
[
3
]
for
result
in
all_
baseline
_results
[
config_idx
]
result
[
3
]
for
result
in
all_
triton
_results
[
config_idx
]
]
]
# Plot speedup ratios vs total tokens (left plot)
# Plot speedup ratios vs total tokens (left plot)
ax_speedup
.
plot
(
ax_speedup
.
plot
(
total_tokens_values
,
ratios
,
"bo-"
,
linewidth
=
3
,
markersize
=
8
total_tokens_values
,
triton_v2_ratios
,
"go-"
,
linewidth
=
3
,
markersize
=
8
,
label
=
"Triton/V2 Speedup"
,
)
)
ax_speedup
.
set_title
(
ax_speedup
.
set_title
(
f
"
{
strategy_name
}
\n
Speedup
(CUDA/
Triton)
\n
E=
{
E
}
, T=
{
T
}
, H=
{
H
}
"
,
f
"
{
strategy_name
}
\n
Speedup
vs Baseline (
Triton)
\n
E=
{
E
}
, T=
{
T
}
, H=
{
H
}
"
,
fontsize
=
12
,
fontsize
=
12
,
fontweight
=
"bold"
,
fontweight
=
"bold"
,
)
)
ax_speedup
.
set_xlabel
(
"Total Tokens"
,
fontweight
=
"bold"
,
fontsize
=
11
)
ax_speedup
.
set_xlabel
(
"Total Tokens"
,
fontweight
=
"bold"
,
fontsize
=
11
)
ax_speedup
.
set_ylabel
(
"Speedup Ratio"
,
fontweight
=
"bold"
,
fontsize
=
11
)
ax_speedup
.
set_ylabel
(
"Speedup Ratio"
,
fontweight
=
"bold"
,
fontsize
=
11
)
ax_speedup
.
legend
(
prop
=
{
"weight"
:
"bold"
})
ax_speedup
.
grid
(
True
,
alpha
=
0.3
)
ax_speedup
.
grid
(
True
,
alpha
=
0.3
)
# Plot bandwidth utilization (right plot)
ax_bandwidth
.
plot
(
ax_bandwidth
.
plot
(
total_tokens_values
,
total_tokens_values
,
cuda
_bandwidth_percentages
,
v2
_bandwidth_percentages
,
"
r
o-"
,
"o-"
,
linewidth
=
3
,
linewidth
=
3
,
markersize
=
8
,
markersize
=
8
,
label
=
"CUDA"
,
label
=
"SiLU V2"
,
color
=
"blue"
,
)
)
ax_bandwidth
.
plot
(
ax_bandwidth
.
plot
(
total_tokens_values
,
total_tokens_values
,
triton_bandwidth_percentages
,
triton_bandwidth_percentages
,
"
g
o-"
,
"o-"
,
linewidth
=
3
,
linewidth
=
3
,
markersize
=
8
,
markersize
=
8
,
label
=
"Triton"
,
label
=
"Triton"
,
color
=
"green"
,
)
)
ax_bandwidth
.
set_title
(
ax_bandwidth
.
set_title
(
f
"
{
strategy_name
}
\n
Bandwidth Utilization (Hopper)
\n
E=
{
E
}
, T=
{
T
}
, H=
{
H
}
"
,
f
"
{
strategy_name
}
\n
Bandwidth Utilization (Hopper)
\n
E=
{
E
}
, T=
{
T
}
, H=
{
H
}
"
,
...
@@ -618,38 +686,12 @@ def create_total_tokens_plot(all_results):
...
@@ -618,38 +686,12 @@ def create_total_tokens_plot(all_results):
for
label
in
ax
.
get_xticklabels
()
+
ax
.
get_yticklabels
():
for
label
in
ax
.
get_xticklabels
()
+
ax
.
get_yticklabels
():
label
.
set_fontweight
(
"bold"
)
label
.
set_fontweight
(
"bold"
)
# Add value labels on speedup points
# Add value labels on
Triton/V2
speedup points
for
x
,
y
in
zip
(
total_tokens_values
,
ratios
):
for
x
,
y
in
zip
(
total_tokens_values
,
triton_v2_
ratios
):
ax_speedup
.
annotate
(
ax_speedup
.
annotate
(
f
"
{
y
:.
2
f
}
x"
,
f
"
{
y
:.
2
f
}
x"
,
(
x
,
y
),
(
x
,
y
),
textcoords
=
"offset points"
,
textcoords
=
"offset points"
,
xytext
=
(
0
,
12
),
ha
=
"center"
,
fontsize
=
10
,
fontweight
=
"bold"
,
bbox
=
dict
(
boxstyle
=
"round,pad=0.3"
,
facecolor
=
"white"
,
alpha
=
0.7
),
)
# Add value labels on CUDA bandwidth points
for
x
,
y
in
zip
(
total_tokens_values
,
cuda_bandwidth_percentages
):
ax_bandwidth
.
annotate
(
f
"
{
y
:.
1
f
}
%"
,
(
x
,
y
),
textcoords
=
"offset points"
,
xytext
=
(
0
,
12
),
ha
=
"center"
,
fontsize
=
9
,
fontweight
=
"bold"
,
bbox
=
dict
(
boxstyle
=
"round,pad=0.2"
,
facecolor
=
"red"
,
alpha
=
0.3
),
)
# Add value labels on Triton bandwidth points
for
x
,
y
in
zip
(
total_tokens_values
,
triton_bandwidth_percentages
):
ax_bandwidth
.
annotate
(
f
"
{
y
:.
1
f
}
%"
,
(
x
,
y
),
textcoords
=
"offset points"
,
xytext
=
(
0
,
-
15
),
xytext
=
(
0
,
-
15
),
ha
=
"center"
,
ha
=
"center"
,
fontsize
=
9
,
fontsize
=
9
,
...
@@ -659,17 +701,20 @@ def create_total_tokens_plot(all_results):
...
@@ -659,17 +701,20 @@ def create_total_tokens_plot(all_results):
plt
.
tight_layout
()
plt
.
tight_layout
()
plt
.
subplots_adjust
(
top
=
0.93
)
# Make room for main title
plt
.
subplots_adjust
(
top
=
0.93
)
# Make room for main title
filename
=
"silu_benchmark_total_tokens.png"
filename
=
"silu_benchmark_total_tokens
_3way
.png"
plt
.
savefig
(
filename
,
dpi
=
300
,
bbox_inches
=
"tight"
)
plt
.
savefig
(
filename
,
dpi
=
300
,
bbox_inches
=
"tight"
)
plt
.
show
()
plt
.
show
()
return
filename
return
filename
# Create combined plot with all strategies
# Create comprehensive 3-way comparison plots
combined_plot_filename
=
create_total_tokens_plot
(
all_results
)
combined_plot_filename
=
create_combined_plot
(
all_results
)
total_tokens_plot_filename
=
create_total_tokens_plot
(
all_results
)
print
(
f
"
\n
{
'='
*
60
}
"
)
print
(
f
"
\n
{
'='
*
80
}
"
)
print
(
"Benchmark Complete!"
)
print
(
"3-Way Benchmark Suite Complete!"
)
print
(
f
"Generated combined plot:
{
combined_plot_filename
}
"
)
print
(
f
"Generated combined comparison plot:
{
combined_plot_filename
}
"
)
print
(
f
"
{
'='
*
60
}
"
)
print
(
f
"Generated total tokens analysis plot:
{
total_tokens_plot_filename
}
"
)
print
(
"Compared: SiLU V2 (CUDA), and Triton implementations"
)
print
(
f
"
{
'='
*
80
}
"
)
benchmarks/kernels/benchmark_trtllm_decode_attention.py
View file @
006693ed
...
@@ -4,12 +4,11 @@
...
@@ -4,12 +4,11 @@
import
csv
import
csv
import
os
import
os
from
datetime
import
datetime
from
datetime
import
datetime
from
typing
import
Optional
import
flashinfer
import
flashinfer
import
torch
import
torch
from
vllm.utils
import
round_up
from
vllm.utils
.math_utils
import
round_up
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
FP8_DTYPE
=
torch
.
float8_e4m3fn
FP8_DTYPE
=
torch
.
float8_e4m3fn
...
@@ -28,9 +27,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
...
@@ -28,9 +27,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
benchmark_decode
(
def
benchmark_decode
(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
quant_dtypes
:
tuple
[
quant_dtypes
:
tuple
[
torch
.
dtype
|
None
,
torch
.
dtype
|
None
,
torch
.
dtype
|
None
],
Optional
[
torch
.
dtype
],
Optional
[
torch
.
dtype
],
Optional
[
torch
.
dtype
]
],
batch_size
:
int
,
batch_size
:
int
,
max_seq_len
:
int
,
max_seq_len
:
int
,
num_heads
:
tuple
[
int
,
int
]
=
(
64
,
8
),
num_heads
:
tuple
[
int
,
int
]
=
(
64
,
8
),
...
...
benchmarks/kernels/benchmark_trtllm_prefill_attention.py
View file @
006693ed
...
@@ -4,12 +4,11 @@
...
@@ -4,12 +4,11 @@
import
csv
import
csv
import
os
import
os
from
datetime
import
datetime
from
datetime
import
datetime
from
typing
import
Optional
import
flashinfer
import
flashinfer
import
torch
import
torch
from
vllm.utils
import
round_up
from
vllm.utils
.math_utils
import
round_up
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
FP8_DTYPE
=
torch
.
float8_e4m3fn
FP8_DTYPE
=
torch
.
float8_e4m3fn
...
@@ -28,9 +27,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
...
@@ -28,9 +27,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
benchmark_prefill
(
def
benchmark_prefill
(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
quant_dtypes
:
tuple
[
quant_dtypes
:
tuple
[
torch
.
dtype
|
None
,
torch
.
dtype
|
None
,
torch
.
dtype
|
None
],
Optional
[
torch
.
dtype
],
Optional
[
torch
.
dtype
],
Optional
[
torch
.
dtype
]
],
batch_size
:
int
,
batch_size
:
int
,
max_seq_len
:
int
,
max_seq_len
:
int
,
num_heads
:
tuple
[
int
,
int
]
=
(
64
,
8
),
num_heads
:
tuple
[
int
,
int
]
=
(
64
,
8
),
...
...
benchmarks/kernels/benchmark_w8a8_block_fp8.py
View file @
006693ed
...
@@ -14,11 +14,11 @@ import torch
...
@@ -14,11 +14,11 @@ import torch
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
_w8a8_
block_fp8_matmul
,
_w8a8_
triton_block_scaled_mm
,
)
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
triton
from
vllm.triton_utils
import
triton
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
...
@@ -83,7 +83,7 @@ def w8a8_block_matmul(
...
@@ -83,7 +83,7 @@ def w8a8_block_matmul(
)
)
if
A
.
dtype
==
torch
.
float8_e4m3fn
:
if
A
.
dtype
==
torch
.
float8_e4m3fn
:
kernel
=
_w8a8_
block_fp8_matmul
kernel
=
_w8a8_
triton_block_scaled_mm
else
:
else
:
raise
RuntimeError
(
"Currently, only support tune w8a8 block fp8 kernel."
)
raise
RuntimeError
(
"Currently, only support tune w8a8 block fp8 kernel."
)
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment