Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9798b2fb
Unverified
Commit
9798b2fb
authored
Jan 30, 2025
by
Lucas Wilkinson
Committed by
GitHub
Jan 30, 2025
Browse files
[Kernel] Update `cutlass_scaled_mm` to support 2d group (blockwise) scaling (#11868)
parent
4078052f
Changes
25
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1736 additions
and
285 deletions
+1736
-285
CMakeLists.txt
CMakeLists.txt
+7
-2
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
+142
-148
csrc/core/math.hpp
csrc/core/math.hpp
+8
-1
csrc/cutlass_extensions/common.hpp
csrc/cutlass_extensions/common.hpp
+17
-0
csrc/cutlass_extensions/gemm/collective/collective_builder.hpp
...cutlass_extensions/gemm/collective/collective_builder.hpp
+123
-0
csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp
csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp
+183
-0
csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
...mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
+730
-0
csrc/cutlass_extensions/gemm/dispatch_policy.hpp
csrc/cutlass_extensions/gemm/dispatch_policy.hpp
+39
-0
csrc/cutlass_extensions/vllm_collective_builder.cuh
csrc/cutlass_extensions/vllm_collective_builder.cuh
+1
-1
csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
+93
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh
csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh
+0
-74
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu
.../quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu
+24
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu
...tization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu
+24
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
...utlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
+168
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
+33
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu
+24
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh
...tization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh
+25
-1
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
+24
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh
...ization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh
+24
-1
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
+47
-57
No files found.
CMakeLists.txt
View file @
9798b2fb
...
...
@@ -245,7 +245,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
FetchContent_Declare
(
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
GIT_TAG v3.
6
.0
GIT_TAG v3.
7
.0
GIT_PROGRESS TRUE
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
...
...
@@ -299,7 +299,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
cuda_archs_loose_intersection
(
SCALED_MM_3X_ARCHS
"9.0a"
"
${
CUDA_ARCHS
}
"
)
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS
)
set
(
SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
)
set
(
SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu"
)
set_gencode_flags_for_srcs
(
SRCS
"
${
SRCS
}
"
CUDA_ARCHS
"
${
SCALED_MM_3X_ARCHS
}
"
)
...
...
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
View file @
9798b2fb
...
...
@@ -3,7 +3,7 @@ import copy
import
itertools
import
pickle
as
pkl
import
time
from
typing
import
Callable
,
Iterable
,
List
,
Tuple
from
typing
import
Callable
,
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch.utils.benchmark
as
TBenchmark
...
...
@@ -12,6 +12,8 @@ from utils import make_rand_tensors
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
w8a8_block_fp8_matmul
)
from
vllm.utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
list
(
WEIGHT_SHAPES
.
keys
())
...
...
@@ -38,8 +40,15 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
def
bench_int8
(
dtype
:
torch
.
dtype
,
m
:
int
,
k
:
int
,
n
:
int
,
label
:
str
,
sub_label
:
str
)
->
Iterable
[
TMeasurement
]:
def
bench_int8
(
dtype
:
torch
.
dtype
,
m
:
int
,
k
:
int
,
n
:
int
,
label
:
str
,
sub_label
:
str
,
bench_kernels
:
Optional
[
List
[
str
]]
=
None
)
->
Iterable
[
TMeasurement
]:
"""Benchmark INT8-based kernels."""
assert
dtype
==
torch
.
int8
a
,
b
=
make_rand_tensors
(
torch
.
int8
,
m
,
n
,
k
)
scale_a
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
...
...
@@ -48,155 +57,132 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
azp
=
torch
.
zeros
((
m
,
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
azp_adj
=
torch
.
zeros
((
n
,
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
bench_fns
=
{
"pytorch_bf16_bf16_bf16_matmul-no-scales"
:
lambda
:
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
bfloat16
),
b
.
to
(
dtype
=
torch
.
bfloat16
)
),
"pytorch_fp16_fp16_fp16_matmul-no-scales"
:
lambda
:
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
float16
),
b
.
to
(
dtype
=
torch
.
float16
)),
"cutlass_i8_i8_bf16_scaled_mm"
:
lambda
:
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
),
"cutlass_i8_i8_bf16_scaled_mm_bias"
:
lambda
:
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
bias
),
"cutlass_i8_i8_bf16_scaled_mm_azp"
:
lambda
:
ops
.
cutlass_scaled_mm_azp
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
),
"cutlass_i8_i8_bf16_scaled_mm_azp_bias"
:
lambda
:
ops
.
cutlass_scaled_mm_azp
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
,
None
,
bias
),
"cutlass_i8_i8_bf16_scaled_mm_azp_pt"
:
lambda
:
ops
.
cutlass_scaled_mm_azp
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
,
azp
),
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias"
:
lambda
:
ops
.
cutlass_scaled_mm_azp
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
,
azp
,
bias
),
}
timers
=
[]
# pytorch impl - bfloat16
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_bf16_bf16_bf16_matmul-no-scales"
,
torch
.
mm
,
a
.
to
(
dtype
=
torch
.
bfloat16
),
b
.
to
(
dtype
=
torch
.
bfloat16
)))
# pytorch impl - float16
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_fp16_fp16_fp16_matmul-no-scales"
,
torch
.
mm
,
a
.
to
(
dtype
=
torch
.
float16
),
b
.
to
(
dtype
=
torch
.
float16
)))
# cutlass impl
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
))
# cutlass with bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm_bias"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
bias
))
# cutlass with azp per-tensor
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm_azp"
,
ops
.
cutlass_scaled_mm_azp
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
))
# cutlass with azp per-tensor + bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm_azp_bias"
,
ops
.
cutlass_scaled_mm_azp
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
,
None
,
bias
))
# cutlass with azp per-token
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm_azp_pt"
,
ops
.
cutlass_scaled_mm_azp
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
,
azp
))
# cutlass with azp per-token + bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias"
,
ops
.
cutlass_scaled_mm_azp
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
,
azp
,
bias
))
for
name
,
fn
in
bench_fns
.
items
():
# If bench_kernels is None, run all. Otherwise, run only exact matches.
if
bench_kernels
is
None
or
name
in
bench_kernels
:
print
(
f
"Running
{
name
}
"
)
timers
.
append
(
bench_fn
(
label
,
sub_label
,
name
,
fn
))
return
timers
def
bench_fp8
(
dtype
:
torch
.
dtype
,
m
:
int
,
k
:
int
,
n
:
int
,
label
:
str
,
sub_label
:
str
)
->
Iterable
[
TMeasurement
]:
def
bench_fp8
(
dtype
:
torch
.
dtype
,
m
:
int
,
k
:
int
,
n
:
int
,
label
:
str
,
sub_label
:
str
,
bench_kernels
:
Optional
[
List
[
str
]]
=
None
)
->
Iterable
[
TMeasurement
]:
"""Benchmark FP8-based kernels."""
assert
dtype
==
torch
.
float8_e4m3fn
a
,
b
=
make_rand_tensors
(
torch
.
float8_e4m3fn
,
m
,
n
,
k
)
a_cont
=
a
.
contiguous
()
scale_a
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
block_scale_a
=
torch
.
rand
((
m
,
k
//
128
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
block_scale_b
=
torch
.
rand
((
k
//
128
,
n
//
128
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
block_scale_a_M_major
=
block_scale_a
.
t
().
contiguous
().
t
()
block_scale_b_K_major
=
block_scale_b
.
t
().
contiguous
().
t
()
bias
=
torch
.
zeros
((
n
,
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
timers
=
[]
# pytorch impl w. bf16
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_bf16_bf16_bf16_matmul-no-scales"
,
torch
.
mm
,
a
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
b
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)))
# pytorch impl: bf16 output, without fp8 fast accum
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_bf16_scaled_mm"
,
torch
.
_scaled_mm
,
a
,
print
(
m
,
k
,
n
)
bench_fns
=
{
"pytorch_bf16_bf16_bf16_matmul-no-scales"
:
lambda
:
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
bfloat16
),
b
.
to
(
dtype
=
torch
.
bfloat16
)
),
"pytorch_fp16_fp16_fp16_matmul-no-scales"
:
lambda
:
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
float16
),
b
.
to
(
dtype
=
torch
.
float16
)),
"pytorch_fp8_fp8_fp16_scaled_mm"
:
lambda
:
torch
.
_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
=
torch
.
float16
),
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum"
:
lambda
:
torch
.
_scaled_mm
(
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
bfloat16
))
# pytorch impl: bf16 output, with fp8 fast accum
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum"
,
torch
.
_scaled_mm
,
a
,
scale_a
,
scale_b
,
out_dtype
=
torch
.
float16
,
use_fast_accum
=
True
),
"pytorch_fp8_fp8_bf16_scaled_mm"
:
lambda
:
torch
.
_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
=
torch
.
bfloat16
),
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum"
:
lambda
:
torch
.
_scaled_mm
(
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
scale_a
,
scale_b
,
out_dtype
=
torch
.
bfloat16
,
use_fast_accum
=
True
))
# pytorch impl: fp16 output, without fp8 fast accum
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_fp16_scaled_mm"
,
torch
.
_scaled_mm
,
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
float16
))
# pytorch impl: fp16 output, with fp8 fast accum
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum"
,
torch
.
_scaled_mm
,
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
float16
,
use_fast_accum
=
True
))
# cutlass impl: bf16 output
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_bf16_scaled_mm"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
))
# cutlass impl: fp16 output
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_fp16_scaled_mm"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
))
# cutlass impl: bf16 output, with bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_bf16_scaled_mm_bias"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
bias
))
# cutlass impl: fp16 output, with bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_fp16_scaled_mm_bias"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
,
bias
.
to
(
dtype
=
torch
.
float16
)))
use_fast_accum
=
True
),
"cutlass_fp8_fp8_bf16_scaled_mm"
:
lambda
:
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
),
"cutlass_fp8_fp8_fp16_scaled_mm"
:
lambda
:
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
),
"cutlass_fp8_fp8_bf16_scaled_mm_bias"
:
lambda
:
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
bias
),
"cutlass_fp8_fp8_fp16_scaled_mm_bias"
:
lambda
:
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
,
bias
.
to
(
dtype
=
torch
.
float16
)),
"triton_fp8_fp8_fp16_scaled_mm_blockwise"
:
lambda
:
w8a8_block_fp8_matmul
(
a_cont
,
b
.
t
(),
block_scale_a
,
block_scale_b
.
t
(),
(
128
,
128
)),
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise"
:
lambda
:
ops
.
cutlass_scaled_mm
(
a
,
b
,
block_scale_a_M_major
,
block_scale_b_K_major
,
torch
.
float16
),
}
timers
=
[]
for
name
,
fn
in
bench_fns
.
items
():
# If bench_kernels is None, run all. Otherwise, run only exact matches.
if
bench_kernels
is
None
or
name
in
bench_kernels
:
print
(
f
"Running
{
name
}
"
)
timers
.
append
(
bench_fn
(
label
,
sub_label
,
name
,
fn
))
return
timers
def
bench
(
dtype
:
torch
.
dtype
,
m
:
int
,
k
:
int
,
n
:
int
,
label
:
str
,
sub_label
:
str
)
->
Iterable
[
TMeasurement
]:
def
bench
(
dtype
:
torch
.
dtype
,
m
:
int
,
k
:
int
,
n
:
int
,
label
:
str
,
sub_label
:
str
,
bench_kernels
:
Optional
[
List
[
str
]]
=
None
)
->
Iterable
[
TMeasurement
]:
if
dtype
==
torch
.
int8
:
return
bench_int8
(
dtype
,
m
,
k
,
n
,
label
,
sub_label
)
return
bench_int8
(
dtype
,
m
,
k
,
n
,
label
,
sub_label
,
bench_kernels
)
if
dtype
==
torch
.
float8_e4m3fn
:
return
bench_fp8
(
dtype
,
m
,
k
,
n
,
label
,
sub_label
)
return
bench_fp8
(
dtype
,
m
,
k
,
n
,
label
,
sub_label
,
bench_kernels
)
raise
ValueError
(
"unsupported type"
)
...
...
@@ -207,18 +193,22 @@ def print_timers(timers: Iterable[TMeasurement]):
def
run
(
dtype
:
torch
.
dtype
,
MKNs
:
Iterable
[
Tuple
[
int
,
int
,
int
]])
->
Iterable
[
TMeasurement
]:
MKNs
:
Iterable
[
Tuple
[
int
,
int
,
int
]],
bench_kernels
:
Optional
[
List
[
str
]]
=
None
)
->
Iterable
[
TMeasurement
]:
results
=
[]
for
m
,
k
,
n
in
MKNs
:
timers
=
bench
(
dtype
,
m
,
k
,
n
,
f
"scaled-
{
dtype
}
-gemm"
,
f
"MKN=(
{
m
}
x
{
k
}
x
{
n
}
)"
)
timers
=
bench
(
dtype
,
m
,
k
,
n
,
f
"scaled-
{
dtype
}
-gemm"
,
f
"MKN=(
{
m
}
x
{
k
}
x
{
n
}
)"
,
bench_kernels
=
bench_kernels
)
print_timers
(
timers
)
results
.
extend
(
timers
)
return
results
# output makers
def
make_output
(
data
:
Iterable
[
TMeasurement
],
MKNs
:
Iterable
[
Tuple
[
int
,
int
,
int
]],
base_description
:
str
,
...
...
@@ -232,15 +222,11 @@ def make_output(data: Iterable[TMeasurement],
pkl
.
dump
(
data
,
f
)
# argparse runners
def
run_square_bench
(
args
):
dim_sizes
=
list
(
range
(
args
.
dim_start
,
args
.
dim_end
+
1
,
args
.
dim_increment
))
MKNs
=
list
(
zip
(
dim_sizes
,
dim_sizes
,
dim_sizes
))
data
=
run
(
args
.
dtype
,
MKNs
)
data
=
run
(
args
.
dtype
,
MKNs
,
bench_kernels
=
args
.
kernels
)
make_output
(
data
,
MKNs
,
f
"square_bench-
{
args
.
dtype
}
"
)
...
...
@@ -251,8 +237,7 @@ def run_range_bench(args):
Ks
=
[
args
.
k_constant
]
*
n
if
args
.
k_constant
is
not
None
else
dim_sizes
Ns
=
[
args
.
n_constant
]
*
n
if
args
.
n_constant
is
not
None
else
dim_sizes
MKNs
=
list
(
zip
(
Ms
,
Ks
,
Ns
))
data
=
run
(
args
.
dtype
,
MKNs
)
data
=
run
(
args
.
dtype
,
MKNs
,
bench_kernels
=
args
.
kernels
)
make_output
(
data
,
MKNs
,
f
"range_bench-
{
args
.
dtype
}
"
)
...
...
@@ -278,7 +263,7 @@ def run_model_bench(args):
for
k
,
n
in
KNs
:
MKNs
.
append
((
m
,
k
,
n
))
data
=
run
(
args
.
dtype
,
MKNs
)
data
=
run
(
args
.
dtype
,
MKNs
,
bench_kernels
=
args
.
kernels
)
model_bench_data
.
append
(
data
)
# Print all results
...
...
@@ -328,6 +313,15 @@ Benchmark Cutlass GEMM.
type
=
to_torch_dtype
,
required
=
True
,
help
=
"Available options are ['int8', 'fp8']"
)
parser
.
add_argument
(
"--kernels"
,
nargs
=
"+"
,
type
=
str
,
default
=
None
,
help
=
"Exact names of the kernels to benchmark. If not set, runs all kernels."
)
subparsers
=
parser
.
add_subparsers
(
dest
=
"cmd"
)
square_parser
=
subparsers
.
add_parser
(
"square_bench"
)
...
...
csrc/core/math.hpp
View file @
9798b2fb
#pragma once
#include <climits>
#include <iostream>
inline
uint32_t
next_pow_2
(
uint32_t
const
num
)
{
inline
constexpr
uint32_t
next_pow_2
(
uint32_t
const
num
)
{
if
(
num
<=
1
)
return
num
;
return
1
<<
(
CHAR_BIT
*
sizeof
(
num
)
-
__builtin_clz
(
num
-
1
));
}
template
<
typename
T
>
inline
constexpr
std
::
enable_if_t
<
std
::
is_integral_v
<
T
>
,
T
>
ceil_div
(
T
a
,
T
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
\ No newline at end of file
csrc/cutlass_extensions/common.hpp
View file @
9798b2fb
...
...
@@ -32,3 +32,20 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
}
int32_t
get_sm_version_num
();
/**
* A wrapper for a kernel that is used to guard against compilation on
* architectures that will never use the kernel. The purpose of this is to
* reduce the size of the compiled binary.
* __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
* into code that will be executed on the device where it is defined.
*/
template
<
typename
Kernel
>
struct
enable_sm90_or_later
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
void
operator
()(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
Kernel
::
operator
()(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
\ No newline at end of file
csrc/cutlass_extensions/gemm/collective/collective_builder.hpp
0 → 100644
View file @
9798b2fb
// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl
// clang-format off
#pragma once
#include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl"
#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
::
gemm
::
collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA_TMA_WS_SS (BlockScaled Builders)
template
<
class
ElementA
,
class
GmemLayoutATag
,
int
AlignmentA
,
class
ElementB
,
class
GmemLayoutBTag
,
int
AlignmentB
,
class
ElementAccumulator
,
class
TileShape_MNK
,
class
ClusterShape_MNK
,
class
StageCountType
,
int
ScaleGranularityM
>
struct
CollectiveBuilder
<
arch
::
Sm90
,
arch
::
OpClassTensorOp
,
ElementA
,
GmemLayoutATag
,
AlignmentA
,
ElementB
,
GmemLayoutBTag
,
AlignmentB
,
ElementAccumulator
,
TileShape_MNK
,
ClusterShape_MNK
,
StageCountType
,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>
,
cute
::
enable_if_t
<
not
detail
::
is_use_rmem_A
<
ElementA
,
GmemLayoutATag
,
ElementB
,
GmemLayoutBTag
>
()
>
>
{
using
KernelScheduleType
=
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>
;
static_assert
(
is_static
<
TileShape_MNK
>::
value
);
static_assert
(
is_static
<
ClusterShape_MNK
>::
value
);
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert
(
cutlass
::
detail
::
dependent_false
<
ElementA
>
,
"Unsupported Toolkit for SM90 Collective Builder
\n
"
);
#endif
static_assert
(
detail
::
is_aligned
<
ElementA
,
AlignmentA
,
ElementB
,
AlignmentB
,
detail
::
tma_alignment_bytes
>
(),
"Should meet TMA alignment requirement
\n
"
);
static
constexpr
bool
IsArrayOfPointersGemm
=
(
cute
::
is_any_of_v
<
KernelScheduleType
,
KernelPtrArrayTmaWarpSpecializedCooperative
,
KernelPtrArrayTmaWarpSpecializedPingpong
>
);
static
constexpr
bool
IsFP8Input
=
detail
::
is_input_fp8
<
ElementA
,
ElementB
>
();
static_assert
((
!
IsFP8Input
||
!
IsArrayOfPointersGemm
),
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now."
);
// For fp32 types, map to tf32 MMA value type
using
ElementAMma
=
cute
::
conditional_t
<
cute
::
is_same_v
<
ElementA
,
float
>
,
tfloat32_t
,
ElementA
>
;
using
ElementBMma
=
cute
::
conditional_t
<
cute
::
is_same_v
<
ElementB
,
float
>
,
tfloat32_t
,
ElementB
>
;
static
constexpr
cute
::
GMMA
::
Major
GmmaMajorA
=
detail
::
gmma_ss_tag_to_major_A
<
ElementAMma
,
GmemLayoutATag
>
();
static
constexpr
cute
::
GMMA
::
Major
GmmaMajorB
=
detail
::
gmma_ss_tag_to_major_B
<
ElementBMma
,
GmemLayoutBTag
>
();
static
constexpr
bool
IsCooperative
=
cute
::
is_any_of_v
<
KernelScheduleType
,
KernelTmaWarpSpecializedCooperative
,
KernelPtrArrayTmaWarpSpecializedCooperative
,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>>
;
using
AtomLayoutMNK
=
cute
::
conditional_t
<
IsCooperative
,
Layout
<
Shape
<
_2
,
_1
,
_1
>>
,
Layout
<
Shape
<
_1
,
_1
,
_1
>>>
;
using
TiledMma
=
decltype
(
cute
::
make_tiled_mma
(
cute
::
GMMA
::
ss_op_selector
<
ElementAMma
,
ElementBMma
,
ElementAccumulator
,
TileShape_MNK
,
GmmaMajorA
,
GmmaMajorB
>
(),
AtomLayoutMNK
{}));
using
GmemTiledCopyA
=
decltype
(
detail
::
sm90_cluster_shape_to_tma_atom
(
shape
<
1
>
(
ClusterShape_MNK
{})));
using
GmemTiledCopyB
=
decltype
(
detail
::
sm90_cluster_shape_to_tma_atom
(
shape
<
0
>
(
ClusterShape_MNK
{})));
using
SmemLayoutAtomA
=
decltype
(
detail
::
ss_smem_selector
<
GmmaMajorA
,
ElementAMma
,
decltype
(
cute
::
get
<
0
>
(
TileShape_MNK
{})),
decltype
(
cute
::
get
<
2
>
(
TileShape_MNK
{}))
>
());
using
SmemLayoutAtomB
=
decltype
(
detail
::
ss_smem_selector
<
GmmaMajorB
,
ElementBMma
,
decltype
(
cute
::
get
<
1
>
(
TileShape_MNK
{})),
decltype
(
cute
::
get
<
2
>
(
TileShape_MNK
{}))
>
());
static
constexpr
size_t
TensorMapStorage
=
IsArrayOfPointersGemm
?
sizeof
(
cute
::
TmaDescriptor
)
*
2
/* for A and B */
:
0
;
static
constexpr
int
KernelSmemCarveout
=
static_cast
<
int
>
(
TensorMapStorage
);
static
constexpr
int
PipelineStages
=
detail
::
compute_stage_count_or_override
<
detail
::
sm90_smem_capacity_bytes
-
KernelSmemCarveout
,
ElementAMma
,
ElementBMma
,
TileShape_MNK
>
(
StageCountType
{});
using
DispatchPolicy
=
MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
<
PipelineStages
,
ClusterShape_MNK
,
KernelScheduleType
,
ScaleGranularityM
>
;
using
SmemCopyAtomA
=
void
;
using
SmemCopyAtomB
=
void
;
using
CollectiveOp
=
CollectiveMma
<
DispatchPolicy
,
TileShape_MNK
,
ElementA
,
TagToStrideA_t
<
GmemLayoutATag
>
,
ElementB
,
TagToStrideB_t
<
GmemLayoutBTag
>
,
TiledMma
,
GmemTiledCopyA
,
SmemLayoutAtomA
,
SmemCopyAtomA
,
cute
::
identity
,
GmemTiledCopyB
,
SmemLayoutAtomB
,
SmemCopyAtomB
,
cute
::
identity
>
;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp
0 → 100644
View file @
9798b2fb
// clang-format off
// adapted from: https://github.com/soundOfDestiny/cutlass/blob/a4208aa6958864923505cade9c63eb2a6daf16e5/include/cutlass/gemm/collective/fp8_accumulation.hpp
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/algorithm/clear.hpp"
#include "cute/tensor.hpp"
//////////////////////////////////////////////////////////////////////////////
///////////////////////////////////FP8 Accumulation///////////////////////////
//////////////////////////////////////////////////////////////////////////////
/// This class provides API to promote (add) or scale (multiply_add) the results
/// from the tensor core accumulators to the main accumulators when the number
/// of MMAs reaches the max number of MMA interval specified by user, after that
/// the tensor core accumulators are zeroed.
//////////////////////////////////////////////////////////////////////////////
namespace
cutlass
::
gemm
::
collective
{
template
<
class
EngineAccum
,
class
LayoutAccum
>
struct
GmmaFP8AccumulationWithScale
{
using
TensorAccum
=
cute
::
Tensor
<
EngineAccum
,
LayoutAccum
>
;
using
ElementAccumulator
=
typename
EngineAccum
::
value_type
;
static_assert
(
is_static
<
LayoutAccum
>::
value
,
"Accumulator Layout should be static"
);
static_assert
(
is_rmem
<
TensorAccum
>::
value
,
"Accumulator tensor must be rmem resident."
);
private:
TensorAccum
&
accum_
;
TensorAccum
accum_temp_
;
uint32_t
accum_promotion_interval_
;
// defines the max num of executed MMAs after which accum should be promoted.
uint32_t
mma_count_per_mainloop_iteration_
;
// num of MMAs per k_tile of mainloop
uint32_t
mma_count_
;
// current executed MMAs
uint32_t
reset_accum_flag_
;
// accum needs to be zeroed or not.
// promote or `add` the partial accumulators to main accumulator (FADD).
CUTLASS_DEVICE
void
promote_core
()
{
warpgroup_wait
<
0
>
();
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
accum_
);
++
i
)
{
accum_
(
i
)
+=
accum_temp_
(
i
);
}
}
// `multiply` scale the partial accumulators and `add` to main accumulator (FFMA).
template
<
class
EngineScale
,
class
LayoutScale
>
CUTLASS_DEVICE
void
scale_core
(
const
cute
::
Tensor
<
EngineScale
,
LayoutScale
>
&
scale
)
{
using
TensorScale
=
cute
::
Tensor
<
EngineScale
,
LayoutScale
>
;
static_assert
(
is_static
<
LayoutScale
>::
value
,
"Scale Layout should be static"
);
static_assert
(
is_rmem
<
TensorScale
>::
value
,
"Scale tensor must be rmem resident."
);
static_assert
(
LayoutAccum
{}.
shape
()
==
LayoutScale
{}.
shape
(),
"Accumulator and scale must have same shape."
);
warpgroup_wait
<
0
>
();
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
accum_
);
++
i
)
{
accum_
(
i
)
+=
accum_temp_
(
i
)
*
scale
(
i
);
}
}
public:
CUTLASS_DEVICE
GmmaFP8AccumulationWithScale
(
TensorAccum
&
accum
,
uint32_t
accum_promotion_interval
,
uint32_t
mma_count_per_mainloop_iteration
)
:
accum_
(
accum
),
accum_promotion_interval_
(
accum_promotion_interval
),
mma_count_per_mainloop_iteration_
(
mma_count_per_mainloop_iteration
),
mma_count_
(
0
),
reset_accum_flag_
(
0
)
{
accum_temp_
=
cute
::
make_fragment_like
(
accum
);
}
//
// Methods (Common)
//
CUTLASS_DEVICE
TensorAccum
&
operator
()()
{
return
accum_temp_
;
}
/// prepare the MMA accumulators when initialization or zeroing is required.
CUTLASS_DEVICE
bool
prepare_if_needed
()
{
return
reset_accum_flag_
;
}
//
// Methods (for FADD version)
//
/// promote (add) the results from the MMA accumulators to main accumulator if needed.
CUTLASS_DEVICE
void
promote_if_needed
()
{
mma_count_
+=
mma_count_per_mainloop_iteration_
;
reset_accum_flag_
=
__shfl_sync
(
0xffffffff
,
mma_count_
==
accum_promotion_interval_
,
0
);
if
(
reset_accum_flag_
)
{
promote_core
();
mma_count_
=
0
;
}
}
/// promote (add) the residue results from the MMA accumulators to main accumulator if needed.
CUTLASS_DEVICE
void
promote_residue_if_needed
()
{
if
(
__shfl_sync
(
0xffffffff
,
mma_count_
>
0
,
0
))
{
promote_core
();
}
}
//
// Methods (for FFMA version)
//
/// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed.
template
<
class
EngineScale
,
class
LayoutScale
>
CUTLASS_DEVICE
void
scale_if_needed
(
const
cute
::
Tensor
<
EngineScale
,
LayoutScale
>
&
scale
)
{
mma_count_
+=
mma_count_per_mainloop_iteration_
;
reset_accum_flag_
=
__shfl_sync
(
0xffffffff
,
mma_count_
==
accum_promotion_interval_
,
0
);
if
(
reset_accum_flag_
)
{
scale_core
(
scale
);
mma_count_
=
0
;
}
}
/// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed.
template
<
class
EngineScale
,
class
LayoutScale
>
CUTLASS_DEVICE
void
scale_residue_if_needed
(
const
cute
::
Tensor
<
EngineScale
,
LayoutScale
>
&
scale
)
{
if
(
__shfl_sync
(
0xffffffff
,
mma_count_
>
0
,
0
))
{
scale_core
(
scale
);
}
}
};
}
// namespace cutlass::gemm::collective
csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
0 → 100644
View file @
9798b2fb
This diff is collapsed.
Click to expand it.
csrc/cutlass_extensions/gemm/dispatch_policy.hpp
0 → 100644
View file @
9798b2fb
#pragma once
#include "cutlass/gemm/dispatch_policy.hpp"
namespace
cutlass
::
gemm
{
//////////////////////////////////////////////////////////////////////////////
// FP8 related policies (including Blocked Scaled Accumulation)
// `ScaleGranularityM` specifies scaling granularity along M, while zero-value
// `ScaleGranularityM` indicates that scaling granularity is
// `size<0>(TileShape_MNK{})` along M.
template
<
int
ScaleGranularityM
=
0
>
struct
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
:
KernelTmaWarpSpecializedCooperative
{};
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
// specialized dynamic schedule For FP8 kernels with Block Scaling
template
<
int
Stages_
,
class
ClusterShape_
=
Shape
<
_1
,
_1
,
_1
>,
class
KernelSchedule
=
KernelTmaWarpSpecialized
,
int
ScaleGranularityM
=
0
// `ScaleGranularityM` specifies scaling granularity along M,
// while zero-value `ScaleGranularityM` indicates that scaling
// granularity is `size<0>(TileShape_MNK{})` along M.
>
struct
MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
:
MainloopSm90TmaGmmaWarpSpecialized
<
Stages_
,
ClusterShape_
,
KernelSchedule
>
{
static_assert
(
cute
::
is_same_v
<
KernelSchedule
,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>>
,
"KernelSchedule must be one of the warp specialized policies"
);
};
//////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::gemm
\ No newline at end of file
csrc/cutlass_extensions/vllm_collective_builder.cuh
View file @
9798b2fb
#pragma once
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass
_extensions
/gemm/collective/collective_builder.hpp"
namespace
cutlass
::
gemm
::
collective
{
using
namespace
cute
;
...
...
csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
0 → 100644
View file @
9798b2fb
#pragma once
// clang-format will break include orders
// clang-format off
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "core/math.hpp"
#include "cutlass_extensions/common.hpp"
// clang-format on
namespace
vllm
::
c3x
{
static
inline
cute
::
Shape
<
int
,
int
,
int
,
int
>
get_problem_shape
(
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
)
{
int32_t
m
=
a
.
size
(
0
),
n
=
b
.
size
(
1
),
k
=
a
.
size
(
1
);
return
{
m
,
n
,
k
,
1
};
}
template
<
typename
GemmKernel
>
void
cutlass_gemm_caller
(
torch
::
Device
device
,
cute
::
Shape
<
int
,
int
,
int
,
int
>
prob_shape
,
typename
GemmKernel
::
MainloopArguments
mainloop_args
,
typename
GemmKernel
::
EpilogueArguments
epilogue_args
)
{
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
prob_shape
,
mainloop_args
,
epilogue_args
};
// Launch the CUTLASS GEMM kernel.
using
GemmOp
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
GemmOp
gemm_op
;
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
device
);
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
device
.
index
());
cutlass
::
Status
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
CUTLASS_CHECK
(
status
);
}
template
<
typename
Gemm
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_params
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideA
=
cute
::
Stride
<
int64_t
,
cute
::
Int
<
1
>
,
int64_t
>
;
using
StrideB
=
cute
::
Stride
<
int64_t
,
cute
::
Int
<
1
>
,
int64_t
>
;
using
StrideC
=
typename
Gemm
::
StrideC
;
StrideA
a_stride
{
lda
,
cute
::
Int
<
1
>
{},
0
};
StrideB
b_stride
{
ldb
,
cute
::
Int
<
1
>
{},
0
};
StrideC
c_stride
{
ldc
,
cute
::
Int
<
1
>
{},
cute
::
Int
<
0
>
{}};
typename
GemmKernel
::
ProblemShape
prob_shape
=
get_problem_shape
(
a
,
b
);
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
};
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
Gemm
::
Epilogue
::
prepare_args
(
std
::
forward
<
EpilogueArgs
>
(
epilogue_params
)...),
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
cutlass_gemm_caller
<
GemmKernel
>
(
a
.
device
(),
prob_shape
,
mainloop_args
,
epilogue_args
);
}
}
// namespace vllm::c3x
\ No newline at end of file
csrc/quantization/cutlass_w8a8/scaled_mm
_c3x
.cuh
→
csrc/quantization/cutlass_w8a8/
c3x/
scaled_mm.cuh
View file @
9798b2fb
...
...
@@ -2,9 +2,6 @@
// clang-format will break include orders
// clang-format off
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "cutlass/cutlass.h"
...
...
@@ -32,21 +29,6 @@ using namespace cute;
namespace
vllm
{
// A wrapper for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template
<
typename
Kernel
>
struct
enable_sm90_or_later
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
void
operator
()(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
Kernel
::
operator
()(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
ClusterShape
,
typename
KernelSchedule
,
...
...
@@ -101,60 +83,4 @@ struct cutlass_3x_gemm {
struct
GemmKernel
:
public
KernelType
{};
};
template
<
typename
Gemm
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_params
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
int32_t
m
=
a
.
size
(
0
);
int32_t
n
=
b
.
size
(
1
);
int32_t
k
=
a
.
size
(
1
);
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideB
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideC
=
typename
Gemm
::
StrideC
;
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
0
};
StrideB
b_stride
{
ldb
,
Int
<
1
>
{},
0
};
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
typename
GemmKernel
::
ProblemShape
prob_shape
{
m
,
n
,
k
,
1
};
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
};
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
Gemm
::
Epilogue
::
prepare_args
(
std
::
forward
<
EpilogueArgs
>
(
epilogue_params
)...),
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
prob_shape
,
mainloop_args
,
epilogue_args
};
// Launch the CUTLASS GEMM kernel.
using
GemmOp
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
GemmOp
gemm_op
;
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
cutlass
::
Status
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
CUTLASS_CHECK
(
status
);
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu
0 → 100644
View file @
9798b2fb
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace
vllm
{
void
cutlass_scaled_mm_azp_sm90_int8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
if
(
azp
)
{
return
cutlass_scaled_mm_sm90_int8_epilogue
<
c3x
::
ScaledEpilogueBiasAzpToken
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
}
else
{
return
cutlass_scaled_mm_sm90_int8_epilogue
<
c3x
::
ScaledEpilogueBiasAzp
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu
0 → 100644
View file @
9798b2fb
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace
vllm
{
void
cutlass_scaled_mm_blockwise_sm90_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
cutlass_gemm_blockwise_sm90_fp8_dispatch
<
cutlass
::
bfloat16_t
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
cutlass_gemm_blockwise_sm90_fp8_dispatch
<
cutlass
::
half_t
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
// namespace vllm
\ No newline at end of file
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
0 → 100644
View file @
9798b2fb
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_gemm_caller.cuh"
namespace
vllm
{
using
namespace
cute
;
template
<
typename
OutType
,
int
GroupSizeM_
,
int
GroupSizeN_
,
int
GroupSizeK_
,
int
TileSizeM_
=
128
,
class
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
>
struct
cutlass_3x_gemm_fp8_blockwise
{
using
GroupSizeM
=
Int
<
GroupSizeM_
>
;
using
GroupSizeN
=
Int
<
GroupSizeN_
>
;
using
GroupSizeK
=
Int
<
GroupSizeK_
>
;
using
TileSizeM
=
Int
<
TileSizeM_
>
;
static_assert
(
TileSizeM_
%
GroupSizeM_
==
0
,
"TileSizeM must be a multiple of GroupSizeM"
);
using
ElementAB
=
cutlass
::
float_e4m3_t
;
using
ElementA
=
ElementAB
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
using
ElementB
=
ElementAB
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
using
ElementD
=
OutType
;
using
StrideD
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
using
ElementC
=
void
;
using
StrideC
=
StrideD
;
static
constexpr
int
AlignmentC
=
AlignmentD
;
using
ElementAccumulator
=
float
;
using
ElementBlockScale
=
float
;
using
ElementCompute
=
float
;
using
ArchTag
=
cutlass
::
arch
::
Sm90
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
TileShape
=
Shape
<
TileSizeM
,
GroupSizeN
,
GroupSizeK
>
;
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
GroupSizeM_
>
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
EpilogueTileType
=
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
;
using
StoreEpilogueCompute
=
typename
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
TileShape
,
ClusterShape
,
EpilogueTileType
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
StrideC
,
AlignmentC
,
ElementD
,
StrideD
,
AlignmentD
,
EpilogueSchedule
,
StoreEpilogueCompute
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
LayoutA
,
AlignmentA
,
ElementB
,
LayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
KernelSchedule
>::
CollectiveOp
;
using
KernelType
=
enable_sm90_or_later
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
cutlass
::
gemm
::
PersistentScheduler
>>
;
struct
GemmKernel
:
public
KernelType
{};
using
StrideA
=
typename
GemmKernel
::
StrideA
;
using
StrideB
=
typename
GemmKernel
::
StrideB
;
};
template
<
typename
Gemm
>
void
cutlass_gemm_caller_blockwise
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
auto
prob_shape
=
c3x
::
get_problem_shape
(
a
,
b
);
int32_t
m
=
get
<
0
>
(
prob_shape
),
n
=
get
<
1
>
(
prob_shape
),
k
=
get
<
2
>
(
prob_shape
);
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideB
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideC
=
typename
Gemm
::
StrideC
;
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
0
};
StrideB
b_stride
{
ldb
,
Int
<
1
>
{},
0
};
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
auto
a_scales_ptr
=
static_cast
<
float
*>
(
a_scales
.
data_ptr
());
auto
b_scales_ptr
=
static_cast
<
float
*>
(
b_scales
.
data_ptr
());
// Check is the t is contiguous and is 1D or 2D with one of the dimensions
// being 1 (i.e. a row or column vector)
auto
is_contiguous_vector
=
[](
const
torch
::
Tensor
&
t
)
{
auto
t_sizes
=
t
.
sizes
();
return
t
.
is_contiguous
()
&&
(
t
.
dim
()
==
1
||
(
t
.
dim
()
==
2
&&
*
std
::
min_element
(
t_sizes
.
begin
(),
t_sizes
.
end
())
==
1
));
};
// TODO(lucas): lets clean-up the kernel so that we pass in Strides so
// we don't have to deal with enforcing implicit layouts
TORCH_CHECK
(
a_scales
.
size
(
0
)
==
m
/
Gemm
::
GroupSizeM
::
value
);
TORCH_CHECK
(
a_scales
.
size
(
1
)
==
k
/
Gemm
::
GroupSizeK
::
value
);
TORCH_CHECK
(
a_scales
.
stride
(
0
)
==
1
||
is_contiguous_vector
(
a_scales
),
"a_scales must be M major"
);
TORCH_CHECK
(
b_scales
.
size
(
0
)
==
k
/
Gemm
::
GroupSizeK
::
value
);
TORCH_CHECK
(
b_scales
.
size
(
1
)
==
n
/
Gemm
::
GroupSizeN
::
value
);
TORCH_CHECK
(
b_scales
.
stride
(
0
)
==
1
||
is_contiguous_vector
(
b_scales
),
"b_scales must be K major"
);
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
,
a_scales_ptr
,
b_scales_ptr
};
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
{},
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
c3x
::
cutlass_gemm_caller
<
GemmKernel
>
(
a
.
device
(),
prob_shape
,
mainloop_args
,
epilogue_args
);
}
template
<
typename
OutType
>
void
cutlass_gemm_blockwise_sm90_fp8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
128
,
128
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
// namespace vllm
\ No newline at end of file
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
0 → 100644
View file @
9798b2fb
#pragma once
#include <torch/all.h>
namespace
vllm
{
void
cutlass_scaled_mm_sm90_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_sm90_int8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_azp_sm90_int8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_blockwise_sm90_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu
0 → 100644
View file @
9798b2fb
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace
vllm
{
void
cutlass_scaled_mm_sm90_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm90_fp8_epilogue
<
c3x
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_mm_sm90_fp8_epilogue
<
c3x
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_
c3x_
sm90_fp8_dispatch.cuh
→
csrc/quantization/cutlass_w8a8/
c3x/
scaled_mm_sm90_fp8_dispatch.cuh
View file @
9798b2fb
#pragma once
#include "scaled_mm_c3x.cuh"
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
/**
* This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
...
...
@@ -9,6 +10,8 @@
namespace
vllm
{
using
c3x
::
cutlass_gemm_caller
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_default
{
...
...
@@ -93,4 +96,25 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
}
}
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm90_fp8_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_args
)
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
// namespace vllm
\ No newline at end of file
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
0 → 100644
View file @
9798b2fb
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace
vllm
{
void
cutlass_scaled_mm_sm90_int8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm90_int8_epilogue
<
c3x
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_mm_sm90_int8_epilogue
<
c3x
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_
c3x_
sm90_int8_dispatch.cuh
→
csrc/quantization/cutlass_w8a8/
c3x/
scaled_mm_sm90_int8_dispatch.cuh
View file @
9798b2fb
#pragma once
#include "scaled_mm_c3x.cuh"
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
/**
* This file defines Gemm kernel configurations for SM90 (int8) based on the
...
...
@@ -9,6 +10,8 @@
namespace
vllm
{
using
c3x
::
cutlass_gemm_caller
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_default
{
...
...
@@ -137,4 +140,24 @@ inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
}
}
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm90_int8_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_args
)
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
// namespace vllm
\ No newline at end of file
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
View file @
9798b2fb
#include <cudaTypedefs.h>
#include "c3x/scaled_mm_kernels.hpp"
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#include "scaled_mm_c3x_sm90_fp8_dispatch.cuh"
#include "scaled_mm_c3x_sm90_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
using
namespace
vllm
;
#include "core/math.hpp"
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper) or later.
*/
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm90_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_args
)
{
if
(
a
.
dtype
()
==
torch
::
kInt8
)
{
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
void
cutlass_scaled_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
...
...
@@ -54,14 +15,50 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
c
.
dtype
(),
"currently bias dtype must match output dtype "
,
c
.
dtype
());
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogueBias
>
(
c
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
using
GroupShape
=
std
::
array
<
int64_t
,
2
>
;
int
M
=
a
.
size
(
0
),
N
=
b
.
size
(
1
),
K
=
a
.
size
(
1
);
GroupShape
a_scale_group_shape
=
[
&
,
&
s
=
a_scales
]()
->
GroupShape
{
if
(
s
.
numel
()
==
1
)
return
{
M
,
K
};
// tensor-wise
if
(
s
.
dim
()
==
2
)
return
{
ceil_div
(
a
.
size
(
0
),
s
.
size
(
0
)),
ceil_div
(
a
.
size
(
1
),
s
.
size
(
1
))};
TORCH_CHECK
(
false
,
"Unsupported scale shape for scale_a"
);
}();
GroupShape
b_scale_group_shape
=
[
&
,
&
s
=
b_scales
]()
->
GroupShape
{
if
(
s
.
numel
()
==
1
)
return
{
K
,
N
};
// tensor-wise
if
(
s
.
dim
()
==
2
)
return
{
ceil_div
(
b
.
size
(
0
),
s
.
size
(
0
)),
ceil_div
(
b
.
size
(
1
),
s
.
size
(
1
))};
TORCH_CHECK
(
false
,
"Unsupported scale shape for scale_b"
);
}();
if
((
a_scale_group_shape
==
GroupShape
{
M
,
K
}
||
a_scale_group_shape
==
GroupShape
{
1
,
K
})
&&
(
b_scale_group_shape
==
GroupShape
{
K
,
N
}
||
b_scale_group_shape
==
GroupShape
{
K
,
1
}))
{
// "standard per-tensor/per-token/per-channel" scaling
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
)
{
vllm
::
cutlass_scaled_mm_sm90_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
{
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogue
>
(
c
,
a
,
b
,
a_scales
,
b_scales
);
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
vllm
::
cutlass_scaled_mm_sm90_int8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
}
else
if
(
a_scale_group_shape
==
GroupShape
{
1
,
128
}
&&
b_scale_group_shape
==
GroupShape
{
128
,
128
})
{
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
&&
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
,
"Currently only FP8 is supported for A group shape 1x128 and "
"B group shape 128x128"
);
TORCH_CHECK
(
!
bias
,
"Bias not yet supported blockwise scaled_mm"
);
vllm
::
cutlass_scaled_mm_blockwise_sm90_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported scale group shapes for CUTLASS 3.x GEMM"
);
}
}
...
...
@@ -75,13 +72,6 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
azp
)
{
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogueBiasAzpToken
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
}
else
{
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogueBiasAzp
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
}
vllm
::
cutlass_scaled_mm_azp_sm90_int8
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
}
#endif
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment