Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0c0fdae8
Unverified
Commit
0c0fdae8
authored
May 09, 2025
by
Pavani Majety
Committed by
GitHub
May 09, 2025
Browse files
[Hardware/NVIDIA/Kernel] Enable nvidia/DeepSeek-R1-FP4 Model (#16362)
parent
3b602cde
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
1994 additions
and
112 deletions
+1994
-112
CMakeLists.txt
CMakeLists.txt
+5
-2
benchmarks/kernels/benchmark_cutlass_fp4_moe.py
benchmarks/kernels/benchmark_cutlass_fp4_moe.py
+408
-0
csrc/ops.h
csrc/ops.h
+12
-0
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+11
-7
csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
+402
-0
csrc/quantization/fp4/nvfp4_experts_quant.cu
csrc/quantization/fp4/nvfp4_experts_quant.cu
+404
-0
csrc/quantization/fp4/nvfp4_quant_entry.cu
csrc/quantization/fp4/nvfp4_quant_entry.cu
+23
-1
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+15
-0
tests/kernels/moe/test_nvfp4_moe.py
tests/kernels/moe/test_nvfp4_moe.py
+144
-0
tests/kernels/quantization/nvfp4_utils.py
tests/kernels/quantization/nvfp4_utils.py
+66
-0
tests/kernels/quantization/test_nvfp4_scaled_mm.py
tests/kernels/quantization/test_nvfp4_scaled_mm.py
+14
-84
vllm/_custom_ops.py
vllm/_custom_ops.py
+91
-7
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+2
-1
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+125
-1
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+20
-3
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+252
-6
No files found.
CMakeLists.txt
View file @
0c0fdae8
...
...
@@ -288,6 +288,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
"csrc/cutlass_extensions/common.cpp"
"csrc/attention/mla/cutlass_mla_entry.cu"
)
...
...
@@ -495,7 +496,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER 12.8 AND FP4_ARCHS
)
set
(
SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
)
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu"
)
set_gencode_flags_for_srcs
(
SRCS
"
${
SRCS
}
"
CUDA_ARCHS
"
${
FP4_ARCHS
}
"
)
...
...
@@ -533,7 +536,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
# on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible
# to compile MoE kernels that use its output.
cuda_archs_loose_intersection
(
SCALED_MM_ARCHS
"9.0a;"
"
${
CUDA_ARCHS
}
"
)
cuda_archs_loose_intersection
(
SCALED_MM_ARCHS
"9.0a;
10.0a
"
"
${
CUDA_ARCHS
}
"
)
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS
)
set
(
SRCS
"csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
"csrc/quantization/cutlass_w8a8/moe/moe_data.cu"
)
...
...
benchmarks/kernels/benchmark_cutlass_fp4_moe.py
0 → 100644
View file @
0c0fdae8
# SPDX-License-Identifier: Apache-2.0
"""
Benchmark the performance of the cutlass_moe_fp4 kernel vs the triton_moe
kernel. The cutlass_moe_fp4 kernel takes in fp4 quantized weights and 16-bit
activations. The triton_moe kernel takes in fp8 weights(tensor scaled to fp8)
and 16-bit activations.
"""
import
nvtx
import
torch
import
torch.utils.benchmark
as
benchmark
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
cutlass_moe_fp4
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
,
fused_topk
)
from
vllm.scalar_type
import
scalar_types
from
vllm.utils
import
FlexibleArgumentParser
WEIGHT_SHAPES_MOE
=
{
"nvidia/DeepSeek-R1-FP4"
:
[
[
256
,
8
,
2048
,
7168
],
],
}
DEFAULT_MODELS
=
[
"nvidia/DeepSeek-R1-FP4"
,
]
DEFAULT_BATCH_SIZES
=
[
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
]
DEFAULT_TP_SIZES
=
[
1
]
PER_ACT_TOKEN_OPTS
=
[
False
]
PER_OUT_CH_OPTS
=
[
False
]
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
.
max
()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
def
to_fp8
(
tensor
:
torch
.
Tensor
):
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
return
torch
.
round
(
tensor
.
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)).
to
(
dtype
=
torch
.
float8_e4m3fn
)
def
bench_run
(
results
:
list
[
benchmark
.
Measurement
],
model
:
str
,
num_experts
:
int
,
topk
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
mkn
:
tuple
[
int
,
int
,
int
]):
label
=
"NVFP4 Blockscaled CUTLASS MOE vs FP8 Tensor Scaled Triton"
sub_label
=
(
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, "
"MKN=({})"
.
format
(
model
,
num_experts
,
topk
,
per_act_token
,
per_out_ch
,
mkn
))
print
(
f
"Testing:
{
sub_label
}
"
)
(
m
,
k
,
n
)
=
mkn
dtype
=
torch
.
half
device
=
"cuda"
a
=
torch
.
randn
((
m
,
k
),
device
=
device
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
num_experts
,
2
*
n
,
k
),
device
=
device
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
num_experts
,
k
,
n
),
device
=
device
,
dtype
=
dtype
)
/
10
_
,
a_fp8_scale
=
ops
.
scaled_fp8_quant
(
a
)
w1_fp8q
=
torch
.
empty
((
num_experts
,
2
*
n
,
k
),
device
=
device
,
dtype
=
torch
.
float8_e4m3fn
)
w2_fp8q
=
torch
.
empty
((
num_experts
,
k
,
n
),
device
=
device
,
dtype
=
torch
.
float8_e4m3fn
)
w1_fp8scale
=
torch
.
empty
((
num_experts
,
1
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
w2_fp8scale
=
torch
.
empty
((
num_experts
,
1
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
for
expert
in
range
(
num_experts
):
w1_fp8q
[
expert
],
w1_fp8scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
w1
[
expert
])
w2_fp8q
[
expert
],
w2_fp8scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
w2
[
expert
])
w1_fp8q_notransp
=
w1_fp8q
.
clone
()
w2_fp8q_notransp
=
w2_fp8q
.
clone
()
w1_fp8q
=
w1_fp8q
.
transpose
(
1
,
2
)
w2_fp8q
=
w2_fp8q
.
transpose
(
1
,
2
)
score
=
torch
.
randn
((
m
,
num_experts
),
device
=
device
,
dtype
=
dtype
)
topk_weights
,
topk_ids
=
fused_topk
(
a
,
score
,
topk
,
renormalize
=
False
)
quant_blocksize
=
16
w1_blockscale
=
torch
.
empty
((
num_experts
,
2
*
n
,
k
//
quant_blocksize
),
device
=
device
,
dtype
=
torch
.
float8_e4m3fn
)
w2_blockscale
=
torch
.
empty
((
num_experts
,
k
,
n
//
quant_blocksize
),
device
=
device
,
dtype
=
torch
.
float8_e4m3fn
)
# n_b_scales = 2 * n if per_out_ch else 1
# k_b_scales = k if per_out_ch else 1
w1_fp4
=
torch
.
empty
((
num_experts
,
2
*
n
,
k
//
2
),
device
=
device
,
dtype
=
torch
.
uint8
)
w2_fp4
=
torch
.
empty
((
num_experts
,
k
,
n
//
2
),
device
=
device
,
dtype
=
torch
.
uint8
)
w1_gs
=
torch
.
empty
((
num_experts
,
),
device
=
device
,
dtype
=
torch
.
float32
)
w2_gs
=
torch
.
empty
((
num_experts
,
),
device
=
device
,
dtype
=
torch
.
float32
)
a1_gs
=
torch
.
ones
((
num_experts
,
),
device
=
device
,
dtype
=
torch
.
float32
)
a2_gs
=
torch
.
ones
((
num_experts
,
),
device
=
device
,
dtype
=
torch
.
float32
)
for
expert
in
range
(
num_experts
):
w1_e
=
w1
[
expert
]
w2_e
=
w2
[
expert
]
w1_amax
=
torch
.
abs
(
w1_e
).
max
().
to
(
torch
.
float32
)
w2_amax
=
torch
.
abs
(
w2_e
).
max
().
to
(
torch
.
float32
)
w1_gs
[
expert
]
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
w1_amax
w2_gs
[
expert
]
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
w2_amax
w1_fp4
[
expert
],
w1_blockscale
[
expert
]
=
ops
.
scaled_fp4_quant
(
w1_e
,
w1_gs
[
expert
])
w2_fp4
[
expert
],
w2_blockscale
[
expert
]
=
ops
.
scaled_fp4_quant
(
w2_e
,
w2_gs
[
expert
])
def
run_triton_moe
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
a_fp8_scale
:
torch
.
Tensor
,
num_repeats
:
int
):
for
_
in
range
(
num_repeats
):
fused_experts
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a_fp8_scale
)
def
run_cutlass_moe_fp4
(
a
:
torch
.
Tensor
,
w1_fp4
:
torch
.
Tensor
,
w2_fp4
:
torch
.
Tensor
,
w1_blockscale
:
torch
.
Tensor
,
w2_blockscale
:
torch
.
Tensor
,
w1_gs
:
torch
.
Tensor
,
w2_gs
:
torch
.
Tensor
,
a1_gs
:
torch
.
Tensor
,
a2_gs
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
device
:
torch
.
device
,
num_repeats
:
int
):
for
_
in
range
(
num_repeats
):
with
nvtx
.
annotate
(
"cutlass_moe_fp4"
,
color
=
"green"
):
cutlass_moe_fp4
(
a
=
a
,
a1_gscale
=
a1_gs
,
a2_gscale
=
a2_gs
,
w1_fp4
=
w1_fp4
,
w1_blockscale
=
w1_blockscale
,
w1_alphas
=
w1_gs
,
w2_fp4
=
w2_fp4
,
w2_blockscale
=
w2_blockscale
,
w2_alphas
=
w2_gs
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
m
=
m
,
n
=
n
,
k
=
k
,
e
=
num_experts
,
device
=
device
)
def
run_cutlass_from_graph
(
a
:
torch
.
Tensor
,
a1_gscale
:
torch
.
Tensor
,
w1_fp4
:
torch
.
Tensor
,
w1_blockscale
:
torch
.
Tensor
,
w1_alphas
:
torch
.
Tensor
,
a2_gscale
:
torch
.
Tensor
,
w2_fp4
:
torch
.
Tensor
,
w2_blockscale
:
torch
.
Tensor
,
w2_alphas
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
device
:
torch
.
device
):
with
set_current_vllm_config
(
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))):
return
cutlass_moe_fp4
(
a
=
a
,
a1_gscale
=
a1_gs
,
w1_fp4
=
w1_fp4
,
w1_blockscale
=
w1_blockscale
,
w1_alphas
=
w1_alphas
,
a2_gscale
=
a2_gs
,
w2_fp4
=
w2_fp4
,
w2_blockscale
=
w2_blockscale
,
w2_alphas
=
w2_alphas
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
m
=
m
,
n
=
n
,
k
=
k
,
e
=
num_experts
,
device
=
device
)
def
run_triton_from_graph
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
a_fp8_scale
:
torch
.
Tensor
):
with
set_current_vllm_config
(
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))):
return
fused_experts
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a_fp8_scale
)
def
replay_graph
(
graph
,
num_repeats
):
for
_
in
range
(
num_repeats
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
cutlass_stream
=
torch
.
cuda
.
Stream
()
cutlass_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
cutlass_graph
,
stream
=
cutlass_stream
):
run_cutlass_from_graph
(
a
=
a
,
a1_gscale
=
a1_gs
,
w1_fp4
=
w1_fp4
,
w1_blockscale
=
w1_blockscale
,
w1_alphas
=
w1_gs
,
a2_gscale
=
a2_gs
,
w2_fp4
=
w2_fp4
,
w2_blockscale
=
w2_blockscale
,
w2_alphas
=
w2_gs
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
m
=
m
,
n
=
n
,
k
=
k
,
e
=
num_experts
,
device
=
device
)
torch
.
cuda
.
synchronize
()
triton_stream
=
torch
.
cuda
.
Stream
()
triton_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
triton_graph
,
stream
=
triton_stream
):
run_triton_from_graph
(
a
,
w1_fp8q_notransp
,
w2_fp8q_notransp
,
topk_weights
,
topk_ids
,
w1_fp8scale
,
w2_fp8scale
,
a_fp8_scale
)
torch
.
cuda
.
synchronize
()
min_run_time
=
5
num_warmup
=
5
num_runs
=
25
globals
=
{
# Baseline params
"w1"
:
w1
,
"w2"
:
w2
,
"score"
:
score
,
"topk"
:
topk
,
"w1_fp8q_notransp"
:
w1_fp8q_notransp
,
"w2_fp8q_notransp"
:
w2_fp8q_notransp
,
"w1_fp8scale"
:
w1_fp8scale
,
"w2_fp8scale"
:
w2_fp8scale
,
"a_fp8_scale"
:
a_fp8_scale
,
# Cutlass params
"a"
:
a
,
"a1_gscale"
:
a1_gs
,
"w1_fp4"
:
w1_fp4
,
"w1_blockscale"
:
w1_blockscale
,
"w1_alphas"
:
w1_gs
,
"a2_gscale"
:
a2_gs
,
"w2_fp4"
:
w2_fp4
,
"w2_blockscale"
:
w2_blockscale
,
"w2_alphas"
:
w2_gs
,
"topk_weights"
:
topk_weights
,
"topk_ids"
:
topk_ids
,
"m"
:
m
,
"n"
:
n
,
"k"
:
k
,
"e"
:
num_experts
,
"device"
:
device
,
# cuda graph params
"cutlass_graph"
:
cutlass_graph
,
"triton_graph"
:
triton_graph
,
# Gen params
"num_runs"
:
num_runs
,
# Kernels
"run_triton_moe"
:
run_triton_moe
,
"run_cutlass_moe_fp4"
:
run_cutlass_moe_fp4
,
"replay_graph"
:
replay_graph
,
}
# Warmup
run_triton_moe
(
a
,
w1_fp8q_notransp
,
w2_fp8q_notransp
,
topk_weights
,
topk_ids
,
w1_fp8scale
,
w2_fp8scale
,
a_fp8_scale
,
num_warmup
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_runs)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"triton_moe"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
# Warmup
replay_graph
(
triton_graph
,
num_warmup
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"replay_graph(triton_graph, num_runs)"
,
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"triton_moe_cuda_graphs"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
# Warmup
run_cutlass_moe_fp4
(
a
,
w1_fp4
,
w2_fp4
,
w1_blockscale
,
w2_blockscale
,
w1_gs
,
w2_gs
,
a1_gs
,
a2_gs
,
topk_weights
,
topk_ids
,
m
,
n
,
k
,
num_experts
,
device
,
num_warmup
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_alphas, w2_alphas, a1_gscale, a2_gscale, topk_weights, topk_ids, m, n, k, e, device, num_runs)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"cutlass_moe_fp4"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
# Warmup
replay_graph
(
cutlass_graph
,
num_warmup
)
results
.
append
(
benchmark
.
Timer
(
stmt
=
"replay_graph(cutlass_graph, num_runs)"
,
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"cutlass_moe_fp4_cuda_graphs"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
def
main
(
args
):
print
(
"Benchmarking models:"
)
for
i
,
model
in
enumerate
(
args
.
models
):
print
(
f
"[
{
i
}
]
{
model
}
"
)
results
:
list
[
benchmark
.
Measurement
]
=
[]
for
model
in
args
.
models
:
for
tp
in
args
.
tp_sizes
:
for
layer
in
WEIGHT_SHAPES_MOE
[
model
]:
num_experts
=
layer
[
0
]
topk
=
layer
[
1
]
size_k
=
layer
[
2
]
size_n
=
layer
[
3
]
//
tp
if
len
(
args
.
limit_k
)
>
0
and
size_k
not
in
args
.
limit_k
:
continue
if
len
(
args
.
limit_n
)
>
0
and
size_n
not
in
args
.
limit_n
:
continue
for
per_act_token
in
PER_ACT_TOKEN_OPTS
:
for
per_out_ch
in
PER_OUT_CH_OPTS
:
for
size_m
in
args
.
batch_sizes
:
mkn
=
(
size_m
,
size_k
,
size_n
)
bench_run
(
results
,
model
,
num_experts
,
topk
,
per_act_token
,
per_out_ch
,
mkn
)
compare
=
benchmark
.
Compare
(
results
)
compare
.
print
()
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark NVFP4 CUTLASS MOE across specified "
"models/shapes/batches"
)
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
DEFAULT_MODELS
,
choices
=
WEIGHT_SHAPES_MOE
.
keys
(),
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_TP_SIZES
)
parser
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
parser
.
add_argument
(
"--limit-k"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-n"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-num-groups"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-per-act-token"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-per-out-ch"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
args
=
parser
.
parse_args
()
main
(
args
)
csrc/ops.h
View file @
0c0fdae8
...
...
@@ -208,6 +208,12 @@ void cutlass_moe_mm(
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
);
void
cutlass_fp4_group_mm
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
a_blockscale
,
const
torch
::
Tensor
&
b_blockscales
,
const
torch
::
Tensor
&
alphas
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
sf_offsets
);
void
get_cutlass_moe_mm_data
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
...
...
@@ -235,6 +241,12 @@ std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input_scale
);
void
scaled_fp4_experts_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
);
#endif
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
0c0fdae8
...
...
@@ -37,12 +37,6 @@ void cutlass_moe_mm_sm90(
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
);
void
get_cutlass_moe_mm_data_caller
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
);
#endif
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
...
...
@@ -53,6 +47,15 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \
defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100
void
get_cutlass_moe_mm_data_caller
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
);
#endif
void
cutlass_scaled_mm_azp_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
...
...
@@ -224,7 +227,8 @@ void get_cutlass_moe_mm_data(
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t
version_num
=
get_sm_version_num
();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90)
get_cutlass_moe_mm_data_caller
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
input_permutation
,
output_permutation
,
num_experts
,
n
,
k
);
...
...
csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
0 → 100644
View file @
0c0fdae8
#include <torch/all.h>
#include <cutlass/arch/arch.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include <cassert>
using
namespace
cute
;
template
<
typename
ElementAB
,
typename
ElementC
,
typename
ElementSF
,
typename
ElementAccumulator
,
typename
LayoutSFA
,
typename
LayoutSFB
,
typename
ScaleConfig
>
__global__
void
__get_group_gemm_starts
(
ElementAB
**
a_offsets
,
ElementAB
**
b_offsets
,
ElementC
**
out_offsets
,
ElementSF
**
a_scales_offsets
,
ElementSF
**
b_scales_offsets
,
ElementAccumulator
**
alpha_offsets
,
LayoutSFA
*
layout_sfa_base_as_int
,
LayoutSFB
*
layout_sfb_base_as_int
,
ElementAB
*
a_base_as_int
,
ElementAB
*
b_base_as_int
,
ElementC
*
out_base_as_int
,
ElementSF
*
a_scales_base_as_int
,
ElementSF
*
b_scales_base_as_int
,
ElementAccumulator
*
alphas_base_as_int
,
const
int32_t
*
expert_offsets
,
const
int32_t
*
sf_offsets
,
const
int32_t
*
problem_sizes_as_shapes
,
const
int
K
,
const
int
N
)
{
int64_t
expert_id
=
threadIdx
.
x
;
if
(
expert_id
>=
gridDim
.
x
*
blockDim
.
x
)
{
return
;
}
// Originally int32_t but upcasting to int64_t to avoid overflow
// during offset calculations
int64_t
expert_offset
=
static_cast
<
int64_t
>
(
expert_offsets
[
expert_id
]);
int64_t
sf_offset
=
static_cast
<
int64_t
>
(
sf_offsets
[
expert_id
]);
// size for block in block scale.
int64_t
group_size
=
16
;
int64_t
m
=
static_cast
<
int64_t
>
(
problem_sizes_as_shapes
[
expert_id
*
3
]);
int64_t
n
=
static_cast
<
int64_t
>
(
problem_sizes_as_shapes
[
expert_id
*
3
+
1
]);
int64_t
k
=
static_cast
<
int64_t
>
(
problem_sizes_as_shapes
[
expert_id
*
3
+
2
]);
assert
((
m
>=
0
&&
n
==
N
&&
k
==
K
&&
k
%
2
==
0
)
&&
"unexpected problem sizes"
);
int64_t
half_k
=
static_cast
<
int64_t
>
(
k
/
2
);
int64_t
group_k
=
static_cast
<
int64_t
>
(
k
/
group_size
);
// Shape of A as uint8/byte = [M, K // 2]
// Shape of B as uint8/byte = [E, N, K // 2]
a_offsets
[
expert_id
]
=
a_base_as_int
+
expert_offset
*
half_k
;
b_offsets
[
expert_id
]
=
b_base_as_int
+
expert_id
*
n
*
half_k
;
// Shape of C = [M, N]
out_offsets
[
expert_id
]
=
out_base_as_int
+
expert_offset
*
n
;
// Shape of a_scale = [sum(sf_sizes), K // group_size]
a_scales_offsets
[
expert_id
]
=
a_scales_base_as_int
+
sf_offset
*
group_k
;
assert
((
reinterpret_cast
<
uintptr_t
>
(
a_scales_offsets
[
expert_id
])
%
128
)
==
0
&&
"TMA requires 128-byte alignment"
);
// Shape of B scale = [E, N, K // group_size]
b_scales_offsets
[
expert_id
]
=
b_scales_base_as_int
+
expert_id
*
n
*
group_k
;
assert
((
reinterpret_cast
<
uintptr_t
>
(
b_scales_offsets
[
expert_id
])
%
128
)
==
0
&&
"TMA requires 128-byte alignment"
);
// Shape of alpha = [E]
alpha_offsets
[
expert_id
]
=
alphas_base_as_int
+
expert_id
;
LayoutSFA
*
layout_sfa_ptr
=
layout_sfa_base_as_int
+
expert_id
;
LayoutSFB
*
layout_sfb_ptr
=
layout_sfb_base_as_int
+
expert_id
;
*
layout_sfa_ptr
=
ScaleConfig
::
tile_atom_to_shape_SFA
(
cute
::
make_shape
(
static_cast
<
int
>
(
m
),
static_cast
<
int
>
(
n
),
static_cast
<
int
>
(
k
),
1
));
*
layout_sfb_ptr
=
ScaleConfig
::
tile_atom_to_shape_SFB
(
cute
::
make_shape
(
static_cast
<
int
>
(
m
),
static_cast
<
int
>
(
n
),
static_cast
<
int
>
(
k
),
1
));
}
#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \
TENSOR_C_TYPE, C_TYPE, LayoutSFA, \
LayoutSFB, ScaleConfig) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
__get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, \
LayoutSFA, LayoutSFB, ScaleConfig> \
<<<1, num_experts, 0, stream>>>( \
static_cast<ELEMENT_AB_TYPE**>(a_starts.data_ptr()), \
static_cast<ELEMENT_AB_TYPE**>(b_starts.data_ptr()), \
static_cast<C_TYPE**>(out_starts.data_ptr()), \
static_cast<SF_TYPE**>(a_scales_starts.data_ptr()), \
static_cast<SF_TYPE**>(b_scales_starts.data_ptr()), \
static_cast<float**>(alpha_starts.data_ptr()), \
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
static_cast<ELEMENT_AB_TYPE*>(a_tensors.data_ptr()), \
static_cast<ELEMENT_AB_TYPE*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<SF_TYPE*>(a_scales.data_ptr()), \
static_cast<SF_TYPE*>(b_scales.data_ptr()), \
static_cast<float*>(alphas.data_ptr()), \
static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<int32_t*>(sf_offsets.data_ptr()), \
static_cast<int32_t*>(problem_sizes.data_ptr()), K, N); \
}
template
<
typename
LayoutSFA
,
typename
LayoutSFB
,
typename
ScaleConfig
>
void
run_get_group_gemm_starts
(
const
torch
::
Tensor
&
a_starts
,
const
torch
::
Tensor
&
b_starts
,
const
torch
::
Tensor
&
out_starts
,
const
torch
::
Tensor
&
a_scales_starts
,
const
torch
::
Tensor
&
b_scales_starts
,
const
torch
::
Tensor
&
alpha_starts
,
const
torch
::
Tensor
&
layout_sfa
,
const
torch
::
Tensor
&
layout_sfb
,
/*these are used for their base addresses*/
torch
::
Tensor
const
&
a_tensors
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
out_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
alphas
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
sf_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
int
M
,
int
N
,
int
K
)
{
int
num_experts
=
(
int
)
expert_offsets
.
size
(
0
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a_tensors
.
device
().
index
());
TORCH_CHECK
(
out_tensors
.
size
(
1
)
==
N
,
"Output tensor shape doesn't match expected shape"
);
TORCH_CHECK
(
K
/
2
==
b_tensors
.
size
(
2
),
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
" dimension must match"
);
if
(
false
)
{
}
//(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
// ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE
(
cutlass
::
float_e2m1_t
,
cutlass
::
float_ue4m3_t
,
torch
::
kBFloat16
,
cutlass
::
bfloat16_t
,
LayoutSFA
,
LayoutSFB
,
ScaleConfig
)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE
(
cutlass
::
float_e2m1_t
,
cutlass
::
float_ue4m3_t
,
torch
::
kFloat16
,
half
,
LayoutSFA
,
LayoutSFB
,
ScaleConfig
)
else
{
TORCH_CHECK
(
false
,
"Invalid output type (must be float16 or bfloat16)"
);
}
}
template
<
typename
OutType
>
void
run_fp4_blockwise_scaled_group_mm
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
a_blockscale
,
const
torch
::
Tensor
&
b_blockscales
,
const
torch
::
Tensor
&
alphas
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
sf_offsets
,
int
M
,
int
N
,
int
K
)
{
using
ProblemShape
=
cutlass
::
gemm
::
GroupProblemShape
<
Shape
<
int32_t
,
int32_t
,
int32_t
>>
;
using
ElementType
=
cutlass
::
float_e2m1_t
;
using
ElementSFType
=
cutlass
::
float_ue4m3_t
;
using
ElementA
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
ElementB
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
ElementC
=
OutType
;
using
ElementD
=
ElementC
;
using
ElementAccumulator
=
float
;
// Layout definitions
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
using
LayoutD
=
LayoutC
;
// Alignment constraints
static
constexpr
int
AlignmentA
=
32
;
static
constexpr
int
AlignmentB
=
32
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementC
>::
value
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
// Architecture definitions
using
ArchTag
=
cutlass
::
arch
::
Sm100
;
using
EpilogueOperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
// Epilogue Operator class tag
using
MainloopOperatorClass
=
cutlass
::
arch
::
OpClassBlockScaledTensorOp
;
// Mainloop Operator class tag
using
StageCountType
=
cutlass
::
gemm
::
collective
::
StageCountAuto
;
// Stage count maximized based
// on the tile size
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
struct
MMA1SMConfig
{
using
MmaTileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100
;
// Kernel to launch
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecialized1Sm
;
// Epilogue to launch
};
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
EpilogueOperatorClass
,
typename
MMA1SMConfig
::
MmaTileShape
,
ClusterShape
,
Shape
<
_128
,
_64
>
,
ElementAccumulator
,
ElementAccumulator
,
ElementC
,
LayoutC
*
,
AlignmentC
,
ElementD
,
LayoutC
*
,
AlignmentD
,
typename
MMA1SMConfig
::
EpilogueSchedule
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
MainloopOperatorClass
,
ElementA
,
LayoutA
*
,
AlignmentA
,
ElementB
,
LayoutB
*
,
AlignmentB
,
ElementAccumulator
,
typename
MMA1SMConfig
::
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
typename
MMA1SMConfig
::
KernelSchedule
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
ProblemShape
,
CollectiveMainloop
,
CollectiveEpilogue
>
;
using
Gemm1SM
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
using
Gemm
=
Gemm1SM
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
InternalStrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
InternalStrideB
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
InternalStrideC
;
using
StrideD
=
typename
Gemm
::
GemmKernel
::
InternalStrideD
;
using
LayoutSFA
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
InternalLayoutSFA
;
using
LayoutSFB
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
InternalLayoutSFB
;
using
ScaleConfig
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
Sm1xxBlkScaledConfig
;
using
UnderlyingProblemShape
=
ProblemShape
::
UnderlyingProblemShape
;
int
num_experts
=
static_cast
<
int
>
(
expert_offsets
.
size
(
0
));
auto
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
());
torch
::
Tensor
a_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
b_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
out_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
a_scales_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
b_scales_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
alpha_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
layout_sfa
=
torch
::
empty
({
num_experts
,
5
},
options_int
);
torch
::
Tensor
layout_sfb
=
torch
::
empty
({
num_experts
,
5
},
options_int
);
torch
::
Tensor
c_strides1
=
torch
::
full
({
num_experts
},
output
.
stride
(
0
),
options_int
);
torch
::
Tensor
a_strides1
=
torch
::
full
({
num_experts
},
a
.
stride
(
0
)
*
2
,
options_int
);
torch
::
Tensor
b_strides1
=
torch
::
full
({
num_experts
},
b
.
stride
(
1
)
*
2
,
options_int
);
run_get_group_gemm_starts
<
LayoutSFA
,
LayoutSFB
,
ScaleConfig
>
(
a_ptrs
,
b_ptrs
,
out_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
alpha_ptrs
,
layout_sfa
,
layout_sfb
,
a
,
b
,
output
,
a_blockscale
,
b_blockscales
,
alphas
,
expert_offsets
,
sf_offsets
,
problem_sizes
,
M
,
N
,
K
);
// Create an instance of the GEMM
Gemm
gemm_op
;
// Initialize problem_sizes_as_shapes correctly
UnderlyingProblemShape
*
problem_sizes_as_shapes
=
static_cast
<
UnderlyingProblemShape
*>
(
problem_sizes
.
data_ptr
());
// Set the Scheduler info
cutlass
::
KernelHardwareInfo
hw_info
;
using
RasterOrderOptions
=
typename
cutlass
::
gemm
::
kernel
::
detail
::
PersistentTileSchedulerSm100GroupParams
<
typename
ProblemShape
::
UnderlyingProblemShape
>::
RasterOrderOptions
;
typename
Gemm
::
GemmKernel
::
TileSchedulerArguments
scheduler
;
scheduler
.
raster_order
=
RasterOrderOptions
::
AlongM
;
hw_info
.
device_id
=
a
.
get_device
();
static
std
::
unordered_map
<
int
,
int
>
cached_sm_counts
;
if
(
cached_sm_counts
.
find
(
hw_info
.
device_id
)
==
cached_sm_counts
.
end
())
{
cached_sm_counts
[
hw_info
.
device_id
]
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
}
hw_info
.
sm_count
=
min
(
cached_sm_counts
[
hw_info
.
device_id
],
INT_MAX
);
// Mainloop Arguments
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
static_cast
<
const
ElementType
**>
(
a_ptrs
.
data_ptr
()),
static_cast
<
StrideA
*>
(
a_strides1
.
data_ptr
()),
static_cast
<
const
ElementType
**>
(
b_ptrs
.
data_ptr
()),
static_cast
<
StrideB
*>
(
b_strides1
.
data_ptr
()),
static_cast
<
const
ElementSFType
**>
(
a_scales_ptrs
.
data_ptr
()),
reinterpret_cast
<
LayoutSFA
*>
(
layout_sfa
.
data_ptr
()),
static_cast
<
const
ElementSFType
**>
(
b_scales_ptrs
.
data_ptr
()),
reinterpret_cast
<
LayoutSFB
*>
(
layout_sfb
.
data_ptr
())};
// Epilogue Arguments
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
{},
// epilogue.thread
nullptr
,
static_cast
<
StrideC
*>
(
c_strides1
.
data_ptr
()),
static_cast
<
ElementD
**>
(
out_ptrs
.
data_ptr
()),
static_cast
<
StrideC
*>
(
c_strides1
.
data_ptr
())};
auto
&
fusion_args
=
epilogue_args
.
thread
;
fusion_args
.
alpha_ptr_array
=
reinterpret_cast
<
float
**>
(
alpha_ptrs
.
data_ptr
());
fusion_args
.
dAlpha
=
{
_0
{},
_0
{},
1
};
// Gemm Arguments
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGrouped
,
{
num_experts
,
problem_sizes_as_shapes
,
nullptr
},
mainloop_args
,
epilogue_args
,
hw_info
,
scheduler
};
size_t
workspace_size
=
Gemm
::
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
auto
can_implement_status
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement_status
==
cutlass
::
Status
::
kSuccess
,
"Failed to implement GEMM"
);
// Run the GEMM
auto
status
=
gemm_op
.
initialize
(
args
,
workspace
.
data_ptr
());
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to initialize GEMM"
);
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to run GEMM"
);
}
constexpr
auto
FLOAT4_E2M1X2
=
at
::
ScalarType
::
Byte
;
constexpr
auto
SF_DTYPE
=
at
::
ScalarType
::
Float8_e4m3fn
;
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
void
cutlass_fp4_group_mm
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
a_blockscale
,
const
torch
::
Tensor
&
b_blockscales
,
const
torch
::
Tensor
&
alphas
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
sf_offsets
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
// Input validation
CHECK_INPUT
(
a
,
FLOAT4_E2M1X2
,
"a"
);
CHECK_INPUT
(
b
,
FLOAT4_E2M1X2
,
"b"
);
CHECK_INPUT
(
a_blockscale
,
SF_DTYPE
,
"a_blockscale"
);
CHECK_INPUT
(
b_blockscales
,
SF_DTYPE
,
"b_blockscales"
);
CHECK_INPUT
(
alphas
,
at
::
ScalarType
::
Float
,
"alphas"
);
TORCH_CHECK
(
a_blockscale
.
dim
()
==
2
,
"expected a_blockscale to be of shape [num_experts, rounded_m,"
" k // group_size], observed rank: "
,
a_blockscale
.
dim
())
TORCH_CHECK
(
b_blockscales
.
dim
()
==
3
,
"expected b_blockscale to be of shape: "
" [num_experts, n, k // group_size], observed rank: "
,
b_blockscales
.
dim
())
TORCH_CHECK
(
problem_sizes
.
dim
()
==
2
,
"problem_sizes must be a 2D tensor"
);
TORCH_CHECK
(
problem_sizes
.
size
(
1
)
==
3
,
"problem_sizes must have the shape (num_experts, 3)"
);
TORCH_CHECK
(
problem_sizes
.
size
(
0
)
==
expert_offsets
.
size
(
0
),
"Number of experts in problem_sizes must match expert_offsets"
);
TORCH_CHECK
(
problem_sizes
.
dtype
()
==
torch
::
kInt32
,
"problem_sizes must be int32."
);
int
M
=
static_cast
<
int
>
(
a
.
size
(
0
));
int
N
=
static_cast
<
int
>
(
b
.
size
(
1
));
int
E
=
static_cast
<
int
>
(
b
.
size
(
0
));
int
K
=
static_cast
<
int
>
(
2
*
b
.
size
(
2
));
if
(
output
.
scalar_type
()
==
torch
::
kBFloat16
)
{
run_fp4_blockwise_scaled_group_mm
<
cutlass
::
bfloat16_t
>
(
output
,
a
,
b
,
a_blockscale
,
b_blockscales
,
alphas
,
problem_sizes
,
expert_offsets
,
sf_offsets
,
M
,
N
,
K
);
}
else
{
run_fp4_blockwise_scaled_group_mm
<
cutlass
::
half_t
>
(
output
,
a
,
b
,
a_blockscale
,
b_blockscales
,
alphas
,
problem_sizes
,
expert_offsets
,
sf_offsets
,
M
,
N
,
K
);
}
#else
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
"be compiled with ENABLE_NVFP4 for SM100+ and CUDA "
"12.8 or above."
);
#endif
}
csrc/quantization/fp4/nvfp4_experts_quant.cu
0 → 100644
View file @
0c0fdae8
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <cuda_fp8.h>
template
<
typename
T
>
struct
TypeConverter
{
using
Type
=
half2
;
};
// keep for generality
template
<
>
struct
TypeConverter
<
half2
>
{
using
Type
=
half
;
};
template
<
>
struct
TypeConverter
<
half
>
{
using
Type
=
half2
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat162
>
{
using
Type
=
__nv_bfloat16
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat16
>
{
using
Type
=
__nv_bfloat162
;
};
#define ELTS_PER_THREAD 8
constexpr
int
CVT_FP4_ELTS_PER_THREAD
=
8
;
constexpr
int
CVT_FP4_SF_VEC_SIZE
=
16
;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float
(
&
array
)[
8
])
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
]),
"f"
(
array
[
1
]),
"f"
(
array
[
2
]),
"f"
(
array
[
3
]),
"f"
(
array
[
4
]),
"f"
(
array
[
5
]),
"f"
(
array
[
6
]),
"f"
(
array
[
7
]));
return
val
;
#else
return
0
;
#endif
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float2
(
&
array
)[
4
])
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
].
x
),
"f"
(
array
[
0
].
y
),
"f"
(
array
[
1
].
x
),
"f"
(
array
[
1
].
y
),
"f"
(
array
[
2
].
x
),
"f"
(
array
[
2
].
y
),
"f"
(
array
[
3
].
x
),
"f"
(
array
[
3
].
y
));
return
val
;
#else
return
0
;
#endif
}
// Fast reciprocal.
inline
__device__
float
reciprocal_approximate_ftz
(
float
a
)
{
float
b
;
asm
volatile
(
"rcp.approx.ftz.f32 %0, %1;
\n
"
:
"=f"
(
b
)
:
"f"
(
a
));
return
b
;
}
template
<
class
SFType
,
int
CVT_FP4_NUM_THREADS_PER_SF
>
__device__
uint8_t
*
cvt_quant_to_fp4_get_sf_out_offset
(
int
rowIdx
,
int
colIdx
,
int
numCols
,
SFType
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert
(
CVT_FP4_NUM_THREADS_PER_SF
==
1
||
CVT_FP4_NUM_THREADS_PER_SF
==
2
);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if
(
threadIdx
.
x
%
CVT_FP4_NUM_THREADS_PER_SF
==
0
)
{
// SF vector index (16 elements share one SF in the K dimension).
int32_t
kIdx
=
colIdx
/
CVT_FP4_NUM_THREADS_PER_SF
;
int32_t
mIdx
=
rowIdx
;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t
mTileIdx
=
mIdx
/
(
32
*
4
);
// SF vector size 16.
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
int32_t
numKTiles
=
(
numCols
+
factor
-
1
)
/
factor
;
int64_t
mTileStride
=
numKTiles
*
32
*
4
*
4
;
int32_t
kTileIdx
=
(
kIdx
/
4
);
int64_t
kTileStride
=
32
*
4
*
4
;
// M tile layout [32, 4] is column-major.
int32_t
outerMIdx
=
(
mIdx
%
32
);
int64_t
outerMStride
=
4
*
4
;
int32_t
innerMIdx
=
(
mIdx
%
(
32
*
4
))
/
32
;
int64_t
innerMStride
=
4
;
int32_t
innerKIdx
=
(
kIdx
%
4
);
int64_t
innerKStride
=
1
;
// Compute the global offset.
int64_t
SFOffset
=
mTileIdx
*
mTileStride
+
kTileIdx
*
kTileStride
+
outerMIdx
*
outerMStride
+
innerMIdx
*
innerMStride
+
innerKIdx
*
innerKStride
;
return
reinterpret_cast
<
uint8_t
*>
(
SFout
)
+
SFOffset
;
}
#endif
return
nullptr
;
}
// Define a 16 bytes packed data type.
template
<
class
Type
>
struct
PackedVec
{
typename
TypeConverter
<
Type
>::
Type
elts
[
4
];
};
template
<
>
struct
PackedVec
<
__nv_fp8_e4m3
>
{
__nv_fp8x2_e4m3
elts
[
8
];
};
// Quantizes the provided PackedVec into the uint32_t output
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__device__
uint32_t
cvt_warp_fp16_to_fp4
(
PackedVec
<
Type
>&
vec
,
float
SFScaleVal
,
uint8_t
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto
localMax
=
__habs2
(
vec
.
elts
[
0
]);
// Local maximum value.
#pragma unroll
for
(
int
i
=
1
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
localMax
=
__hmax2
(
localMax
,
__habs2
(
vec
.
elts
[
i
]));
}
// Get the absolute maximum among all 16 values (two threads).
localMax
=
__hmax2
(
__shfl_xor_sync
(
uint32_t
(
-
1
),
localMax
,
1
),
localMax
);
// Get the final absolute maximum values.
float
vecMax
=
float
(
__hmax
(
localMax
.
x
,
localMax
.
y
));
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float
SFValue
=
SFScaleVal
*
(
vecMax
*
reciprocal_approximate_ftz
(
6.0
f
));
// 8 bits representation of the SF.
uint8_t
fp8SFVal
;
// Write the SF to global memory (STG.8).
if
constexpr
(
UE8M0_SF
)
{
// Extract the 8 exponent bits from float32.
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
uint32_t
tmp
=
reinterpret_cast
<
uint32_t
&>
(
SFValue
)
>>
23
;
fp8SFVal
=
tmp
&
0xff
;
// Convert back to fp32.
reinterpret_cast
<
uint32_t
&>
(
SFValue
)
=
tmp
<<
23
;
}
else
{
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3
tmp
=
__nv_fp8_e4m3
(
SFValue
);
reinterpret_cast
<
__nv_fp8_e4m3
&>
(
fp8SFVal
)
=
tmp
;
// Convert back to fp32.
SFValue
=
float
(
tmp
);
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal))
float
outputScale
=
SFValue
!=
0
?
reciprocal_approximate_ftz
(
SFValue
*
reciprocal_approximate_ftz
(
SFScaleVal
))
:
0.0
f
;
if
(
SFout
)
{
// Write the SF to global memory (STG.8).
*
SFout
=
fp8SFVal
;
}
// Convert the input to float.
float2
fp2Vals
[
CVT_FP4_ELTS_PER_THREAD
/
2
];
#pragma unroll
for
(
int
i
=
0
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
if
constexpr
(
std
::
is_same_v
<
Type
,
half
>
)
{
fp2Vals
[
i
]
=
__half22float2
(
vec
.
elts
[
i
]);
}
else
{
fp2Vals
[
i
]
=
__bfloat1622float2
(
vec
.
elts
[
i
]);
}
fp2Vals
[
i
].
x
*=
outputScale
;
fp2Vals
[
i
].
y
*=
outputScale
;
}
// Convert to e2m1 values.
uint32_t
e2m1Vec
=
fp32_vec_to_e2m1
(
fp2Vals
);
// Write the e2m1 values to global memory.
return
e2m1Vec
;
#else
return
0
;
#endif
}
// Use UE4M3 by default.
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__global__
void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__
(
512
,
4
)
cvt_fp16_to_fp4
(
#else
cvt_fp16_to_fp4
(
#endif
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
float
const
*
SFScale
,
uint32_t
*
out
,
uint32_t
*
SFout
,
uint32_t
*
input_offset_by_experts
,
uint32_t
*
output_scale_offset_by_experts
,
int
n_experts
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using
PackedVec
=
PackedVec
<
Type
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
"Vec size is not matched."
);
// Input tensor row/col loops.
for
(
int
rowIdx
=
blockIdx
.
x
;
rowIdx
<
numRows
;
rowIdx
+=
gridDim
.
x
)
{
for
(
int
colIdx
=
threadIdx
.
x
;
colIdx
<
numCols
/
CVT_FP4_ELTS_PER_THREAD
;
colIdx
+=
blockDim
.
x
)
{
int64_t
inOffset
=
rowIdx
*
(
numCols
/
CVT_FP4_ELTS_PER_THREAD
)
+
colIdx
;
PackedVec
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t
outOffset
=
inOffset
;
auto
&
out_pos
=
out
[
outOffset
];
// Find index within the experts.
int
rowIdx_in_expert
=
0
;
int
expert_idx
=
0
;
for
(
int
i
=
0
;
i
<
n_experts
;
i
++
)
{
if
(
rowIdx
>=
input_offset_by_experts
[
i
]
&&
rowIdx
<
input_offset_by_experts
[
i
+
1
])
{
rowIdx_in_expert
=
rowIdx
-
input_offset_by_experts
[
i
];
expert_idx
=
i
;
break
;
}
}
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float
const
SFScaleVal
=
SFScale
==
nullptr
?
1.0
f
:
SFScale
[
expert_idx
];
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
// The actual output_scales dim is computed from the padded numCols.
int32_t
numCols_padded
=
(
numCols
+
factor
-
1
)
/
factor
*
factor
;
int
numCols_SFout
=
numCols_padded
/
CVT_FP4_SF_VEC_SIZE
/
4
;
uint32_t
*
SFout_in_expert
=
SFout
+
output_scale_offset_by_experts
[
expert_idx
]
*
numCols_SFout
;
auto
sf_out
=
cvt_quant_to_fp4_get_sf_out_offset
<
uint32_t
,
CVT_FP4_NUM_THREADS_PER_SF
>
(
rowIdx_in_expert
,
colIdx
,
numCols
,
SFout_in_expert
);
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
in_vec
,
SFScaleVal
,
sf_out
);
}
}
#endif
}
template
<
typename
T
>
void
quant_impl
(
void
*
output
,
void
*
output_scale
,
void
*
input
,
void
*
input_global_scale
,
void
*
input_offset_by_experts
,
void
*
output_scale_offset_by_experts
,
int
m_topk
,
int
k
,
int
n_experts
,
cudaStream_t
stream
)
{
// TODO: this multiProcessorCount should be cached.
int
device
;
cudaGetDevice
(
&
device
);
int
multiProcessorCount
;
cudaDeviceGetAttribute
(
&
multiProcessorCount
,
cudaDevAttrMultiProcessorCount
,
device
);
// Grid, Block size.
// Each thread converts 8 values.
dim3
block
(
std
::
min
(
int
(
k
/
ELTS_PER_THREAD
),
512
));
// Get number of blocks per SM (assume we can fully utilize the SM).
int
const
numBlocksPerSM
=
2048
/
block
.
x
;
dim3
grid
(
std
::
min
(
int
(
m_topk
),
multiProcessorCount
*
numBlocksPerSM
));
cvt_fp16_to_fp4
<
T
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
m_topk
,
k
,
reinterpret_cast
<
T
*>
(
input
),
reinterpret_cast
<
float
*>
(
input_global_scale
),
reinterpret_cast
<
uint32_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
n_experts
);
}
/*Quantization entry for fp4 experts quantization*/
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m);
constexpr
auto
HALF
=
at
::
ScalarType
::
Half
;
constexpr
auto
BF16
=
at
::
ScalarType
::
BFloat16
;
constexpr
auto
FLOAT
=
at
::
ScalarType
::
Float
;
constexpr
auto
INT
=
at
::
ScalarType
::
Int
;
constexpr
auto
UINT8
=
at
::
ScalarType
::
Byte
;
void
scaled_fp4_experts_quant_sm100a
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
)
{
CHECK_INPUT
(
output
,
"output must be a CUDA tensor"
);
CHECK_INPUT
(
output_scale
,
"output_scale must be a CUDA tensor"
);
CHECK_INPUT
(
input
,
"input must be a CUDA tensor"
);
CHECK_INPUT
(
input_global_scale
,
"input_global_scale must be a CUDA tensor"
);
CHECK_INPUT
(
input_offset_by_experts
,
"input_offset_by_experts must be a CUDA tensor"
);
CHECK_INPUT
(
output_scale_offset_by_experts
,
"output_scale_offset_by_experts must be a CUDA tensor"
);
TORCH_CHECK
(
output
.
dim
()
==
2
);
TORCH_CHECK
(
output_scale
.
dim
()
==
2
);
TORCH_CHECK
(
input
.
dim
()
==
2
);
TORCH_CHECK
(
input_global_scale
.
dim
()
==
1
);
TORCH_CHECK
(
input_offset_by_experts
.
dim
()
==
1
);
TORCH_CHECK
(
output_scale_offset_by_experts
.
dim
()
==
1
);
TORCH_CHECK
(
input
.
scalar_type
()
==
HALF
||
input
.
scalar_type
()
==
BF16
);
TORCH_CHECK
(
input_global_scale
.
scalar_type
()
==
FLOAT
);
TORCH_CHECK
(
input_offset_by_experts
.
scalar_type
()
==
INT
);
TORCH_CHECK
(
output_scale_offset_by_experts
.
scalar_type
()
==
INT
);
// output is uint8 (two nvfp4 values are packed into one uint8)
// output_scale is int32 (four fp8 values are packed into one int32)
TORCH_CHECK
(
output
.
scalar_type
()
==
UINT8
);
TORCH_CHECK
(
output_scale
.
scalar_type
()
==
INT
);
const
int
BLOCK_SIZE
=
16
;
auto
m_topk
=
input
.
size
(
0
);
auto
k
=
input
.
size
(
1
);
TORCH_CHECK
(
k
%
BLOCK_SIZE
==
0
,
"k must be a multiple of 16"
);
auto
n_experts
=
input_global_scale
.
size
(
0
);
TORCH_CHECK
(
input_offset_by_experts
.
size
(
0
)
==
n_experts
+
1
);
TORCH_CHECK
(
output_scale_offset_by_experts
.
size
(
0
)
==
n_experts
+
1
);
TORCH_CHECK
(
output
.
size
(
0
)
==
m_topk
);
TORCH_CHECK
(
output
.
size
(
1
)
==
k
/
2
);
int
scales_k
=
k
/
BLOCK_SIZE
;
// 4 means the swizzle requirement by nvidia nvfp4.
int
padded_k
=
(
scales_k
+
(
4
-
1
))
/
4
*
4
;
// 4 means 4 fp8 values are packed into one int32
TORCH_CHECK
(
output_scale
.
size
(
1
)
*
4
==
padded_k
);
auto
in_dtype
=
input
.
dtype
();
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
input
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
input
.
get_device
());
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
quant_impl
<
half
>
(
output
.
data_ptr
(),
output_scale
.
data_ptr
(),
input
.
data_ptr
(),
input_global_scale
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
m_topk
,
k
,
n_experts
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
BFloat16
)
{
quant_impl
<
__nv_bfloat16
>
(
output
.
data_ptr
(),
output_scale
.
data_ptr
(),
input
.
data_ptr
(),
input_global_scale
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
m_topk
,
k
,
n_experts
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Expected input data type to be half or bfloat16"
);
}
}
\ No newline at end of file
csrc/quantization/fp4/nvfp4_quant_entry.cu
View file @
0c0fdae8
...
...
@@ -23,10 +23,32 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output,
torch
::
Tensor
const
&
input_sf
);
#endif
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
void
scaled_fp4_experts_quant_sm100a
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
);
#endif
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return
scaled_fp4_quant_sm100a
(
output
,
input
,
output_sf
,
input_sf
);
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 quantization"
);
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 quantization kernel"
);
}
void
scaled_fp4_experts_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return
scaled_fp4_experts_quant_sm100a
(
output
,
output_scale
,
input
,
input_global_scale
,
input_offset_by_experts
,
output_scale_offset_by_experts
);
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 experts quantization kernel"
);
}
csrc/torch_bindings.cpp
View file @
0c0fdae8
...
...
@@ -363,6 +363,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
{
stride_tag
});
ops
.
impl
(
"cutlass_scaled_fp4_mm"
,
torch
::
kCUDA
,
&
cutlass_scaled_fp4_mm
);
// cutlass nvfp4 block scaled group GEMM
ops
.
def
(
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()"
,
{
stride_tag
});
ops
.
impl
(
"cutlass_fp4_group_mm"
,
torch
::
kCUDA
,
&
cutlass_fp4_group_mm
);
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops
.
def
(
...
...
@@ -492,6 +500,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! output_scale, Tensor input_scale) -> ()"
);
ops
.
impl
(
"scaled_fp4_quant"
,
torch
::
kCUDA
,
&
scaled_fp4_quant
);
// Compute NVFP4 experts quantization.
ops
.
def
(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()"
);
ops
.
impl
(
"scaled_fp4_experts_quant"
,
torch
::
kCUDA
,
&
scaled_fp4_experts_quant
);
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
// of the given capability
ops
.
def
(
"cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool"
);
...
...
tests/kernels/moe/test_nvfp4_moe.py
0 → 100644
View file @
0c0fdae8
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
torch
from
tests.kernels.quantization.nvfp4_utils
import
(
FLOAT4_E2M1_MAX
,
FLOAT8_E4M3_MAX
,
dequantize_nvfp4_to_dtype
)
from
tests.kernels.utils
import
torch_moe
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
cutlass_moe_fp4
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.platforms
import
current_platform
if
not
current_platform
.
has_device_capability
(
100
):
pytest
.
skip
(
reason
=
"Nvfp4 Requires compute capability of 10 or above."
,
allow_module_level
=
True
)
MNK_FACTORS
=
[
(
2
,
1024
,
1024
),
(
2
,
1024
,
1536
),
(
2
,
3072
,
1024
),
(
2
,
3072
,
1536
),
(
64
,
1024
,
1024
),
(
64
,
1024
,
1536
),
(
64
,
3072
,
1024
),
(
64
,
2048
,
1536
),
(
224
,
1024
,
1024
),
(
224
,
1024
,
1536
),
]
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
[
40
,
64
,
256
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
1
,
6
,
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
half
,
torch
.
bfloat16
])
@
torch
.
inference_mode
()
def
test_cutlass_fp4_moe_no_graph
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
):
current_platform
.
seed_everything
(
7
)
with
set_current_vllm_config
(
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))):
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
quant_blocksize
=
16
round_up
=
lambda
x
,
y
:
(
x
+
y
-
1
)
//
y
*
y
sf_w1_2n
=
round_up
(
2
*
n
,
128
)
sf_w1_k
=
round_up
(
k
//
quant_blocksize
,
4
)
w1_blockscale
=
torch
.
empty
((
e
,
sf_w1_2n
,
sf_w1_k
),
device
=
"cuda"
,
dtype
=
torch
.
float8_e4m3fn
)
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
sf_w2_k
=
round_up
(
k
,
128
)
sf_w2_n
=
round_up
(
n
//
quant_blocksize
,
4
)
w2_blockscale
=
torch
.
empty
((
e
,
sf_w2_k
,
sf_w2_n
),
device
=
"cuda"
,
dtype
=
torch
.
float8_e4m3fn
)
w1_q
=
torch
.
empty
((
e
,
2
*
n
,
k
//
2
),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
w2_q
=
torch
.
empty
((
e
,
k
,
n
//
2
),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
w1_gs
=
torch
.
empty
((
e
,
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
w2_gs
=
torch
.
empty
((
e
,
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
for
expert
in
range
(
e
):
w1_amax
=
torch
.
abs
(
w1
).
max
().
to
(
torch
.
float32
)
w2_amax
=
torch
.
abs
(
w2
).
max
().
to
(
torch
.
float32
)
w1_gs
[
expert
]
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
w1_amax
w2_gs
[
expert
]
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
w2_amax
w1_q
[
expert
],
w1_blockscale
[
expert
]
=
ops
.
scaled_fp4_quant
(
w1
[
expert
],
w1_gs
[
expert
])
w2_q
[
expert
],
w2_blockscale
[
expert
]
=
ops
.
scaled_fp4_quant
(
w2
[
expert
],
w2_gs
[
expert
])
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
,
topk_ids
=
fused_topk
(
a
,
score
,
topk
,
renormalize
=
False
)
a1_gs
=
torch
.
ones
((
e
,
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
a2_gs
=
torch
.
ones
((
e
,
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
cutlass_output
=
cutlass_moe_fp4
(
a
=
a
,
a1_gscale
=
a1_gs
,
w1_fp4
=
w1_q
,
w1_blockscale
=
w1_blockscale
,
w1_alphas
=
(
1
/
w1_gs
),
a2_gscale
=
a2_gs
,
w2_fp4
=
w2_q
,
w2_blockscale
=
w2_blockscale
,
w2_alphas
=
(
1
/
w2_gs
),
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
m
=
m
,
n
=
n
,
k
=
k
,
e
=
e
,
device
=
a
.
device
,
)
# Reference check:
a_global_scale
=
((
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
a
.
flatten
(),
dim
=-
1
)).
to
(
torch
.
float32
)
a_fp4
,
a_scale_interleaved
=
ops
.
scaled_fp4_quant
(
a
,
a_global_scale
)
_
,
m_k
=
a_fp4
.
shape
a_in_dtype
=
dequantize_nvfp4_to_dtype
(
a_fp4
,
a_scale_interleaved
,
a_global_scale
,
dtype
=
a
.
dtype
,
device
=
a
.
device
,
block_size
=
quant_blocksize
)
w1_d
=
torch
.
empty
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
w2_d
=
torch
.
empty
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
for
idx
in
range
(
0
,
e
):
w1_d
[
idx
]
=
dequantize_nvfp4_to_dtype
(
w1_q
[
idx
],
w1_blockscale
[
idx
],
w1_gs
[
idx
],
dtype
=
w1
.
dtype
,
device
=
w1
.
device
,
block_size
=
quant_blocksize
)
w2_d
[
idx
]
=
dequantize_nvfp4_to_dtype
(
w2_q
[
idx
],
w2_blockscale
[
idx
],
w2_gs
[
idx
],
dtype
=
w2
.
dtype
,
device
=
w2
.
device
,
block_size
=
quant_blocksize
)
torch_output
=
torch_moe
(
a_in_dtype
,
w1_d
,
w2_d
,
score
,
topk
,
None
)
torch
.
testing
.
assert_close
(
torch_output
,
cutlass_output
,
atol
=
1e-1
,
rtol
=
1e-1
)
if
__name__
==
"__main__"
:
test_cutlass_fp4_moe_no_graph
((
2
,
1024
,
1024
),
40
,
1
,
torch
.
half
)
tests/kernels/quantization/nvfp4_utils.py
0 → 100644
View file @
0c0fdae8
# SPDX-License-Identifier: Apache-2.0
import
torch
from
vllm.scalar_type
import
scalar_types
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
.
max
()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
kE2M1ToFloat
=
torch
.
tensor
([
0.
,
0.5
,
1.
,
1.5
,
2.
,
3.
,
4.
,
6.
],
dtype
=
torch
.
float32
)
def
convert_swizzled_to_linear
(
a_sf_swizzled
:
torch
.
Tensor
,
m
,
k
,
block_size
):
m_tiles
=
(
m
+
128
-
1
)
//
128
f
=
block_size
*
4
k_tiles
=
(
k
+
f
-
1
)
//
f
tmp
=
torch
.
reshape
(
a_sf_swizzled
,
(
1
,
m_tiles
,
k_tiles
,
32
,
4
,
4
))
tmp
=
torch
.
permute
(
tmp
,
(
0
,
1
,
4
,
3
,
2
,
5
))
out
=
tmp
.
reshape
(
m_tiles
*
128
,
k_tiles
*
f
//
block_size
)
return
out
[
0
:
m
,
0
:
k
]
def
dequantize_nvfp4_to_dtype
(
tensor_fp4
,
tensor_sf
,
global_scale
,
dtype
,
device
,
block_size
=
16
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert
tensor_fp4
.
dtype
==
torch
.
uint8
m
,
packed_k
=
tensor_fp4
.
shape
k
=
packed_k
*
2
tensor_f32
=
break_fp4_bytes
(
tensor_fp4
,
dtype
)
tensor_f32
=
tensor_f32
.
reshape
(
m
,
k
//
block_size
,
block_size
)
tensor_sf
=
tensor_sf
.
view
(
torch
.
float8_e4m3fn
)
tensor_sf
=
convert_swizzled_to_linear
(
tensor_sf
,
m
,
k
,
block_size
)
tensor_sf_dtype
=
tensor_sf
.
to
(
torch
.
float32
)
/
global_scale
# scale the tensor
out
=
(
tensor_f32
*
tensor_sf_dtype
.
unsqueeze
(
-
1
)).
reshape
(
m
,
k
)
return
out
.
to
(
dtype
=
dtype
)
def
break_fp4_bytes
(
a
,
dtype
):
assert
a
.
dtype
==
torch
.
uint8
m
,
n
=
a
.
shape
# Vectorized nibble processing
a_flat
=
a
.
flatten
()
high
=
(
a_flat
&
0xF0
)
>>
4
# Upper nibbles
low
=
a_flat
&
0x0F
# Lower nibbles
# Combine nibbles for batch processing
combined
=
torch
.
stack
((
low
,
high
),
dim
=
1
).
flatten
()
# Vectorized sign and magnitude extraction
signs
=
(
combined
&
0x08
).
to
(
torch
.
bool
)
# Sign bits
abs_vals
=
(
combined
&
0x07
).
to
(
torch
.
long
)
# Magnitude indices
# Device-aware lookup and sign application
kE2M1
=
kE2M1ToFloat
.
to
(
device
=
a
.
device
)
values
=
kE2M1
[
abs_vals
]
*
torch
.
where
(
signs
,
-
1.0
,
1.0
)
# Reshape to final form
return
values
.
reshape
(
m
,
n
*
2
).
to
(
dtype
=
dtype
)
tests/kernels/quantization/test_nvfp4_scaled_mm.py
View file @
0c0fdae8
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
torch
from
nvfp4_utils
import
(
FLOAT4_E2M1_MAX
,
FLOAT8_E4M3_MAX
,
dequantize_nvfp4_to_dtype
)
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
if
not
current_platform
.
has_device_capability
(
100
):
pytest
.
skip
(
reason
=
"Nvfp4 Requires compute capability of 10 or above."
,
...
...
@@ -19,95 +20,24 @@ SHAPES.extend(PAD_SHAPES)
SEEDS
=
[
42
]
CUDA_DEVICES
=
[
'cuda:0'
]
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1fn
.
max
()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
kE2M1ToFloatArray
=
[
0.
,
0.5
,
1.
,
1.5
,
2.
,
3.
,
4.
,
6.
,
]
def
e2m1_to_fp32
(
int4_value
):
signBit
=
(
int4_value
&
0x8
)
int4_absValue
=
int4_value
&
0x7
float_result
=
kE2M1ToFloatArray
[
int4_absValue
]
if
(
signBit
):
float_result
=
-
float_result
return
float_result
def
break_fp4_bytes
(
a
,
dtype
):
assert
(
a
.
dtype
==
torch
.
uint8
)
m
,
n
=
a
.
shape
a
=
a
.
flatten
()
# Get upper 4 bits
highHalfByte
=
(
a
&
0xF0
)
>>
4
# Get lower 4 bits
lowHalfByte
=
a
&
0x0F
fH
=
torch
.
tensor
([
e2m1_to_fp32
(
x
)
for
x
in
highHalfByte
]).
to
(
a
.
device
)
fL
=
torch
.
tensor
([
e2m1_to_fp32
(
x
)
for
x
in
lowHalfByte
]).
to
(
a
.
device
)
# [0xAB, 0xCD] -> [0xB, 0xA, 0xD, 0xC]
out
=
torch
.
stack
((
fL
,
fH
),
dim
=-
1
).
reshape
(
m
,
n
*
2
)
return
out
def
convert_swizzled_to_linear
(
a_sf_swizzled
:
torch
.
Tensor
,
m
,
k
,
block_size
):
sf_m
,
sf_k
=
a_sf_swizzled
.
shape
m_tiles
=
(
m
+
128
-
1
)
//
128
f
=
block_size
*
4
k_tiles
=
(
k
+
f
-
1
)
//
f
tmp
=
torch
.
reshape
(
a_sf_swizzled
,
(
1
,
m_tiles
,
k_tiles
,
32
,
4
,
4
))
tmp
=
torch
.
permute
(
tmp
,
(
0
,
1
,
4
,
3
,
2
,
5
))
out
=
tmp
.
reshape
(
m_tiles
*
128
,
k_tiles
*
f
//
block_size
)
return
out
[
0
:
m
,
0
:
k
]
def
dequantize_to_dtype
(
tensor_fp4
,
tensor_sf
,
global_scale
,
dtype
,
device
,
block_size
=
16
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert
tensor_fp4
.
dtype
==
torch
.
uint8
m
,
packed_k
=
tensor_fp4
.
shape
k
=
packed_k
*
2
tensor_f32
=
break_fp4_bytes
(
tensor_fp4
,
dtype
)
tensor_f32
=
tensor_f32
.
reshape
(
m
,
k
//
block_size
,
block_size
)
tensor_sf
=
tensor_sf
.
view
(
torch
.
float8_e4m3fn
)
tensor_sf
=
convert_swizzled_to_linear
(
tensor_sf
,
m
,
k
,
block_size
)
tensor_sf_dtype
=
tensor_sf
.
to
(
torch
.
float32
)
/
global_scale
# scale the tensor
out
=
(
tensor_f32
*
tensor_sf_dtype
.
unsqueeze
(
-
1
)).
reshape
(
m
,
k
)
return
out
def
get_ref_results
(
a_fp4
,
b_fp4
,
a_sf
,
b_sf
,
a_global_scale
,
b_global_scale
,
m
,
n
,
dtype
,
block_size
,
device
):
_
,
m_k
=
a_fp4
.
shape
_
,
n_k
=
b_fp4
.
shape
assert
(
m_k
==
n_k
)
a_in_dtype
=
dequantize_to_dtype
(
a_fp4
,
a_sf
,
a_global_scale
,
dtype
=
dtype
,
device
=
device
,
block_size
=
block_size
)
b_in_dtype
=
dequantize_to_dtype
(
b_fp4
,
b_sf
,
b_global_scale
,
dtype
=
dtype
,
device
=
device
,
block_size
=
block_size
)
a_in_dtype
=
dequantize_
nvfp4_
to_dtype
(
a_fp4
,
a_sf
,
a_global_scale
,
dtype
=
dtype
,
device
=
device
,
block_size
=
block_size
)
b_in_dtype
=
dequantize_
nvfp4_
to_dtype
(
b_fp4
,
b_sf
,
b_global_scale
,
dtype
=
dtype
,
device
=
device
,
block_size
=
block_size
)
return
torch
.
matmul
(
a_in_dtype
,
b_in_dtype
.
t
())
...
...
vllm/_custom_ops.py
View file @
0c0fdae8
...
...
@@ -745,10 +745,11 @@ def get_cutlass_moe_mm_data(
- output_permutation: Permutation that must be used to shuffle the output
after executing the MMs.
"""
torch
.
ops
.
_C
.
get_cutlass_moe_mm_data
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
input_permutation
,
output_permutation
,
num_experts
,
n
,
k
)
return
torch
.
ops
.
_C
.
get_cutlass_moe_mm_data
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
input_permutation
,
output_permutation
,
num_experts
,
n
,
k
)
def
cutlass_moe_mm
(
out_tensors
:
torch
.
Tensor
,
a_tensors
:
torch
.
Tensor
,
...
...
@@ -767,9 +768,41 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
MMs used in the fused MoE operation.
- a/b/c_strides: The data strides passed to grouped matrix multiplication.
"""
torch
.
ops
.
_C
.
cutlass_moe_mm
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
)
return
torch
.
ops
.
_C
.
cutlass_moe_mm
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
)
def
cutlass_fp4_moe_mm
(
a_tensors
:
torch
.
Tensor
,
b_tensors
:
torch
.
Tensor
,
a_scales
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
alphas
:
torch
.
Tensor
,
problem_sizes
:
torch
.
Tensor
,
expert_offsets
:
torch
.
Tensor
,
sf_offsets
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
"""
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
the gemms for each combination based on the specified problem sizes.
This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward.
- a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized
input and expert weights.
- a_/b_scales: The blockscales in FP8-E4M3 precision
- expert_offsets/sf_offsets: Indices that mark at which token index
each expert begins its computation. The number of tokens
computed with expert E is expert_offsets[E + 1] -
expert_offsets[E] And the sf_size per expert is
sf_offset[E+1] - sf_offset[E]
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
MMs used in the fused MoE operation.
"""
m_topk
=
a_tensors
.
shape
[
0
]
n
=
b_tensors
.
shape
[
1
]
c_shape
=
(
m_topk
,
n
)
c
=
torch
.
empty
(
c_shape
,
device
=
device
,
dtype
=
out_dtype
)
torch
.
ops
.
_C
.
cutlass_fp4_group_mm
(
c
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
alphas
,
problem_sizes
,
expert_offsets
,
sf_offsets
)
return
c
.
to
(
out_dtype
)
# aqlm
...
...
@@ -960,6 +993,57 @@ def scaled_fp4_quant(
return
output
,
output_scale
def
scaled_fp4_experts_quant
(
input_tensor
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
expert_offsets
:
torch
.
Tensor
,
blockscale_offsets
:
torch
.
Tensor
,
topk
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
MAX_TOKENS_PER_EXPERT
:
int
=
163840
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
packed MoE Inputs.
Args:
input: The input tensor to be quantized to FP4
expert_map: The expert map tensor
input_global_scale: A scalar scaling factor for the entire tensor.
expert_offsets: The expert offsets tensor
blockscale_offsets: The blockscale offsets tensor
Outputs:
output: The quantized tensor in FP4
output_scales: The blockscale tensor in FP8-E4M3
"""
assert
not
current_platform
.
is_rocm
()
assert
input_tensor
.
ndim
==
2
,
(
f
'input.ndim needs to be == 2, but got
{
input_tensor
.
ndim
}
.'
)
input_tensor
=
input_tensor
[
expert_map
]
if
expert_map
is
not
None
else
input_tensor
m_numtopk
,
k
=
input_tensor
.
shape
assert
(
m_numtopk
<=
MAX_TOKENS_PER_EXPERT
*
topk
),
(
f
"m_numtopk must be less than MAX_TOKENS_PER_EXPERT * topk for"
f
" scaled_fp4_experts_quant kernel, observed m_numtopk =
{
m_numtopk
}
"
)
scales_k
=
k
//
16
padded_k
=
(
scales_k
+
(
4
-
1
))
//
4
# output is uint8 and packed fp4 values
output
=
torch
.
empty
(
m_numtopk
,
k
//
2
,
device
=
input_tensor
.
device
,
dtype
=
torch
.
uint8
)
output_scales
=
torch
.
empty
(
MAX_TOKENS_PER_EXPERT
*
topk
,
padded_k
,
dtype
=
torch
.
int32
,
device
=
input_tensor
.
device
)
torch
.
ops
.
_C
.
scaled_fp4_experts_quant
(
output
,
output_scales
,
input_tensor
,
input_global_scale
,
expert_offsets
,
blockscale_offsets
)
output_scales
=
output_scales
.
view
(
torch
.
float8_e4m3fn
)
return
output
,
output_scales
# fp8
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/fused_moe/__init__.py
View file @
0c0fdae8
...
...
@@ -36,7 +36,7 @@ if HAS_TRITON:
import
vllm.model_executor.layers.fused_moe.fused_marlin_moe
# noqa
import
vllm.model_executor.layers.fused_moe.fused_moe
# noqa
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
cutlass_moe_fp8
)
cutlass_moe_fp4
,
cutlass_moe_fp8
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
,
fused_moe
,
fused_topk
,
get_config_file_name
,
grouped_topk
)
...
...
@@ -48,4 +48,5 @@ if HAS_TRITON:
"get_config_file_name"
,
"grouped_topk"
,
"cutlass_moe_fp8"
,
"cutlass_moe_fp4"
,
]
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
0c0fdae8
# SPDX-License-Identifier: Apache-2.0
"""Fused MoE kernel."""
"""
CUTLASS based
Fused MoE kernel
s
."""
from
typing
import
Optional
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.scalar_type
import
scalar_types
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
...
...
@@ -178,3 +179,126 @@ def cutlass_moe_fp8(
if
not
apply_router_weight_on_input
:
c2
=
c2
*
topk_weights
.
view
(
m
,
topk
,
1
).
to
(
out_dtype
)
return
c2
.
sum
(
dim
=
1
)
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
.
max
()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
MAX_TOKENS_PER_EXPERT
=
65536
def
cutlass_moe_fp4
(
a
:
torch
.
Tensor
,
a1_gscale
:
torch
.
Tensor
,
w1_fp4
:
torch
.
Tensor
,
w1_blockscale
:
torch
.
Tensor
,
w1_alphas
:
torch
.
Tensor
,
a2_gscale
:
torch
.
Tensor
,
w2_fp4
:
torch
.
Tensor
,
w2_blockscale
:
torch
.
Tensor
,
w2_alphas
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
device
:
torch
.
device
):
"""
MoE implementation for FP4 Inputs
# Gemm 1
a: Input tensor: [m, k] (half/bfloat16)
a1_gscale: Activation scale per expert: [e] (float32)
w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k]
w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
(Note: `n` is the up projection output dim, `k` is the input dim in
full precision)
w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
(Block size = 16 for NVFP4)
# Gemm 2
a2_gscale: Activation scale per expert: [e]
w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n]
w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1)
w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3
topk_weights: [m, topk] dtype: float8
topk_ids: [m, topk] dtype: float8
m, n, k: Unquantized weight shapes, dtype: int
e: number of experts, dtype: int
assumes that topk < k < n to satisfy - up/down projection expectations.
"""
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
w1_fp4
.
dtype
==
torch
.
uint8
,
"weight 1 must be uint8"
assert
w2_fp4
.
dtype
==
torch
.
uint8
,
"weight 2 must be uint8"
assert
(
w1_fp4
.
ndim
==
3
and
w2_fp4
.
ndim
==
3
and
w1_blockscale
.
ndim
==
3
and
w2_blockscale
.
ndim
==
3
),
(
"All Weights must be of rank 3 for cutlass_moe_fp4"
)
m_a
,
k_a
=
a
.
shape
e_w1
,
nx2_w1
,
half_k_w1
=
w1_fp4
.
shape
e_w2
,
k_w2
,
half_n_w2
=
w2_fp4
.
shape
assert
(
e_w1
==
e_w2
and
e_w1
==
e
),
(
"Number of experts must match"
,
" between weights."
)
assert
(
k_a
//
2
==
half_k_w1
and
k
==
k_w2
),
(
"Hidden size mismatch between a, w1 and w2"
)
assert
(
nx2_w1
==
n
*
2
and
half_n_w2
==
n
//
2
),
(
"mismatch in "
"expected `n`"
)
assert
(
m
==
m_a
),
"input shape mismatch"
assert
2
*
half_k_w1
==
k_w2
,
"Hidden size mismatch w2 and w1"
assert
a
.
dtype
in
[
torch
.
half
,
torch
.
bfloat16
],
"Invalid input dtype"
assert
(
topk_weights
.
shape
[
0
]
==
m
and
topk_ids
.
shape
[
0
]
==
m
),
(
"topk must be provided for each row of a"
)
assert
(
m
<=
MAX_TOKENS_PER_EXPERT
),
(
f
"m must be less than MAX_TOKENS_PER_EXPERT(
{
MAX_TOKENS_PER_EXPERT
}
)"
f
" for cutlass_moe_fp4, observed m =
{
m
}
"
)
out_dtype
=
a
.
dtype
num_topk
=
topk_ids
.
shape
[
1
]
expert_offsets
=
torch
.
empty
((
e
+
1
),
dtype
=
torch
.
int32
,
device
=
device
)
# Problem size: (num_experts, (m,2n,k))
problem_sizes1
=
torch
.
empty
((
e
,
3
),
dtype
=
torch
.
int32
,
device
=
device
)
# Problem size: (num_experts, (m,n,k))
problem_sizes2
=
torch
.
empty
((
e
,
3
),
dtype
=
torch
.
int32
,
device
=
device
)
a_map
=
torch
.
empty
((
topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
c_map
=
torch
.
empty
((
topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
# problem shapes should have [m, n, k]
# Note that problem sizes are based on logical number of elements.
ops
.
get_cutlass_moe_mm_data
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
a_map
,
c_map
,
e
,
n
,
k
)
tokens_per_expert
=
problem_sizes1
[:,
0
]
rounded_tokens_per_expert
=
(
tokens_per_expert
+
(
128
-
1
))
//
128
*
128
blockscale_offsets
=
torch
.
zeros
(
e
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
blockscale_offsets
[
1
:]
=
torch
.
cumsum
(
rounded_tokens_per_expert
,
dim
=
0
)
rep_a_fp4
,
rep_a_blockscale
=
ops
.
scaled_fp4_experts_quant
(
a
,
a1_gscale
,
expert_offsets
,
blockscale_offsets
,
num_topk
,
expert_map
=
a_map
,
MAX_TOKENS_PER_EXPERT
=
MAX_TOKENS_PER_EXPERT
)
c1
=
ops
.
cutlass_fp4_moe_mm
(
rep_a_fp4
,
w1_fp4
,
rep_a_blockscale
,
w1_blockscale
,
w1_alphas
,
problem_sizes1
,
expert_offsets
[:
-
1
],
blockscale_offsets
[:
-
1
],
out_dtype
,
device
)
del
rep_a_fp4
,
rep_a_blockscale
# hidden size dimension is split to one halfpytho sized tensor.
intermediate
=
torch
.
empty
((
m
*
num_topk
,
w1_fp4
.
shape
[
1
]
//
2
),
device
=
device
,
dtype
=
out_dtype
)
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate
,
c1
)
int_fp4
,
int_blockscale
=
ops
.
scaled_fp4_experts_quant
(
intermediate
,
a2_gscale
,
expert_offsets
,
blockscale_offsets
,
num_topk
,
MAX_TOKENS_PER_EXPERT
=
MAX_TOKENS_PER_EXPERT
)
c2
=
ops
.
cutlass_fp4_moe_mm
(
int_fp4
,
w2_fp4
,
int_blockscale
,
w2_blockscale
,
w2_alphas
,
problem_sizes2
,
expert_offsets
[:
-
1
],
blockscale_offsets
[:
-
1
],
out_dtype
,
device
)
del
int_fp4
,
int_blockscale
out
=
(
c2
[
c_map
].
view
(
m
,
num_topk
,
k
)
*
topk_weights
.
view
(
m
,
num_topk
,
1
).
half
()).
sum
(
dim
=
1
)
return
out
.
to
(
dtype
=
out_dtype
)
vllm/model_executor/layers/fused_moe/layer.py
View file @
0c0fdae8
...
...
@@ -643,7 +643,7 @@ class FusedMoE(torch.nn.Module):
expert_id
=
self
.
_map_global_expert_id_to_local_expert_id
(
expert_id
)
if
expert_id
==
-
1
:
return
quant_method_name
=
self
.
quant_method
.
__class__
.
__name__
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
...
...
@@ -697,8 +697,9 @@ class FusedMoE(torch.nn.Module):
# this is needed for compressed-tensors only
loaded_weight
=
loaded_weight
.
to
(
param
.
data
.
device
)
if
param
.
data
[
expert_id
]
!=
1
and
(
param
.
data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
:
if
(
"compressed"
in
quant_method_name
.
lower
()
and
param
.
data
[
expert_id
]
!=
1
and
(
param
.
data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
):
raise
ValueError
(
"input_scales of w1 and w3 of a layer "
f
"must be equal. But got
{
param
.
data
[
expert_id
]
}
"
...
...
@@ -718,6 +719,22 @@ class FusedMoE(torch.nn.Module):
tp_rank
=
self
.
tp_rank
)
return
if
"ModelOpt"
in
quant_method_name
:
if
(
'weight_scale_2'
in
weight_name
or
'input_scale'
in
weight_name
):
self
.
_load_per_tensor_weight_scale
(
shard_id
=
shard_id
,
param
=
param
,
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
)
elif
"weight"
in
weight_name
:
self
.
_load_model_weight_or_group_weight_scale
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
self
.
tp_rank
)
return
# Case weight scales, zero_points and offset
if
(
"scale"
in
weight_name
or
"zero"
in
weight_name
or
"offset"
in
weight_name
):
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
0c0fdae8
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
torch
from
torch.nn
import
Module
...
...
@@ -9,6 +9,8 @@ from torch.nn.parameter import Parameter
from
vllm._custom_ops
import
(
cutlass_scaled_fp4_mm
,
cutlass_scaled_mm_supports_fp4
,
scaled_fp4_quant
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
...
...
@@ -210,25 +212,37 @@ class ModelOptNvFp4Config(QuantizationConfig):
"`hf_quant_config.json` file for your model's "
"quant configuration."
)
is_checkpoint_nvfp4_serialized
=
(
"NVFP4"
in
quant_method
)
kv_cache_quant_algo
=
quant_config
[
"kv_cache_quant_algo"
]
group_size
=
quant_config
[
"group_size"
]
exclude_modules
=
quant_config
[
"exclude_modules"
]
if
not
(
group_size
and
kv_cache_quant_algo
and
exclude_modules
):
if
(
"group_size"
and
"kv_cache_quant_algo"
and
"exclude_modules"
)
not
in
quant_config
:
raise
ValueError
(
"NVFP4 quantization requires group size and "
"kv_cache_quant_algo specified in "
"hf_quant_config.json"
)
kv_cache_quant_algo
=
quant_config
[
"kv_cache_quant_algo"
]
group_size
=
quant_config
[
"group_size"
]
exclude_modules
=
quant_config
[
"exclude_modules"
]
return
cls
(
is_checkpoint_nvfp4_serialized
,
kv_cache_quant_algo
,
exclude_modules
,
group_size
)
def
is_layer_excluded
(
self
,
prefix
:
str
,
exclude_modules
:
List
):
import
re
for
pattern
in
exclude_modules
:
regex_str
=
pattern
.
replace
(
'.'
,
r
'\.'
).
replace
(
'*'
,
r
'.*'
)
if
re
.
fullmatch
(
regex_str
,
prefix
):
return
True
return
False
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped
(
prefix
,
self
.
exclude_modules
):
if
(
is_layer_skipped
(
prefix
,
self
.
exclude_modules
)
or
self
.
is_layer_excluded
(
prefix
,
self
.
exclude_modules
)):
return
UnquantizedLinearMethod
()
return
ModelOptNvFp4LinearMethod
(
self
)
elif
isinstance
(
layer
,
Attention
):
return
ModelOptFp8KVCacheMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
ModelOptNvFp4FusedMoE
(
self
)
return
None
...
...
@@ -409,3 +423,235 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
if
bias
is
not
None
:
out
=
out
+
bias
return
out
.
view
(
*
output_shape
)
class
ModelOptNvFp4FusedMoE
(
FusedMoEMethodBase
):
"""
MoE Method for FP4 Quantization.
Args:
quant_config: NVFP4 Quant Config
"""
def
__init__
(
self
,
quant_config
:
ModelOptNvFp4Config
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
if
not
self
.
quant_config
.
is_checkpoint_nvfp4_serialized
:
raise
ValueError
(
"NVFP4 quantization was selected, "
" dynamic quantization is not supported."
)
layer
.
quant_config
=
self
.
quant_config
weight_dtype
=
torch
.
uint8
weight_scale_dtype
=
torch
.
float8_e4m3fn
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
# GEMM 1
w13_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
# 2 fp4 items are packed in the input dimension
hidden_size
//
2
,
dtype
=
weight_dtype
),
input_dim
=
1
,
output_dim
=
2
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
# GEMM 2
w2_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
hidden_size
,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition
//
2
,
dtype
=
weight_dtype
),
input_dim
=
1
,
output_dim
=
2
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
w13_weight_scale
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
# 2 fp4 items are packed in the input dimension
hidden_size
//
self
.
quant_config
.
group_size
,
dtype
=
weight_scale_dtype
),
input_dim
=
1
,
output_dim
=
2
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
hidden_size
,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition
//
self
.
quant_config
.
group_size
,
dtype
=
weight_scale_dtype
),
input_dim
=
1
,
output_dim
=
2
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
BLOCK
.
value
})
w13_weight_scale_2
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"w13_weight_scale_2"
,
w13_weight_scale_2
)
w2_weight_scale_2
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
num_experts
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"w2_weight_scale_2"
,
w2_weight_scale_2
)
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
})
w13_input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
w2_input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
num_experts
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
def
swizzle_blockscale
(
self
,
scale
:
torch
.
tensor
):
assert
(
scale
.
dtype
==
torch
.
float8_e4m3fn
)
# Pad and blockwise interleave weight_scale
scale_ndim
=
scale
.
ndim
if
scale
.
ndim
==
2
:
scale
=
scale
.
unsqueeze
(
0
)
assert
scale
.
ndim
==
3
B
,
M
,
K
=
scale
.
shape
round_up_multiple
=
lambda
x
,
m
:
(
x
+
m
-
1
)
//
m
*
m
M_padded
=
round_up_multiple
(
M
,
128
)
K_padded
=
round_up_multiple
(
K
,
4
)
padded_scale
=
torch
.
zeros
((
B
,
M_padded
,
K_padded
),
dtype
=
scale
.
dtype
)
padded_scale
[:
B
,
:
M
,
:
K
]
=
scale
batches
,
rows
,
cols
=
padded_scale
.
shape
assert
rows
%
128
==
0
assert
cols
%
4
==
0
padded_scale
=
padded_scale
.
reshape
(
batches
,
rows
//
128
,
4
,
32
,
cols
//
4
,
4
)
swizzled_scale
=
padded_scale
.
permute
((
0
,
1
,
4
,
3
,
2
,
5
))
swizzled_scale
=
swizzled_scale
.
contiguous
().
cuda
()
return
(
swizzled_scale
.
reshape
(
M
,
K
)
if
scale_ndim
==
2
else
swizzled_scale
.
reshape
(
B
,
M
,
K
))
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# GEMM 1
assert
torch
.
allclose
(
layer
.
w13_weight_scale_2
[:,
0
],
layer
.
w13_weight_scale_2
[:,
1
]),
(
"Expected w1_weight_scale_2 to equal w3_weight_scale_2"
)
w13_weight_scale_2
=
layer
.
w13_weight_scale_2
[:,
0
]
layer
.
w13_weight_scale_2
=
Parameter
(
w13_weight_scale_2
,
requires_grad
=
False
)
w13_input_scale
=
layer
.
w13_input_scale
.
max
(
dim
=
1
).
values
.
to
(
torch
.
float32
)
layer
.
g1_alphas
=
Parameter
(
(
w13_input_scale
*
w13_weight_scale_2
).
to
(
torch
.
float32
),
requires_grad
=
False
)
assert
(
layer
.
w13_weight_scale
.
shape
[
2
]
%
16
==
0
),
(
"Expected weight_scale.dim(1) to be divisible by 16"
)
assert
(
layer
.
w13_weight_scale
.
dtype
==
torch
.
float8_e4m3fn
),
(
"Weight Blockscale must be represented as FP8-E4M3"
)
w13_blockscale_swizzled
=
self
.
swizzle_blockscale
(
layer
.
w13_weight_scale
)
layer
.
w13_blockscale_swizzled
=
Parameter
(
w13_blockscale_swizzled
,
requires_grad
=
False
)
# This is for quantization, so we need to invert it.
layer
.
w13_input_scale_quant
=
Parameter
(
(
1
/
w13_input_scale
).
to
(
torch
.
float32
),
requires_grad
=
False
)
# GEMM 2
layer
.
g2_alphas
=
Parameter
(
(
layer
.
w2_input_scale
*
layer
.
w2_weight_scale_2
).
to
(
torch
.
float32
),
requires_grad
=
False
)
# This is for quantization, so we need to invert it.
layer
.
w2_input_scale_quant
=
Parameter
(
(
1
/
layer
.
w2_input_scale
).
to
(
torch
.
float32
),
requires_grad
=
False
)
assert
(
layer
.
w2_weight_scale
.
shape
[
2
]
%
16
==
0
),
(
"Expected weight_scale.dim(1) to be divisible by 16"
)
assert
(
layer
.
w2_weight_scale
.
dtype
==
torch
.
float8_e4m3fn
),
(
"Weight Blockscale must be represented as FP8-E4M3"
)
w2_blockscale_swizzled
=
self
.
swizzle_blockscale
(
layer
.
w2_weight_scale
)
layer
.
w2_blockscale_swizzled
=
Parameter
(
w2_blockscale_swizzled
,
requires_grad
=
False
)
return
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
):
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
not
apply_router_weight_on_input
,
(
"Router weight on input is not "
"supported for ModelOptNvFp4FusedMoE."
)
assert
expert_map
is
None
,
(
"Expert Parallelism /expert_map "
"is currently not supported for "
"ModelOptNvFp4FusedMoE."
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
cutlass_moe_fp4
)
# Cutlass moe takes in activations in BF16/Half precision
# and fp4 quantized weights loaded from the checkpoint
return
cutlass_moe_fp4
(
a
=
x
,
w1_fp4
=
layer
.
w13_weight
,
w1_blockscale
=
layer
.
w13_blockscale_swizzled
,
w1_alphas
=
layer
.
g1_alphas
,
w2_fp4
=
layer
.
w2_weight
,
w2_blockscale
=
layer
.
w2_blockscale_swizzled
,
w2_alphas
=
layer
.
g2_alphas
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
m
=
x
.
shape
[
0
],
n
=
layer
.
w2_weight
.
shape
[
2
]
*
2
,
k
=
x
.
shape
[
1
],
e
=
layer
.
w13_weight
.
shape
[
0
],
a1_gscale
=
layer
.
w13_input_scale_quant
,
a2_gscale
=
layer
.
w2_input_scale_quant
,
device
=
x
.
device
).
to
(
x
.
dtype
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment