Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bae4fdc7
Unverified
Commit
bae4fdc7
authored
Jun 07, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Jun 07, 2025
Browse files
add fbgemm moe grouped gemm kernel benchmark (#6924)
parent
6153f2ff
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
1983 additions
and
0 deletions
+1983
-0
benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py
benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py
+366
-0
benchmark/fbgemm/fbgemm_grouped_gemm.py
benchmark/fbgemm/fbgemm_grouped_gemm.py
+1294
-0
benchmark/fbgemm/test_grouped_gemm.py
benchmark/fbgemm/test_grouped_gemm.py
+323
-0
No files found.
benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py
0 → 100644
View file @
bae4fdc7
# python3 benchmark/kernels/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
import
argparse
import
torch
import
triton
from
fbgemm_grouped_gemm
import
grouped_gemm
as
fbgemm_grouped_gemm
from
fbgemm_grouped_gemm
import
(
grouped_gemm_fp8_rowwise
as
fbgemm_grouped_gemm_fp8_rowwise
,
)
from
transformers
import
AutoConfig
from
sglang.srt.layers.moe.ep_moe.kernels
import
(
grouped_gemm_triton
as
sglang_grouped_gemm
,
)
def
get_model_config
(
model_name
:
str
,
tp_size
:
int
):
config
=
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
num_groups
=
config
.
ffn_config
.
moe_num_experts
intermediate_size
=
config
.
ffn_config
.
ffn_hidden_size
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
num_groups
=
config
.
num_experts
intermediate_size
=
config
.
intermediate_size
elif
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
num_groups
=
config
.
num_experts
intermediate_size
=
config
.
moe_intermediate_size
elif
config
.
architectures
[
0
]
==
"Qwen3MoeForCausalLM"
:
num_groups
=
config
.
num_experts
intermediate_size
=
config
.
moe_intermediate_size
elif
config
.
architectures
[
0
]
in
[
"DeepseekV2ForCausalLM"
,
"DeepseekV3ForCausalLM"
]:
num_groups
=
(
config
.
n_routed_experts
+
1
if
config
.
architectures
[
0
]
in
[
"DeepseekV3ForCausalLM"
]
else
config
.
n_routed_experts
)
intermediate_size
=
config
.
moe_intermediate_size
elif
config
.
architectures
[
0
]
==
"Llama4ForConditionalGeneration"
:
num_groups
=
config
.
text_config
.
num_local_experts
intermediate_size
=
config
.
text_config
.
intermediate_size
elif
config
.
architectures
[
0
]
in
[
"Grok1ForCausalLM"
,
"Grok1ImgGen"
,
"Grok1AForCausalLM"
,
]:
num_groups
=
config
.
num_local_experts
intermediate_size
=
config
.
moe_intermediate_size
else
:
num_groups
=
config
.
num_local_experts
intermediate_size
=
config
.
intermediate_size
shape_configs
=
{
"num_groups"
:
num_groups
,
"hidden_size"
:
config
.
hidden_size
,
"intermediate_size"
:
intermediate_size
,
"dtype"
:
config
.
torch_dtype
,
}
print
(
f
"
{
shape_configs
=
}
"
)
return
shape_configs
def
create_test_data
(
batch_size
,
num_groups
,
hidden_size
,
intermediate_size
):
torch
.
manual_seed
(
42
)
tokens_per_group
=
batch_size
//
num_groups
m_sizes
=
torch
.
full
(
(
num_groups
,),
tokens_per_group
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
x
=
torch
.
randn
(
batch_size
,
hidden_size
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
base_weights
=
torch
.
randn
(
num_groups
,
intermediate_size
,
hidden_size
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
w_fbgemm
=
base_weights
.
reshape
(
num_groups
*
intermediate_size
,
hidden_size
)
w_sglang
=
base_weights
c_fbgemm
=
torch
.
empty
(
batch_size
,
intermediate_size
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
c_sglang
=
torch
.
empty
(
batch_size
,
intermediate_size
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
seg_indptr
=
torch
.
zeros
(
num_groups
+
1
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
for
i
in
range
(
1
,
num_groups
+
1
):
seg_indptr
[
i
]
=
seg_indptr
[
i
-
1
]
+
tokens_per_group
weight_indices
=
torch
.
arange
(
num_groups
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
return
(
x
,
w_fbgemm
,
w_sglang
,
c_fbgemm
,
c_sglang
,
m_sizes
,
seg_indptr
,
weight_indices
,
)
def
create_fp8_test_data
(
batch_size
,
num_groups
,
hidden_size
,
intermediate_size
):
torch
.
manual_seed
(
42
)
tokens_per_group
=
batch_size
//
num_groups
m_sizes
=
torch
.
full
(
(
num_groups
,),
tokens_per_group
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
x_fp16
=
torch
.
randn
(
batch_size
,
hidden_size
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
w_fp16
=
torch
.
randn
(
num_groups
*
intermediate_size
,
hidden_size
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
x_fp8
=
x_fp16
.
to
(
torch
.
float8_e4m3fn
)
w_fp8
=
w_fp16
.
to
(
torch
.
float8_e4m3fn
)
x_scale
=
torch
.
randn
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
).
abs
()
+
1e-4
w_scale
=
torch
.
randn
(
num_groups
,
dtype
=
torch
.
float32
,
device
=
"cuda"
).
abs
()
+
1e-4
return
x_fp8
,
w_fp8
,
m_sizes
,
x_scale
,
w_scale
def
get_benchmark_config
(
use_fp8_w8a8
=
False
):
if
use_fp8_w8a8
:
return
{
"line_vals"
:
[
"fbgemm_grouped_gemm_fp8"
,
"sglang_grouped_gemm"
],
"line_names"
:
[
"FBGEMM Grouped GEMM FP8"
,
"SGLang Grouped GEMM FP8"
],
"styles"
:
[(
"blue"
,
"-"
),
(
"red"
,
"-"
)],
}
else
:
return
{
"line_vals"
:
[
"fbgemm_grouped_gemm"
,
"sglang_grouped_gemm"
],
"line_names"
:
[
"FBGEMM Grouped GEMM BF16"
,
"SGLang Grouped GEMM BF16"
],
"styles"
:
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
}
def
run_benchmark
(
model_config
,
use_fp8_w8a8
=
False
,
save_path
=
"./benchmark_grouped_gemm/"
):
config
=
get_benchmark_config
(
use_fp8_w8a8
)
benchmark_config
=
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
],
line_arg
=
"provider"
,
line_vals
=
config
[
"line_vals"
],
line_names
=
config
[
"line_names"
],
styles
=
config
[
"styles"
],
ylabel
=
"Time (ms)"
,
plot_name
=
"grouped-gemm-performance"
,
args
=
{},
)
@
triton
.
testing
.
perf_report
(
benchmark_config
)
def
dynamic_benchmark
(
batch_size
,
provider
,
model_config
,
use_fp8_w8a8
=
False
):
print
(
f
"Benchmarking
{
provider
}
with batch_size=
{
batch_size
}
"
)
torch
.
cuda
.
manual_seed_all
(
0
)
num_groups
=
model_config
[
"num_groups"
]
hidden_size
=
model_config
[
"hidden_size"
]
intermediate_size
=
model_config
[
"intermediate_size"
]
if
provider
==
"fbgemm_grouped_gemm_fp8"
:
try
:
test_data
=
create_fp8_test_data
(
batch_size
,
num_groups
,
hidden_size
,
intermediate_size
)
x_fp8
,
w_fp8
,
m_sizes
,
x_scale
,
w_scale
=
test_data
def
run_func
():
return
fbgemm_grouped_gemm_fp8_rowwise
(
x_fp8
,
w_fp8
,
m_sizes
,
x_scale
,
w_scale
,
use_fast_accum
=
True
)
except
Exception
as
e
:
print
(
f
"FP8 not supported, skipping:
{
e
}
"
)
return
float
(
"inf"
),
float
(
"inf"
),
float
(
"inf"
)
else
:
test_data
=
create_test_data
(
batch_size
,
num_groups
,
hidden_size
,
intermediate_size
)
(
x
,
w_fbgemm
,
w_sglang
,
c_fbgemm
,
c_sglang
,
m_sizes
,
seg_indptr
,
weight_indices
,
)
=
test_data
if
provider
==
"fbgemm_grouped_gemm"
:
def
run_func
():
return
fbgemm_grouped_gemm
(
x
,
w_fbgemm
,
m_sizes
,
use_fast_accum
=
True
)
else
:
def
run_func
():
return
sglang_grouped_gemm
(
x
,
w_sglang
,
c_sglang
,
num_groups
,
weight_column_major
=
True
,
seg_indptr
=
seg_indptr
,
weight_indices
=
weight_indices
,
c_dtype
=
c_sglang
.
dtype
,
)
for
_
in
range
(
10
):
try
:
run_func
()
except
Exception
as
e
:
print
(
f
"Error during warmup for
{
provider
}
:
{
e
}
"
)
return
float
(
"inf"
),
float
(
"inf"
),
float
(
"inf"
)
torch
.
cuda
.
synchronize
()
try
:
quantiles
=
[
0.5
,
0.2
,
0.8
]
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
run_func
,
quantiles
=
quantiles
)
return
ms
,
min_ms
,
max_ms
except
Exception
as
e
:
print
(
f
"Error during benchmarking for
{
provider
}
:
{
e
}
"
)
return
float
(
"inf"
),
float
(
"inf"
),
float
(
"inf"
)
dynamic_benchmark
.
run
(
show_plots
=
True
,
print_data
=
True
,
save_path
=
save_path
,
model_config
=
model_config
,
use_fp8_w8a8
=
use_fp8_w8a8
,
)
def
verify_correctness
(
model_config
,
use_fp8_w8a8
):
print
(
"Verifying correctness..."
)
batch_size
=
128
num_groups
=
model_config
[
"num_groups"
]
hidden_size
=
model_config
[
"hidden_size"
]
intermediate_size
=
model_config
[
"intermediate_size"
]
test_data
=
create_test_data
(
batch_size
,
num_groups
,
hidden_size
,
intermediate_size
)
(
x
,
w_fbgemm
,
w_sglang
,
c_fbgemm
,
c_sglang
,
m_sizes
,
seg_indptr
,
weight_indices
)
=
(
test_data
)
try
:
result_fbgemm
=
fbgemm_grouped_gemm
(
x
,
w_fbgemm
,
m_sizes
,
use_fast_accum
=
True
)
result_sglang
=
sglang_grouped_gemm
(
x
,
w_sglang
,
c_sglang
,
num_groups
,
weight_column_major
=
True
,
seg_indptr
=
seg_indptr
,
weight_indices
=
weight_indices
,
c_dtype
=
c_sglang
.
dtype
,
)
if
torch
.
allclose
(
result_fbgemm
,
result_sglang
,
rtol
=
1e-3
,
atol
=
1e-3
):
print
(
"✓ BF16 Correctness verification passed!"
)
else
:
max_diff
=
torch
.
max
(
torch
.
abs
(
result_fbgemm
-
result_sglang
))
print
(
f
"✗ BF16 Correctness verification failed! Max diff:
{
max_diff
}
"
)
return
False
if
use_fp8_w8a8
:
try
:
fp8_data
=
create_fp8_test_data
(
batch_size
,
num_groups
,
hidden_size
,
intermediate_size
)
x_fp8
,
w_fp8
,
m_sizes_fp8
,
x_scale
,
w_scale
=
fp8_data
result_fp8
=
fbgemm_grouped_gemm_fp8_rowwise
(
x_fp8
,
w_fp8
,
m_sizes_fp8
,
x_scale
,
w_scale
,
use_fast_accum
=
True
)
assert
result_fp8
.
shape
==
(
batch_size
,
intermediate_size
)
print
(
"✓ FP8 functionality test passed!"
)
except
Exception
as
e
:
print
(
f
"FP8 test failed (possibly unsupported):
{
e
}
"
)
return
False
return
True
except
Exception
as
e
:
print
(
f
"✗ Error during correctness verification:
{
e
}
"
)
return
False
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark FBGEMM vs SGLang Grouped GEMM"
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
,
help
=
"Model name to get configuration from"
,
)
parser
.
add_argument
(
"--tp-size"
,
type
=
int
,
default
=
1
,
help
=
"Tensor parallelism size"
)
parser
.
add_argument
(
"--use-fp8-w8a8"
,
action
=
"store_true"
,
help
=
"Enable FP8 W8A8 benchmark"
)
parser
.
add_argument
(
"--save-path"
,
type
=
str
,
default
=
"./benchmark_grouped_gemm/"
,
help
=
"Path to save benchmark results"
,
)
parser
.
add_argument
(
"--verify-correctness"
,
action
=
"store_true"
,
help
=
"Verify correctness before benchmarking"
,
)
args
=
parser
.
parse_args
()
try
:
model_config
=
get_model_config
(
args
.
model
,
args
.
tp_size
)
except
Exception
as
e
:
print
(
f
"Failed to get model config:
{
e
}
"
)
print
(
"Using default configuration..."
)
model_config
=
{
"num_groups"
:
8
,
"hidden_size"
:
4096
,
"intermediate_size"
:
14336
,
"dtype"
:
torch
.
bfloat16
,
}
print
(
"Running benchmark with:"
)
print
(
f
" num_groups:
{
model_config
[
'num_groups'
]
}
"
)
print
(
f
" hidden_size:
{
model_config
[
'hidden_size'
]
}
"
)
print
(
f
" intermediate_size:
{
model_config
[
'intermediate_size'
]
}
"
)
print
(
f
" use_fp8_w8a8:
{
args
.
use_fp8_w8a8
}
"
)
if
args
.
verify_correctness
:
if
not
verify_correctness
(
model_config
,
args
.
use_fp8_w8a8
):
print
(
"Correctness verification failed. Exiting..."
)
return
try
:
run_benchmark
(
model_config
=
model_config
,
use_fp8_w8a8
=
args
.
use_fp8_w8a8
,
save_path
=
args
.
save_path
,
)
except
Exception
as
e
:
print
(
f
"Benchmark failed:
{
e
}
"
)
if
__name__
==
"__main__"
:
main
()
benchmark/fbgemm/fbgemm_grouped_gemm.py
0 → 100644
View file @
bae4fdc7
This diff is collapsed.
Click to expand it.
benchmark/fbgemm/test_grouped_gemm.py
0 → 100644
View file @
bae4fdc7
import
os
import
sys
import
pytest
import
torch
sys
.
path
.
insert
(
0
,
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)))
try
:
from
fbgemm_grouped_gemm
import
grouped_gemm
as
fbgemm_grouped_gemm
from
fbgemm_grouped_gemm
import
(
grouped_gemm_fp8_rowwise
as
fbgemm_grouped_gemm_fp8_rowwise
,
)
FBGEMM_AVAILABLE
=
True
print
(
"✓ Successfully imported FBGEMM grouped GEMM"
)
except
ImportError
as
e
:
print
(
f
"✗ Failed to import FBGEMM grouped GEMM:
{
e
}
"
)
FBGEMM_AVAILABLE
=
False
try
:
from
sglang.srt.layers.moe.ep_moe.kernels
import
(
grouped_gemm_triton
as
sglang_grouped_gemm
,
)
SGLANG_AVAILABLE
=
True
print
(
"✓ Successfully imported SGLang grouped GEMM"
)
except
ImportError
as
e
:
print
(
f
"✗ Failed to import SGLang grouped GEMM:
{
e
}
"
)
SGLANG_AVAILABLE
=
False
def
create_uniform_groups
(
batch_size
,
num_groups
,
device
):
tokens_per_group
=
batch_size
//
num_groups
return
torch
.
full
((
num_groups
,),
tokens_per_group
,
dtype
=
torch
.
int64
,
device
=
device
)
def
create_non_uniform_groups
(
batch_size
,
num_groups
,
device
):
remaining
=
batch_size
m_sizes
=
[]
for
i
in
range
(
num_groups
-
1
):
if
remaining
<=
1
:
size
=
1
else
:
max_size
=
remaining
-
(
num_groups
-
i
-
1
)
+
1
size
=
torch
.
randint
(
1
,
max_size
,
(
1
,)).
item
()
m_sizes
.
append
(
size
)
remaining
-=
size
m_sizes
.
append
(
remaining
)
return
torch
.
tensor
(
m_sizes
,
dtype
=
torch
.
int64
,
device
=
device
)
def
create_sglang_inputs
(
x
,
w
,
m_sizes
,
num_groups
,
intermediate_size
,
device
):
batch_size
=
x
.
shape
[
0
]
c_sglang
=
torch
.
empty
(
batch_size
,
intermediate_size
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
seg_indptr
=
torch
.
zeros
(
num_groups
+
1
,
dtype
=
torch
.
int64
,
device
=
device
)
current_pos
=
0
for
i
,
size
in
enumerate
(
m_sizes
):
current_pos
+=
size
seg_indptr
[
i
+
1
]
=
current_pos
weight_indices
=
torch
.
arange
(
num_groups
,
dtype
=
torch
.
int64
,
device
=
device
)
w_sglang
=
w
.
view
(
num_groups
,
intermediate_size
,
-
1
)
return
c_sglang
,
seg_indptr
,
weight_indices
,
w_sglang
def
create_fp8_data
(
batch_size
,
num_groups
,
hidden_size
,
intermediate_size
,
device
):
torch
.
manual_seed
(
42
)
x_fp16
=
torch
.
randn
(
batch_size
,
hidden_size
,
dtype
=
torch
.
float16
,
device
=
device
)
w_fp16
=
torch
.
randn
(
num_groups
*
intermediate_size
,
hidden_size
,
dtype
=
torch
.
float16
,
device
=
device
)
x_fp8
=
x_fp16
.
to
(
torch
.
float8_e4m3fn
)
w_fp8
=
w_fp16
.
to
(
torch
.
float8_e4m3fn
)
x_scale
=
torch
.
randn
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
device
).
abs
()
+
1e-4
w_scale
=
torch
.
randn
(
num_groups
,
dtype
=
torch
.
float32
,
device
=
device
).
abs
()
+
1e-4
return
x_fp8
,
w_fp8
,
x_scale
,
w_scale
@
pytest
.
fixture
def
device
():
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
"CUDA not available"
)
return
torch
.
device
(
"cuda"
)
@
pytest
.
mark
.
skipif
(
not
FBGEMM_AVAILABLE
,
reason
=
"FBGEMM not available"
)
@
pytest
.
mark
.
skipif
(
not
SGLANG_AVAILABLE
,
reason
=
"SGLang not available"
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"num_groups"
,
[
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
512
,
1024
])
@
pytest
.
mark
.
parametrize
(
"intermediate_size"
,
[
1024
,
2048
])
def
test_uniform_groups
(
batch_size
,
num_groups
,
hidden_size
,
intermediate_size
,
device
):
if
batch_size
%
num_groups
!=
0
:
pytest
.
skip
(
f
"batch_size
{
batch_size
}
not divisible by num_groups
{
num_groups
}
"
)
torch
.
manual_seed
(
42
)
m_sizes
=
create_uniform_groups
(
batch_size
,
num_groups
,
device
)
x
=
torch
.
randn
(
batch_size
,
hidden_size
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
w
=
torch
.
randn
(
num_groups
*
intermediate_size
,
hidden_size
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
result_fbgemm
=
fbgemm_grouped_gemm
(
x
,
w
,
m_sizes
,
use_fast_accum
=
True
)
c_sglang
,
seg_indptr
,
weight_indices
,
w_sglang
=
create_sglang_inputs
(
x
,
w
,
m_sizes
,
num_groups
,
intermediate_size
,
device
)
result_sglang
=
sglang_grouped_gemm
(
x
,
w_sglang
,
c_sglang
,
num_groups
,
weight_column_major
=
True
,
seg_indptr
=
seg_indptr
,
weight_indices
=
weight_indices
,
c_dtype
=
c_sglang
.
dtype
,
)
assert
torch
.
allclose
(
result_fbgemm
,
result_sglang
,
rtol
=
1e-3
,
atol
=
1e-3
)
@
pytest
.
mark
.
skipif
(
not
FBGEMM_AVAILABLE
,
reason
=
"FBGEMM not available"
)
@
pytest
.
mark
.
skipif
(
not
SGLANG_AVAILABLE
,
reason
=
"SGLang not available"
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
63
,
100
,
127
])
@
pytest
.
mark
.
parametrize
(
"num_groups"
,
[
3
,
5
,
7
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
512
,
1024
])
@
pytest
.
mark
.
parametrize
(
"intermediate_size"
,
[
1024
,
2048
])
def
test_non_uniform_groups
(
batch_size
,
num_groups
,
hidden_size
,
intermediate_size
,
device
):
torch
.
manual_seed
(
42
)
m_sizes
=
create_non_uniform_groups
(
batch_size
,
num_groups
,
device
)
x
=
torch
.
randn
(
batch_size
,
hidden_size
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
w
=
torch
.
randn
(
num_groups
*
intermediate_size
,
hidden_size
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
result_fbgemm
=
fbgemm_grouped_gemm
(
x
,
w
,
m_sizes
,
use_fast_accum
=
True
)
c_sglang
,
seg_indptr
,
weight_indices
,
w_sglang
=
create_sglang_inputs
(
x
,
w
,
m_sizes
,
num_groups
,
intermediate_size
,
device
)
result_sglang
=
sglang_grouped_gemm
(
x
,
w_sglang
,
c_sglang
,
num_groups
,
weight_column_major
=
True
,
seg_indptr
=
seg_indptr
,
weight_indices
=
weight_indices
,
c_dtype
=
c_sglang
.
dtype
,
)
assert
torch
.
allclose
(
result_fbgemm
,
result_sglang
,
rtol
=
1e-3
,
atol
=
1e-3
)
@
pytest
.
mark
.
skipif
(
not
FBGEMM_AVAILABLE
,
reason
=
"FBGEMM not available"
)
@
pytest
.
mark
.
skipif
(
not
SGLANG_AVAILABLE
,
reason
=
"SGLang not available"
)
@
pytest
.
mark
.
parametrize
(
"batch_size,num_groups"
,
[(
64
,
4
),
(
128
,
8
),
(
256
,
16
)])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
768
,
2048
,
4096
])
@
pytest
.
mark
.
parametrize
(
"intermediate_size"
,
[
2048
,
4096
,
8192
])
def
test_large_dimensions
(
batch_size
,
num_groups
,
hidden_size
,
intermediate_size
,
device
):
torch
.
manual_seed
(
42
)
m_sizes
=
create_uniform_groups
(
batch_size
,
num_groups
,
device
)
x
=
torch
.
randn
(
batch_size
,
hidden_size
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
w
=
torch
.
randn
(
num_groups
*
intermediate_size
,
hidden_size
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
result_fbgemm
=
fbgemm_grouped_gemm
(
x
,
w
,
m_sizes
,
use_fast_accum
=
True
)
c_sglang
,
seg_indptr
,
weight_indices
,
w_sglang
=
create_sglang_inputs
(
x
,
w
,
m_sizes
,
num_groups
,
intermediate_size
,
device
)
result_sglang
=
sglang_grouped_gemm
(
x
,
w_sglang
,
c_sglang
,
num_groups
,
weight_column_major
=
True
,
seg_indptr
=
seg_indptr
,
weight_indices
=
weight_indices
,
c_dtype
=
c_sglang
.
dtype
,
)
assert
torch
.
allclose
(
result_fbgemm
,
result_sglang
,
rtol
=
1e-3
,
atol
=
1e-3
)
@
pytest
.
mark
.
skipif
(
not
FBGEMM_AVAILABLE
,
reason
=
"FBGEMM not available"
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
32
,
64
])
@
pytest
.
mark
.
parametrize
(
"num_groups"
,
[
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
512
,
1024
])
@
pytest
.
mark
.
parametrize
(
"intermediate_size"
,
[
1024
,
2048
])
def
test_fp8_uniform_groups
(
batch_size
,
num_groups
,
hidden_size
,
intermediate_size
,
device
):
if
batch_size
%
num_groups
!=
0
:
pytest
.
skip
(
f
"batch_size
{
batch_size
}
not divisible by num_groups
{
num_groups
}
"
)
torch
.
manual_seed
(
42
)
m_sizes
=
create_uniform_groups
(
batch_size
,
num_groups
,
device
)
x_fp8
,
w_fp8
,
x_scale
,
w_scale
=
create_fp8_data
(
batch_size
,
num_groups
,
hidden_size
,
intermediate_size
,
device
)
try
:
result_fp8
=
fbgemm_grouped_gemm_fp8_rowwise
(
x_fp8
,
w_fp8
,
m_sizes
,
x_scale
,
w_scale
,
use_fast_accum
=
True
)
assert
result_fp8
.
shape
==
(
batch_size
,
intermediate_size
)
assert
result_fp8
.
dtype
==
torch
.
bfloat16
except
Exception
as
e
:
pytest
.
skip
(
f
"FP8 test failed (possibly unsupported):
{
e
}
"
)
@
pytest
.
mark
.
skipif
(
not
FBGEMM_AVAILABLE
,
reason
=
"FBGEMM not available"
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
63
,
100
])
@
pytest
.
mark
.
parametrize
(
"num_groups"
,
[
3
,
5
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
512
,
1024
])
@
pytest
.
mark
.
parametrize
(
"intermediate_size"
,
[
1024
,
2048
])
def
test_fp8_non_uniform_groups
(
batch_size
,
num_groups
,
hidden_size
,
intermediate_size
,
device
):
torch
.
manual_seed
(
42
)
m_sizes
=
create_non_uniform_groups
(
batch_size
,
num_groups
,
device
)
x_fp8
,
w_fp8
,
x_scale
,
w_scale
=
create_fp8_data
(
batch_size
,
num_groups
,
hidden_size
,
intermediate_size
,
device
)
try
:
result_fp8
=
fbgemm_grouped_gemm_fp8_rowwise
(
x_fp8
,
w_fp8
,
m_sizes
,
x_scale
,
w_scale
,
use_fast_accum
=
True
)
assert
result_fp8
.
shape
==
(
batch_size
,
intermediate_size
)
assert
result_fp8
.
dtype
==
torch
.
bfloat16
except
Exception
as
e
:
pytest
.
skip
(
f
"FP8 test failed (possibly unsupported):
{
e
}
"
)
@
pytest
.
mark
.
skipif
(
not
FBGEMM_AVAILABLE
,
reason
=
"FBGEMM not available"
)
def
test_fbgemm_only_uniform
(
device
):
torch
.
manual_seed
(
42
)
batch_size
,
num_groups
=
64
,
4
hidden_size
,
intermediate_size
=
512
,
1024
m_sizes
=
create_uniform_groups
(
batch_size
,
num_groups
,
device
)
x
=
torch
.
randn
(
batch_size
,
hidden_size
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
w
=
torch
.
randn
(
num_groups
*
intermediate_size
,
hidden_size
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
result
=
fbgemm_grouped_gemm
(
x
,
w
,
m_sizes
,
use_fast_accum
=
True
)
assert
result
.
shape
==
(
batch_size
,
intermediate_size
)
assert
result
.
dtype
==
torch
.
bfloat16
@
pytest
.
mark
.
skipif
(
not
SGLANG_AVAILABLE
,
reason
=
"SGLang not available"
)
def
test_sglang_only_uniform
(
device
):
torch
.
manual_seed
(
42
)
batch_size
,
num_groups
=
64
,
4
hidden_size
,
intermediate_size
=
512
,
1024
m_sizes
=
create_uniform_groups
(
batch_size
,
num_groups
,
device
)
x
=
torch
.
randn
(
batch_size
,
hidden_size
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
w
=
torch
.
randn
(
num_groups
*
intermediate_size
,
hidden_size
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
c_sglang
,
seg_indptr
,
weight_indices
,
w_sglang
=
create_sglang_inputs
(
x
,
w
,
m_sizes
,
num_groups
,
intermediate_size
,
device
)
result
=
sglang_grouped_gemm
(
x
,
w_sglang
,
c_sglang
,
num_groups
,
weight_column_major
=
True
,
seg_indptr
=
seg_indptr
,
weight_indices
=
weight_indices
,
c_dtype
=
c_sglang
.
dtype
,
)
assert
result
.
shape
==
(
batch_size
,
intermediate_size
)
assert
result
.
dtype
==
torch
.
bfloat16
def
test_imports
():
assert
(
FBGEMM_AVAILABLE
or
SGLANG_AVAILABLE
),
"Neither FBGEMM nor SGLang is available"
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
,
"-v"
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment