Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9798b2fb
Unverified
Commit
9798b2fb
authored
Jan 30, 2025
by
Lucas Wilkinson
Committed by
GitHub
Jan 30, 2025
Browse files
[Kernel] Update `cutlass_scaled_mm` to support 2d group (blockwise) scaling (#11868)
parent
4078052f
Changes
25
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1736 additions
and
285 deletions
+1736
-285
CMakeLists.txt
CMakeLists.txt
+7
-2
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
+142
-148
csrc/core/math.hpp
csrc/core/math.hpp
+8
-1
csrc/cutlass_extensions/common.hpp
csrc/cutlass_extensions/common.hpp
+17
-0
csrc/cutlass_extensions/gemm/collective/collective_builder.hpp
...cutlass_extensions/gemm/collective/collective_builder.hpp
+123
-0
csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp
csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp
+183
-0
csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
...mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
+730
-0
csrc/cutlass_extensions/gemm/dispatch_policy.hpp
csrc/cutlass_extensions/gemm/dispatch_policy.hpp
+39
-0
csrc/cutlass_extensions/vllm_collective_builder.cuh
csrc/cutlass_extensions/vllm_collective_builder.cuh
+1
-1
csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
+93
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh
csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh
+0
-74
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu
.../quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu
+24
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu
...tization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu
+24
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
...utlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
+168
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
+33
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu
+24
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh
...tization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh
+25
-1
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
+24
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh
...ization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh
+24
-1
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
+47
-57
No files found.
CMakeLists.txt
View file @
9798b2fb
...
@@ -245,7 +245,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -245,7 +245,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
FetchContent_Declare
(
FetchContent_Declare
(
cutlass
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
GIT_TAG v3.
6
.0
GIT_TAG v3.
7
.0
GIT_PROGRESS TRUE
GIT_PROGRESS TRUE
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
...
@@ -299,7 +299,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -299,7 +299,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
cuda_archs_loose_intersection
(
SCALED_MM_3X_ARCHS
"9.0a"
"
${
CUDA_ARCHS
}
"
)
cuda_archs_loose_intersection
(
SCALED_MM_3X_ARCHS
"9.0a"
"
${
CUDA_ARCHS
}
"
)
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS
)
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS
)
set
(
SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
)
set
(
SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu"
)
set_gencode_flags_for_srcs
(
set_gencode_flags_for_srcs
(
SRCS
"
${
SRCS
}
"
SRCS
"
${
SRCS
}
"
CUDA_ARCHS
"
${
SCALED_MM_3X_ARCHS
}
"
)
CUDA_ARCHS
"
${
SCALED_MM_3X_ARCHS
}
"
)
...
...
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
View file @
9798b2fb
...
@@ -3,7 +3,7 @@ import copy
...
@@ -3,7 +3,7 @@ import copy
import
itertools
import
itertools
import
pickle
as
pkl
import
pickle
as
pkl
import
time
import
time
from
typing
import
Callable
,
Iterable
,
List
,
Tuple
from
typing
import
Callable
,
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.utils.benchmark
as
TBenchmark
import
torch.utils.benchmark
as
TBenchmark
...
@@ -12,6 +12,8 @@ from utils import make_rand_tensors
...
@@ -12,6 +12,8 @@ from utils import make_rand_tensors
from
weight_shapes
import
WEIGHT_SHAPES
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
w8a8_block_fp8_matmul
)
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
list
(
WEIGHT_SHAPES
.
keys
())
DEFAULT_MODELS
=
list
(
WEIGHT_SHAPES
.
keys
())
...
@@ -38,8 +40,15 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
...
@@ -38,8 +40,15 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
).
blocked_autorange
(
min_run_time
=
min_run_time
)
def
bench_int8
(
dtype
:
torch
.
dtype
,
m
:
int
,
k
:
int
,
n
:
int
,
label
:
str
,
def
bench_int8
(
sub_label
:
str
)
->
Iterable
[
TMeasurement
]:
dtype
:
torch
.
dtype
,
m
:
int
,
k
:
int
,
n
:
int
,
label
:
str
,
sub_label
:
str
,
bench_kernels
:
Optional
[
List
[
str
]]
=
None
)
->
Iterable
[
TMeasurement
]:
"""Benchmark INT8-based kernels."""
assert
dtype
==
torch
.
int8
assert
dtype
==
torch
.
int8
a
,
b
=
make_rand_tensors
(
torch
.
int8
,
m
,
n
,
k
)
a
,
b
=
make_rand_tensors
(
torch
.
int8
,
m
,
n
,
k
)
scale_a
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_a
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
...
@@ -48,155 +57,132 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
...
@@ -48,155 +57,132 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
azp
=
torch
.
zeros
((
m
,
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
azp
=
torch
.
zeros
((
m
,
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
azp_adj
=
torch
.
zeros
((
n
,
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
azp_adj
=
torch
.
zeros
((
n
,
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
bench_fns
=
{
"pytorch_bf16_bf16_bf16_matmul-no-scales"
:
lambda
:
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
bfloat16
),
b
.
to
(
dtype
=
torch
.
bfloat16
)
),
"pytorch_fp16_fp16_fp16_matmul-no-scales"
:
lambda
:
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
float16
),
b
.
to
(
dtype
=
torch
.
float16
)),
"cutlass_i8_i8_bf16_scaled_mm"
:
lambda
:
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
),
"cutlass_i8_i8_bf16_scaled_mm_bias"
:
lambda
:
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
bias
),
"cutlass_i8_i8_bf16_scaled_mm_azp"
:
lambda
:
ops
.
cutlass_scaled_mm_azp
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
),
"cutlass_i8_i8_bf16_scaled_mm_azp_bias"
:
lambda
:
ops
.
cutlass_scaled_mm_azp
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
,
None
,
bias
),
"cutlass_i8_i8_bf16_scaled_mm_azp_pt"
:
lambda
:
ops
.
cutlass_scaled_mm_azp
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
,
azp
),
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias"
:
lambda
:
ops
.
cutlass_scaled_mm_azp
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
,
azp
,
bias
),
}
timers
=
[]
timers
=
[]
# pytorch impl - bfloat16
for
name
,
fn
in
bench_fns
.
items
():
timers
.
append
(
# If bench_kernels is None, run all. Otherwise, run only exact matches.
bench_fn
(
label
,
sub_label
,
"pytorch_bf16_bf16_bf16_matmul-no-scales"
,
if
bench_kernels
is
None
or
name
in
bench_kernels
:
torch
.
mm
,
a
.
to
(
dtype
=
torch
.
bfloat16
),
print
(
f
"Running
{
name
}
"
)
b
.
to
(
dtype
=
torch
.
bfloat16
)))
timers
.
append
(
bench_fn
(
label
,
sub_label
,
name
,
fn
))
# pytorch impl - float16
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_fp16_fp16_fp16_matmul-no-scales"
,
torch
.
mm
,
a
.
to
(
dtype
=
torch
.
float16
),
b
.
to
(
dtype
=
torch
.
float16
)))
# cutlass impl
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
))
# cutlass with bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm_bias"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
bias
))
# cutlass with azp per-tensor
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm_azp"
,
ops
.
cutlass_scaled_mm_azp
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
))
# cutlass with azp per-tensor + bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm_azp_bias"
,
ops
.
cutlass_scaled_mm_azp
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
,
None
,
bias
))
# cutlass with azp per-token
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm_azp_pt"
,
ops
.
cutlass_scaled_mm_azp
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
,
azp
))
# cutlass with azp per-token + bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias"
,
ops
.
cutlass_scaled_mm_azp
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
,
azp
,
bias
))
return
timers
return
timers
def
bench_fp8
(
dtype
:
torch
.
dtype
,
m
:
int
,
k
:
int
,
n
:
int
,
label
:
str
,
def
bench_fp8
(
sub_label
:
str
)
->
Iterable
[
TMeasurement
]:
dtype
:
torch
.
dtype
,
m
:
int
,
k
:
int
,
n
:
int
,
label
:
str
,
sub_label
:
str
,
bench_kernels
:
Optional
[
List
[
str
]]
=
None
)
->
Iterable
[
TMeasurement
]:
"""Benchmark FP8-based kernels."""
assert
dtype
==
torch
.
float8_e4m3fn
assert
dtype
==
torch
.
float8_e4m3fn
a
,
b
=
make_rand_tensors
(
torch
.
float8_e4m3fn
,
m
,
n
,
k
)
a
,
b
=
make_rand_tensors
(
torch
.
float8_e4m3fn
,
m
,
n
,
k
)
a_cont
=
a
.
contiguous
()
scale_a
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_a
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
block_scale_a
=
torch
.
rand
((
m
,
k
//
128
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
block_scale_b
=
torch
.
rand
((
k
//
128
,
n
//
128
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
block_scale_a_M_major
=
block_scale_a
.
t
().
contiguous
().
t
()
block_scale_b_K_major
=
block_scale_b
.
t
().
contiguous
().
t
()
bias
=
torch
.
zeros
((
n
,
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
bias
=
torch
.
zeros
((
n
,
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
timers
=
[]
print
(
m
,
k
,
n
)
bench_fns
=
{
"pytorch_bf16_bf16_bf16_matmul-no-scales"
:
lambda
:
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
bfloat16
),
b
.
to
(
dtype
=
torch
.
bfloat16
)
),
"pytorch_fp16_fp16_fp16_matmul-no-scales"
:
lambda
:
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
float16
),
b
.
to
(
dtype
=
torch
.
float16
)),
"pytorch_fp8_fp8_fp16_scaled_mm"
:
lambda
:
torch
.
_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
=
torch
.
float16
),
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum"
:
lambda
:
torch
.
_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
=
torch
.
float16
,
use_fast_accum
=
True
),
"pytorch_fp8_fp8_bf16_scaled_mm"
:
lambda
:
torch
.
_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
=
torch
.
bfloat16
),
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum"
:
lambda
:
torch
.
_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
=
torch
.
bfloat16
,
use_fast_accum
=
True
),
"cutlass_fp8_fp8_bf16_scaled_mm"
:
lambda
:
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
),
"cutlass_fp8_fp8_fp16_scaled_mm"
:
lambda
:
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
),
"cutlass_fp8_fp8_bf16_scaled_mm_bias"
:
lambda
:
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
bias
),
"cutlass_fp8_fp8_fp16_scaled_mm_bias"
:
lambda
:
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
,
bias
.
to
(
dtype
=
torch
.
float16
)),
"triton_fp8_fp8_fp16_scaled_mm_blockwise"
:
lambda
:
w8a8_block_fp8_matmul
(
a_cont
,
b
.
t
(),
block_scale_a
,
block_scale_b
.
t
(),
(
128
,
128
)),
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise"
:
lambda
:
ops
.
cutlass_scaled_mm
(
a
,
b
,
block_scale_a_M_major
,
block_scale_b_K_major
,
torch
.
float16
),
}
# pytorch impl w. bf16
timers
=
[]
timers
.
append
(
for
name
,
fn
in
bench_fns
.
items
():
bench_fn
(
label
,
sub_label
,
"pytorch_bf16_bf16_bf16_matmul-no-scales"
,
# If bench_kernels is None, run all. Otherwise, run only exact matches.
torch
.
mm
,
a
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
if
bench_kernels
is
None
or
name
in
bench_kernels
:
b
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)))
print
(
f
"Running
{
name
}
"
)
timers
.
append
(
bench_fn
(
label
,
sub_label
,
name
,
fn
))
# pytorch impl: bf16 output, without fp8 fast accum
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_bf16_scaled_mm"
,
torch
.
_scaled_mm
,
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
bfloat16
))
# pytorch impl: bf16 output, with fp8 fast accum
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum"
,
torch
.
_scaled_mm
,
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
bfloat16
,
use_fast_accum
=
True
))
# pytorch impl: fp16 output, without fp8 fast accum
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_fp16_scaled_mm"
,
torch
.
_scaled_mm
,
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
float16
))
# pytorch impl: fp16 output, with fp8 fast accum
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum"
,
torch
.
_scaled_mm
,
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
float16
,
use_fast_accum
=
True
))
# cutlass impl: bf16 output
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_bf16_scaled_mm"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
))
# cutlass impl: fp16 output
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_fp16_scaled_mm"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
))
# cutlass impl: bf16 output, with bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_bf16_scaled_mm_bias"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
bias
))
# cutlass impl: fp16 output, with bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_fp16_scaled_mm_bias"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
,
bias
.
to
(
dtype
=
torch
.
float16
)))
return
timers
return
timers
def
bench
(
dtype
:
torch
.
dtype
,
m
:
int
,
k
:
int
,
n
:
int
,
label
:
str
,
def
bench
(
dtype
:
torch
.
dtype
,
sub_label
:
str
)
->
Iterable
[
TMeasurement
]:
m
:
int
,
k
:
int
,
n
:
int
,
label
:
str
,
sub_label
:
str
,
bench_kernels
:
Optional
[
List
[
str
]]
=
None
)
->
Iterable
[
TMeasurement
]:
if
dtype
==
torch
.
int8
:
if
dtype
==
torch
.
int8
:
return
bench_int8
(
dtype
,
m
,
k
,
n
,
label
,
sub_label
)
return
bench_int8
(
dtype
,
m
,
k
,
n
,
label
,
sub_label
,
bench_kernels
)
if
dtype
==
torch
.
float8_e4m3fn
:
if
dtype
==
torch
.
float8_e4m3fn
:
return
bench_fp8
(
dtype
,
m
,
k
,
n
,
label
,
sub_label
)
return
bench_fp8
(
dtype
,
m
,
k
,
n
,
label
,
sub_label
,
bench_kernels
)
raise
ValueError
(
"unsupported type"
)
raise
ValueError
(
"unsupported type"
)
...
@@ -207,18 +193,22 @@ def print_timers(timers: Iterable[TMeasurement]):
...
@@ -207,18 +193,22 @@ def print_timers(timers: Iterable[TMeasurement]):
def
run
(
dtype
:
torch
.
dtype
,
def
run
(
dtype
:
torch
.
dtype
,
MKNs
:
Iterable
[
Tuple
[
int
,
int
,
int
]])
->
Iterable
[
TMeasurement
]:
MKNs
:
Iterable
[
Tuple
[
int
,
int
,
int
]],
bench_kernels
:
Optional
[
List
[
str
]]
=
None
)
->
Iterable
[
TMeasurement
]:
results
=
[]
results
=
[]
for
m
,
k
,
n
in
MKNs
:
for
m
,
k
,
n
in
MKNs
:
timers
=
bench
(
dtype
,
m
,
k
,
n
,
f
"scaled-
{
dtype
}
-gemm"
,
timers
=
bench
(
dtype
,
f
"MKN=(
{
m
}
x
{
k
}
x
{
n
}
)"
)
m
,
k
,
n
,
f
"scaled-
{
dtype
}
-gemm"
,
f
"MKN=(
{
m
}
x
{
k
}
x
{
n
}
)"
,
bench_kernels
=
bench_kernels
)
print_timers
(
timers
)
print_timers
(
timers
)
results
.
extend
(
timers
)
results
.
extend
(
timers
)
return
results
return
results
# output makers
def
make_output
(
data
:
Iterable
[
TMeasurement
],
def
make_output
(
data
:
Iterable
[
TMeasurement
],
MKNs
:
Iterable
[
Tuple
[
int
,
int
,
int
]],
MKNs
:
Iterable
[
Tuple
[
int
,
int
,
int
]],
base_description
:
str
,
base_description
:
str
,
...
@@ -232,15 +222,11 @@ def make_output(data: Iterable[TMeasurement],
...
@@ -232,15 +222,11 @@ def make_output(data: Iterable[TMeasurement],
pkl
.
dump
(
data
,
f
)
pkl
.
dump
(
data
,
f
)
# argparse runners
def
run_square_bench
(
args
):
def
run_square_bench
(
args
):
dim_sizes
=
list
(
dim_sizes
=
list
(
range
(
args
.
dim_start
,
args
.
dim_end
+
1
,
args
.
dim_increment
))
range
(
args
.
dim_start
,
args
.
dim_end
+
1
,
args
.
dim_increment
))
MKNs
=
list
(
zip
(
dim_sizes
,
dim_sizes
,
dim_sizes
))
MKNs
=
list
(
zip
(
dim_sizes
,
dim_sizes
,
dim_sizes
))
data
=
run
(
args
.
dtype
,
MKNs
)
data
=
run
(
args
.
dtype
,
MKNs
,
bench_kernels
=
args
.
kernels
)
make_output
(
data
,
MKNs
,
f
"square_bench-
{
args
.
dtype
}
"
)
make_output
(
data
,
MKNs
,
f
"square_bench-
{
args
.
dtype
}
"
)
...
@@ -251,8 +237,7 @@ def run_range_bench(args):
...
@@ -251,8 +237,7 @@ def run_range_bench(args):
Ks
=
[
args
.
k_constant
]
*
n
if
args
.
k_constant
is
not
None
else
dim_sizes
Ks
=
[
args
.
k_constant
]
*
n
if
args
.
k_constant
is
not
None
else
dim_sizes
Ns
=
[
args
.
n_constant
]
*
n
if
args
.
n_constant
is
not
None
else
dim_sizes
Ns
=
[
args
.
n_constant
]
*
n
if
args
.
n_constant
is
not
None
else
dim_sizes
MKNs
=
list
(
zip
(
Ms
,
Ks
,
Ns
))
MKNs
=
list
(
zip
(
Ms
,
Ks
,
Ns
))
data
=
run
(
args
.
dtype
,
MKNs
)
data
=
run
(
args
.
dtype
,
MKNs
,
bench_kernels
=
args
.
kernels
)
make_output
(
data
,
MKNs
,
f
"range_bench-
{
args
.
dtype
}
"
)
make_output
(
data
,
MKNs
,
f
"range_bench-
{
args
.
dtype
}
"
)
...
@@ -278,7 +263,7 @@ def run_model_bench(args):
...
@@ -278,7 +263,7 @@ def run_model_bench(args):
for
k
,
n
in
KNs
:
for
k
,
n
in
KNs
:
MKNs
.
append
((
m
,
k
,
n
))
MKNs
.
append
((
m
,
k
,
n
))
data
=
run
(
args
.
dtype
,
MKNs
)
data
=
run
(
args
.
dtype
,
MKNs
,
bench_kernels
=
args
.
kernels
)
model_bench_data
.
append
(
data
)
model_bench_data
.
append
(
data
)
# Print all results
# Print all results
...
@@ -328,6 +313,15 @@ Benchmark Cutlass GEMM.
...
@@ -328,6 +313,15 @@ Benchmark Cutlass GEMM.
type
=
to_torch_dtype
,
type
=
to_torch_dtype
,
required
=
True
,
required
=
True
,
help
=
"Available options are ['int8', 'fp8']"
)
help
=
"Available options are ['int8', 'fp8']"
)
parser
.
add_argument
(
"--kernels"
,
nargs
=
"+"
,
type
=
str
,
default
=
None
,
help
=
"Exact names of the kernels to benchmark. If not set, runs all kernels."
)
subparsers
=
parser
.
add_subparsers
(
dest
=
"cmd"
)
subparsers
=
parser
.
add_subparsers
(
dest
=
"cmd"
)
square_parser
=
subparsers
.
add_parser
(
"square_bench"
)
square_parser
=
subparsers
.
add_parser
(
"square_bench"
)
...
@@ -362,4 +356,4 @@ Benchmark Cutlass GEMM.
...
@@ -362,4 +356,4 @@ Benchmark Cutlass GEMM.
model_parser
.
set_defaults
(
func
=
run_model_bench
)
model_parser
.
set_defaults
(
func
=
run_model_bench
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
args
.
func
(
args
)
args
.
func
(
args
)
\ No newline at end of file
csrc/core/math.hpp
View file @
9798b2fb
#pragma once
#include <climits>
#include <climits>
#include <iostream>
#include <iostream>
inline
uint32_t
next_pow_2
(
uint32_t
const
num
)
{
inline
constexpr
uint32_t
next_pow_2
(
uint32_t
const
num
)
{
if
(
num
<=
1
)
return
num
;
if
(
num
<=
1
)
return
num
;
return
1
<<
(
CHAR_BIT
*
sizeof
(
num
)
-
__builtin_clz
(
num
-
1
));
return
1
<<
(
CHAR_BIT
*
sizeof
(
num
)
-
__builtin_clz
(
num
-
1
));
}
template
<
typename
T
>
inline
constexpr
std
::
enable_if_t
<
std
::
is_integral_v
<
T
>
,
T
>
ceil_div
(
T
a
,
T
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
}
\ No newline at end of file
csrc/cutlass_extensions/common.hpp
View file @
9798b2fb
...
@@ -32,3 +32,20 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
...
@@ -32,3 +32,20 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
}
}
int32_t
get_sm_version_num
();
int32_t
get_sm_version_num
();
/**
* A wrapper for a kernel that is used to guard against compilation on
* architectures that will never use the kernel. The purpose of this is to
* reduce the size of the compiled binary.
* __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
* into code that will be executed on the device where it is defined.
*/
template
<
typename
Kernel
>
struct
enable_sm90_or_later
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
void
operator
()(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
Kernel
::
operator
()(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
\ No newline at end of file
csrc/cutlass_extensions/gemm/collective/collective_builder.hpp
0 → 100644
View file @
9798b2fb
// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl
// clang-format off
#pragma once
#include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl"
#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
::
gemm
::
collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA_TMA_WS_SS (BlockScaled Builders)
template
<
class
ElementA
,
class
GmemLayoutATag
,
int
AlignmentA
,
class
ElementB
,
class
GmemLayoutBTag
,
int
AlignmentB
,
class
ElementAccumulator
,
class
TileShape_MNK
,
class
ClusterShape_MNK
,
class
StageCountType
,
int
ScaleGranularityM
>
struct
CollectiveBuilder
<
arch
::
Sm90
,
arch
::
OpClassTensorOp
,
ElementA
,
GmemLayoutATag
,
AlignmentA
,
ElementB
,
GmemLayoutBTag
,
AlignmentB
,
ElementAccumulator
,
TileShape_MNK
,
ClusterShape_MNK
,
StageCountType
,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>
,
cute
::
enable_if_t
<
not
detail
::
is_use_rmem_A
<
ElementA
,
GmemLayoutATag
,
ElementB
,
GmemLayoutBTag
>
()
>
>
{
using
KernelScheduleType
=
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>
;
static_assert
(
is_static
<
TileShape_MNK
>::
value
);
static_assert
(
is_static
<
ClusterShape_MNK
>::
value
);
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert
(
cutlass
::
detail
::
dependent_false
<
ElementA
>
,
"Unsupported Toolkit for SM90 Collective Builder
\n
"
);
#endif
static_assert
(
detail
::
is_aligned
<
ElementA
,
AlignmentA
,
ElementB
,
AlignmentB
,
detail
::
tma_alignment_bytes
>
(),
"Should meet TMA alignment requirement
\n
"
);
static
constexpr
bool
IsArrayOfPointersGemm
=
(
cute
::
is_any_of_v
<
KernelScheduleType
,
KernelPtrArrayTmaWarpSpecializedCooperative
,
KernelPtrArrayTmaWarpSpecializedPingpong
>
);
static
constexpr
bool
IsFP8Input
=
detail
::
is_input_fp8
<
ElementA
,
ElementB
>
();
static_assert
((
!
IsFP8Input
||
!
IsArrayOfPointersGemm
),
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now."
);
// For fp32 types, map to tf32 MMA value type
using
ElementAMma
=
cute
::
conditional_t
<
cute
::
is_same_v
<
ElementA
,
float
>
,
tfloat32_t
,
ElementA
>
;
using
ElementBMma
=
cute
::
conditional_t
<
cute
::
is_same_v
<
ElementB
,
float
>
,
tfloat32_t
,
ElementB
>
;
static
constexpr
cute
::
GMMA
::
Major
GmmaMajorA
=
detail
::
gmma_ss_tag_to_major_A
<
ElementAMma
,
GmemLayoutATag
>
();
static
constexpr
cute
::
GMMA
::
Major
GmmaMajorB
=
detail
::
gmma_ss_tag_to_major_B
<
ElementBMma
,
GmemLayoutBTag
>
();
static
constexpr
bool
IsCooperative
=
cute
::
is_any_of_v
<
KernelScheduleType
,
KernelTmaWarpSpecializedCooperative
,
KernelPtrArrayTmaWarpSpecializedCooperative
,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>>
;
using
AtomLayoutMNK
=
cute
::
conditional_t
<
IsCooperative
,
Layout
<
Shape
<
_2
,
_1
,
_1
>>
,
Layout
<
Shape
<
_1
,
_1
,
_1
>>>
;
using
TiledMma
=
decltype
(
cute
::
make_tiled_mma
(
cute
::
GMMA
::
ss_op_selector
<
ElementAMma
,
ElementBMma
,
ElementAccumulator
,
TileShape_MNK
,
GmmaMajorA
,
GmmaMajorB
>
(),
AtomLayoutMNK
{}));
using
GmemTiledCopyA
=
decltype
(
detail
::
sm90_cluster_shape_to_tma_atom
(
shape
<
1
>
(
ClusterShape_MNK
{})));
using
GmemTiledCopyB
=
decltype
(
detail
::
sm90_cluster_shape_to_tma_atom
(
shape
<
0
>
(
ClusterShape_MNK
{})));
using
SmemLayoutAtomA
=
decltype
(
detail
::
ss_smem_selector
<
GmmaMajorA
,
ElementAMma
,
decltype
(
cute
::
get
<
0
>
(
TileShape_MNK
{})),
decltype
(
cute
::
get
<
2
>
(
TileShape_MNK
{}))
>
());
using
SmemLayoutAtomB
=
decltype
(
detail
::
ss_smem_selector
<
GmmaMajorB
,
ElementBMma
,
decltype
(
cute
::
get
<
1
>
(
TileShape_MNK
{})),
decltype
(
cute
::
get
<
2
>
(
TileShape_MNK
{}))
>
());
static
constexpr
size_t
TensorMapStorage
=
IsArrayOfPointersGemm
?
sizeof
(
cute
::
TmaDescriptor
)
*
2
/* for A and B */
:
0
;
static
constexpr
int
KernelSmemCarveout
=
static_cast
<
int
>
(
TensorMapStorage
);
static
constexpr
int
PipelineStages
=
detail
::
compute_stage_count_or_override
<
detail
::
sm90_smem_capacity_bytes
-
KernelSmemCarveout
,
ElementAMma
,
ElementBMma
,
TileShape_MNK
>
(
StageCountType
{});
using
DispatchPolicy
=
MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
<
PipelineStages
,
ClusterShape_MNK
,
KernelScheduleType
,
ScaleGranularityM
>
;
using
SmemCopyAtomA
=
void
;
using
SmemCopyAtomB
=
void
;
using
CollectiveOp
=
CollectiveMma
<
DispatchPolicy
,
TileShape_MNK
,
ElementA
,
TagToStrideA_t
<
GmemLayoutATag
>
,
ElementB
,
TagToStrideB_t
<
GmemLayoutBTag
>
,
TiledMma
,
GmemTiledCopyA
,
SmemLayoutAtomA
,
SmemCopyAtomA
,
cute
::
identity
,
GmemTiledCopyB
,
SmemLayoutAtomB
,
SmemCopyAtomB
,
cute
::
identity
>
;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp
0 → 100644
View file @
9798b2fb
// clang-format off
// adapted from: https://github.com/soundOfDestiny/cutlass/blob/a4208aa6958864923505cade9c63eb2a6daf16e5/include/cutlass/gemm/collective/fp8_accumulation.hpp
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/algorithm/clear.hpp"
#include "cute/tensor.hpp"
//////////////////////////////////////////////////////////////////////////////
///////////////////////////////////FP8 Accumulation///////////////////////////
//////////////////////////////////////////////////////////////////////////////
/// This class provides API to promote (add) or scale (multiply_add) the results
/// from the tensor core accumulators to the main accumulators when the number
/// of MMAs reaches the max number of MMA interval specified by user, after that
/// the tensor core accumulators are zeroed.
//////////////////////////////////////////////////////////////////////////////
namespace
cutlass
::
gemm
::
collective
{
template
<
class
EngineAccum
,
class
LayoutAccum
>
struct
GmmaFP8AccumulationWithScale
{
using
TensorAccum
=
cute
::
Tensor
<
EngineAccum
,
LayoutAccum
>
;
using
ElementAccumulator
=
typename
EngineAccum
::
value_type
;
static_assert
(
is_static
<
LayoutAccum
>::
value
,
"Accumulator Layout should be static"
);
static_assert
(
is_rmem
<
TensorAccum
>::
value
,
"Accumulator tensor must be rmem resident."
);
private:
TensorAccum
&
accum_
;
TensorAccum
accum_temp_
;
uint32_t
accum_promotion_interval_
;
// defines the max num of executed MMAs after which accum should be promoted.
uint32_t
mma_count_per_mainloop_iteration_
;
// num of MMAs per k_tile of mainloop
uint32_t
mma_count_
;
// current executed MMAs
uint32_t
reset_accum_flag_
;
// accum needs to be zeroed or not.
// promote or `add` the partial accumulators to main accumulator (FADD).
CUTLASS_DEVICE
void
promote_core
()
{
warpgroup_wait
<
0
>
();
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
accum_
);
++
i
)
{
accum_
(
i
)
+=
accum_temp_
(
i
);
}
}
// `multiply` scale the partial accumulators and `add` to main accumulator (FFMA).
template
<
class
EngineScale
,
class
LayoutScale
>
CUTLASS_DEVICE
void
scale_core
(
const
cute
::
Tensor
<
EngineScale
,
LayoutScale
>
&
scale
)
{
using
TensorScale
=
cute
::
Tensor
<
EngineScale
,
LayoutScale
>
;
static_assert
(
is_static
<
LayoutScale
>::
value
,
"Scale Layout should be static"
);
static_assert
(
is_rmem
<
TensorScale
>::
value
,
"Scale tensor must be rmem resident."
);
static_assert
(
LayoutAccum
{}.
shape
()
==
LayoutScale
{}.
shape
(),
"Accumulator and scale must have same shape."
);
warpgroup_wait
<
0
>
();
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
accum_
);
++
i
)
{
accum_
(
i
)
+=
accum_temp_
(
i
)
*
scale
(
i
);
}
}
public:
CUTLASS_DEVICE
GmmaFP8AccumulationWithScale
(
TensorAccum
&
accum
,
uint32_t
accum_promotion_interval
,
uint32_t
mma_count_per_mainloop_iteration
)
:
accum_
(
accum
),
accum_promotion_interval_
(
accum_promotion_interval
),
mma_count_per_mainloop_iteration_
(
mma_count_per_mainloop_iteration
),
mma_count_
(
0
),
reset_accum_flag_
(
0
)
{
accum_temp_
=
cute
::
make_fragment_like
(
accum
);
}
//
// Methods (Common)
//
CUTLASS_DEVICE
TensorAccum
&
operator
()()
{
return
accum_temp_
;
}
/// prepare the MMA accumulators when initialization or zeroing is required.
CUTLASS_DEVICE
bool
prepare_if_needed
()
{
return
reset_accum_flag_
;
}
//
// Methods (for FADD version)
//
/// promote (add) the results from the MMA accumulators to main accumulator if needed.
CUTLASS_DEVICE
void
promote_if_needed
()
{
mma_count_
+=
mma_count_per_mainloop_iteration_
;
reset_accum_flag_
=
__shfl_sync
(
0xffffffff
,
mma_count_
==
accum_promotion_interval_
,
0
);
if
(
reset_accum_flag_
)
{
promote_core
();
mma_count_
=
0
;
}
}
/// promote (add) the residue results from the MMA accumulators to main accumulator if needed.
CUTLASS_DEVICE
void
promote_residue_if_needed
()
{
if
(
__shfl_sync
(
0xffffffff
,
mma_count_
>
0
,
0
))
{
promote_core
();
}
}
//
// Methods (for FFMA version)
//
/// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed.
template
<
class
EngineScale
,
class
LayoutScale
>
CUTLASS_DEVICE
void
scale_if_needed
(
const
cute
::
Tensor
<
EngineScale
,
LayoutScale
>
&
scale
)
{
mma_count_
+=
mma_count_per_mainloop_iteration_
;
reset_accum_flag_
=
__shfl_sync
(
0xffffffff
,
mma_count_
==
accum_promotion_interval_
,
0
);
if
(
reset_accum_flag_
)
{
scale_core
(
scale
);
mma_count_
=
0
;
}
}
/// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed.
template
<
class
EngineScale
,
class
LayoutScale
>
CUTLASS_DEVICE
void
scale_residue_if_needed
(
const
cute
::
Tensor
<
EngineScale
,
LayoutScale
>
&
scale
)
{
if
(
__shfl_sync
(
0xffffffff
,
mma_count_
>
0
,
0
))
{
scale_core
(
scale
);
}
}
};
}
// namespace cutlass::gemm::collective
csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
0 → 100644
View file @
9798b2fb
This diff is collapsed.
Click to expand it.
csrc/cutlass_extensions/gemm/dispatch_policy.hpp
0 → 100644
View file @
9798b2fb
#pragma once
#include "cutlass/gemm/dispatch_policy.hpp"
namespace
cutlass
::
gemm
{
//////////////////////////////////////////////////////////////////////////////
// FP8 related policies (including Blocked Scaled Accumulation)
// `ScaleGranularityM` specifies scaling granularity along M, while zero-value
// `ScaleGranularityM` indicates that scaling granularity is
// `size<0>(TileShape_MNK{})` along M.
template
<
int
ScaleGranularityM
=
0
>
struct
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
:
KernelTmaWarpSpecializedCooperative
{};
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
// specialized dynamic schedule For FP8 kernels with Block Scaling
template
<
int
Stages_
,
class
ClusterShape_
=
Shape
<
_1
,
_1
,
_1
>,
class
KernelSchedule
=
KernelTmaWarpSpecialized
,
int
ScaleGranularityM
=
0
// `ScaleGranularityM` specifies scaling granularity along M,
// while zero-value `ScaleGranularityM` indicates that scaling
// granularity is `size<0>(TileShape_MNK{})` along M.
>
struct
MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
:
MainloopSm90TmaGmmaWarpSpecialized
<
Stages_
,
ClusterShape_
,
KernelSchedule
>
{
static_assert
(
cute
::
is_same_v
<
KernelSchedule
,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>>
,
"KernelSchedule must be one of the warp specialized policies"
);
};
//////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::gemm
\ No newline at end of file
csrc/cutlass_extensions/vllm_collective_builder.cuh
View file @
9798b2fb
#pragma once
#pragma once
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass
_extensions
/gemm/collective/collective_builder.hpp"
namespace
cutlass
::
gemm
::
collective
{
namespace
cutlass
::
gemm
::
collective
{
using
namespace
cute
;
using
namespace
cute
;
...
...
csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
0 → 100644
View file @
9798b2fb
#pragma once
// clang-format will break include orders
// clang-format off
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "core/math.hpp"
#include "cutlass_extensions/common.hpp"
// clang-format on
namespace
vllm
::
c3x
{
static
inline
cute
::
Shape
<
int
,
int
,
int
,
int
>
get_problem_shape
(
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
)
{
int32_t
m
=
a
.
size
(
0
),
n
=
b
.
size
(
1
),
k
=
a
.
size
(
1
);
return
{
m
,
n
,
k
,
1
};
}
template
<
typename
GemmKernel
>
void
cutlass_gemm_caller
(
torch
::
Device
device
,
cute
::
Shape
<
int
,
int
,
int
,
int
>
prob_shape
,
typename
GemmKernel
::
MainloopArguments
mainloop_args
,
typename
GemmKernel
::
EpilogueArguments
epilogue_args
)
{
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
prob_shape
,
mainloop_args
,
epilogue_args
};
// Launch the CUTLASS GEMM kernel.
using
GemmOp
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
GemmOp
gemm_op
;
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
device
);
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
device
.
index
());
cutlass
::
Status
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
CUTLASS_CHECK
(
status
);
}
template
<
typename
Gemm
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_params
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideA
=
cute
::
Stride
<
int64_t
,
cute
::
Int
<
1
>
,
int64_t
>
;
using
StrideB
=
cute
::
Stride
<
int64_t
,
cute
::
Int
<
1
>
,
int64_t
>
;
using
StrideC
=
typename
Gemm
::
StrideC
;
StrideA
a_stride
{
lda
,
cute
::
Int
<
1
>
{},
0
};
StrideB
b_stride
{
ldb
,
cute
::
Int
<
1
>
{},
0
};
StrideC
c_stride
{
ldc
,
cute
::
Int
<
1
>
{},
cute
::
Int
<
0
>
{}};
typename
GemmKernel
::
ProblemShape
prob_shape
=
get_problem_shape
(
a
,
b
);
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
};
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
Gemm
::
Epilogue
::
prepare_args
(
std
::
forward
<
EpilogueArgs
>
(
epilogue_params
)...),
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
cutlass_gemm_caller
<
GemmKernel
>
(
a
.
device
(),
prob_shape
,
mainloop_args
,
epilogue_args
);
}
}
// namespace vllm::c3x
\ No newline at end of file
csrc/quantization/cutlass_w8a8/scaled_mm
_c3x
.cuh
→
csrc/quantization/cutlass_w8a8/
c3x/
scaled_mm.cuh
View file @
9798b2fb
...
@@ -2,9 +2,6 @@
...
@@ -2,9 +2,6 @@
// clang-format will break include orders
// clang-format will break include orders
// clang-format off
// clang-format off
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "cutlass/cutlass.h"
#include "cutlass/cutlass.h"
...
@@ -32,21 +29,6 @@ using namespace cute;
...
@@ -32,21 +29,6 @@ using namespace cute;
namespace
vllm
{
namespace
vllm
{
// A wrapper for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template
<
typename
Kernel
>
struct
enable_sm90_or_later
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
void
operator
()(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
Kernel
::
operator
()(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue_
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
ClusterShape
,
typename
KernelSchedule
,
typename
TileShape
,
typename
ClusterShape
,
typename
KernelSchedule
,
...
@@ -101,60 +83,4 @@ struct cutlass_3x_gemm {
...
@@ -101,60 +83,4 @@ struct cutlass_3x_gemm {
struct
GemmKernel
:
public
KernelType
{};
struct
GemmKernel
:
public
KernelType
{};
};
};
template
<
typename
Gemm
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_params
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
int32_t
m
=
a
.
size
(
0
);
int32_t
n
=
b
.
size
(
1
);
int32_t
k
=
a
.
size
(
1
);
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideB
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideC
=
typename
Gemm
::
StrideC
;
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
0
};
StrideB
b_stride
{
ldb
,
Int
<
1
>
{},
0
};
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
typename
GemmKernel
::
ProblemShape
prob_shape
{
m
,
n
,
k
,
1
};
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
};
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
Gemm
::
Epilogue
::
prepare_args
(
std
::
forward
<
EpilogueArgs
>
(
epilogue_params
)...),
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
prob_shape
,
mainloop_args
,
epilogue_args
};
// Launch the CUTLASS GEMM kernel.
using
GemmOp
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
GemmOp
gemm_op
;
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
cutlass
::
Status
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
CUTLASS_CHECK
(
status
);
}
}
// namespace vllm
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu
0 → 100644
View file @
9798b2fb
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace
vllm
{
void
cutlass_scaled_mm_azp_sm90_int8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
if
(
azp
)
{
return
cutlass_scaled_mm_sm90_int8_epilogue
<
c3x
::
ScaledEpilogueBiasAzpToken
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
}
else
{
return
cutlass_scaled_mm_sm90_int8_epilogue
<
c3x
::
ScaledEpilogueBiasAzp
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu
0 → 100644
View file @
9798b2fb
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace
vllm
{
void
cutlass_scaled_mm_blockwise_sm90_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
cutlass_gemm_blockwise_sm90_fp8_dispatch
<
cutlass
::
bfloat16_t
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
cutlass_gemm_blockwise_sm90_fp8_dispatch
<
cutlass
::
half_t
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
// namespace vllm
\ No newline at end of file
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
0 → 100644
View file @
9798b2fb
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_gemm_caller.cuh"
namespace
vllm
{
using
namespace
cute
;
template
<
typename
OutType
,
int
GroupSizeM_
,
int
GroupSizeN_
,
int
GroupSizeK_
,
int
TileSizeM_
=
128
,
class
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
>
struct
cutlass_3x_gemm_fp8_blockwise
{
using
GroupSizeM
=
Int
<
GroupSizeM_
>
;
using
GroupSizeN
=
Int
<
GroupSizeN_
>
;
using
GroupSizeK
=
Int
<
GroupSizeK_
>
;
using
TileSizeM
=
Int
<
TileSizeM_
>
;
static_assert
(
TileSizeM_
%
GroupSizeM_
==
0
,
"TileSizeM must be a multiple of GroupSizeM"
);
using
ElementAB
=
cutlass
::
float_e4m3_t
;
using
ElementA
=
ElementAB
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
using
ElementB
=
ElementAB
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
using
ElementD
=
OutType
;
using
StrideD
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
using
ElementC
=
void
;
using
StrideC
=
StrideD
;
static
constexpr
int
AlignmentC
=
AlignmentD
;
using
ElementAccumulator
=
float
;
using
ElementBlockScale
=
float
;
using
ElementCompute
=
float
;
using
ArchTag
=
cutlass
::
arch
::
Sm90
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
TileShape
=
Shape
<
TileSizeM
,
GroupSizeN
,
GroupSizeK
>
;
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
GroupSizeM_
>
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
EpilogueTileType
=
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
;
using
StoreEpilogueCompute
=
typename
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
TileShape
,
ClusterShape
,
EpilogueTileType
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
StrideC
,
AlignmentC
,
ElementD
,
StrideD
,
AlignmentD
,
EpilogueSchedule
,
StoreEpilogueCompute
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
LayoutA
,
AlignmentA
,
ElementB
,
LayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
KernelSchedule
>::
CollectiveOp
;
using
KernelType
=
enable_sm90_or_later
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
cutlass
::
gemm
::
PersistentScheduler
>>
;
struct
GemmKernel
:
public
KernelType
{};
using
StrideA
=
typename
GemmKernel
::
StrideA
;
using
StrideB
=
typename
GemmKernel
::
StrideB
;
};
template
<
typename
Gemm
>
void
cutlass_gemm_caller_blockwise
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
auto
prob_shape
=
c3x
::
get_problem_shape
(
a
,
b
);
int32_t
m
=
get
<
0
>
(
prob_shape
),
n
=
get
<
1
>
(
prob_shape
),
k
=
get
<
2
>
(
prob_shape
);
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideB
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideC
=
typename
Gemm
::
StrideC
;
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
0
};
StrideB
b_stride
{
ldb
,
Int
<
1
>
{},
0
};
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
auto
a_scales_ptr
=
static_cast
<
float
*>
(
a_scales
.
data_ptr
());
auto
b_scales_ptr
=
static_cast
<
float
*>
(
b_scales
.
data_ptr
());
// Check is the t is contiguous and is 1D or 2D with one of the dimensions
// being 1 (i.e. a row or column vector)
auto
is_contiguous_vector
=
[](
const
torch
::
Tensor
&
t
)
{
auto
t_sizes
=
t
.
sizes
();
return
t
.
is_contiguous
()
&&
(
t
.
dim
()
==
1
||
(
t
.
dim
()
==
2
&&
*
std
::
min_element
(
t_sizes
.
begin
(),
t_sizes
.
end
())
==
1
));
};
// TODO(lucas): lets clean-up the kernel so that we pass in Strides so
// we don't have to deal with enforcing implicit layouts
TORCH_CHECK
(
a_scales
.
size
(
0
)
==
m
/
Gemm
::
GroupSizeM
::
value
);
TORCH_CHECK
(
a_scales
.
size
(
1
)
==
k
/
Gemm
::
GroupSizeK
::
value
);
TORCH_CHECK
(
a_scales
.
stride
(
0
)
==
1
||
is_contiguous_vector
(
a_scales
),
"a_scales must be M major"
);
TORCH_CHECK
(
b_scales
.
size
(
0
)
==
k
/
Gemm
::
GroupSizeK
::
value
);
TORCH_CHECK
(
b_scales
.
size
(
1
)
==
n
/
Gemm
::
GroupSizeN
::
value
);
TORCH_CHECK
(
b_scales
.
stride
(
0
)
==
1
||
is_contiguous_vector
(
b_scales
),
"b_scales must be K major"
);
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
,
a_scales_ptr
,
b_scales_ptr
};
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
{},
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
c3x
::
cutlass_gemm_caller
<
GemmKernel
>
(
a
.
device
(),
prob_shape
,
mainloop_args
,
epilogue_args
);
}
template
<
typename
OutType
>
void
cutlass_gemm_blockwise_sm90_fp8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
128
,
128
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
// namespace vllm
\ No newline at end of file
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
0 → 100644
View file @
9798b2fb
#pragma once
#include <torch/all.h>
namespace
vllm
{
void
cutlass_scaled_mm_sm90_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_sm90_int8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_azp_sm90_int8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_blockwise_sm90_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu
0 → 100644
View file @
9798b2fb
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace
vllm
{
void
cutlass_scaled_mm_sm90_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm90_fp8_epilogue
<
c3x
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_mm_sm90_fp8_epilogue
<
c3x
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_
c3x_
sm90_fp8_dispatch.cuh
→
csrc/quantization/cutlass_w8a8/
c3x/
scaled_mm_sm90_fp8_dispatch.cuh
View file @
9798b2fb
#pragma once
#pragma once
#include "scaled_mm_c3x.cuh"
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
/**
/**
* This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
* This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
...
@@ -9,6 +10,8 @@
...
@@ -9,6 +10,8 @@
namespace
vllm
{
namespace
vllm
{
using
c3x
::
cutlass_gemm_caller
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_default
{
struct
sm90_fp8_config_default
{
...
@@ -93,4 +96,25 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
...
@@ -93,4 +96,25 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
}
}
}
}
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm90_fp8_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_args
)
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
// namespace vllm
}
// namespace vllm
\ No newline at end of file
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
0 → 100644
View file @
9798b2fb
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace
vllm
{
void
cutlass_scaled_mm_sm90_int8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm90_int8_epilogue
<
c3x
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_mm_sm90_int8_epilogue
<
c3x
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_
c3x_
sm90_int8_dispatch.cuh
→
csrc/quantization/cutlass_w8a8/
c3x/
scaled_mm_sm90_int8_dispatch.cuh
View file @
9798b2fb
#pragma once
#pragma once
#include "scaled_mm_c3x.cuh"
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
/**
/**
* This file defines Gemm kernel configurations for SM90 (int8) based on the
* This file defines Gemm kernel configurations for SM90 (int8) based on the
...
@@ -9,6 +10,8 @@
...
@@ -9,6 +10,8 @@
namespace
vllm
{
namespace
vllm
{
using
c3x
::
cutlass_gemm_caller
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_default
{
struct
sm90_int8_config_default
{
...
@@ -137,4 +140,24 @@ inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
...
@@ -137,4 +140,24 @@ inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
}
}
}
}
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm90_int8_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_args
)
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
// namespace vllm
}
// namespace vllm
\ No newline at end of file
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
View file @
9798b2fb
#include <cudaTypedefs.h>
#include <cudaTypedefs.h>
#include "c3x/scaled_mm_kernels.hpp"
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#include "core/math.hpp"
#include "scaled_mm_c3x_sm90_fp8_dispatch.cuh"
#include "scaled_mm_c3x_sm90_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
using
namespace
vllm
;
/*
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper) or later.
NVIDIA GPUs with sm90a (Hopper) or later.
*/
*/
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm90_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_args
)
{
if
(
a
.
dtype
()
==
torch
::
kInt8
)
{
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
void
cutlass_scaled_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
...
@@ -54,14 +15,50 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
...
@@ -54,14 +15,50 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
c
.
dtype
(),
using
GroupShape
=
std
::
array
<
int64_t
,
2
>
;
"currently bias dtype must match output dtype "
,
c
.
dtype
());
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogueBias
>
(
int
M
=
a
.
size
(
0
),
N
=
b
.
size
(
1
),
K
=
a
.
size
(
1
);
c
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
GroupShape
a_scale_group_shape
=
[
&
,
&
s
=
a_scales
]()
->
GroupShape
{
if
(
s
.
numel
()
==
1
)
return
{
M
,
K
};
// tensor-wise
if
(
s
.
dim
()
==
2
)
return
{
ceil_div
(
a
.
size
(
0
),
s
.
size
(
0
)),
ceil_div
(
a
.
size
(
1
),
s
.
size
(
1
))};
TORCH_CHECK
(
false
,
"Unsupported scale shape for scale_a"
);
}();
GroupShape
b_scale_group_shape
=
[
&
,
&
s
=
b_scales
]()
->
GroupShape
{
if
(
s
.
numel
()
==
1
)
return
{
K
,
N
};
// tensor-wise
if
(
s
.
dim
()
==
2
)
return
{
ceil_div
(
b
.
size
(
0
),
s
.
size
(
0
)),
ceil_div
(
b
.
size
(
1
),
s
.
size
(
1
))};
TORCH_CHECK
(
false
,
"Unsupported scale shape for scale_b"
);
}();
if
((
a_scale_group_shape
==
GroupShape
{
M
,
K
}
||
a_scale_group_shape
==
GroupShape
{
1
,
K
})
&&
(
b_scale_group_shape
==
GroupShape
{
K
,
N
}
||
b_scale_group_shape
==
GroupShape
{
K
,
1
}))
{
// "standard per-tensor/per-token/per-channel" scaling
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
)
{
vllm
::
cutlass_scaled_mm_sm90_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
vllm
::
cutlass_scaled_mm_sm90_int8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
}
else
if
(
a_scale_group_shape
==
GroupShape
{
1
,
128
}
&&
b_scale_group_shape
==
GroupShape
{
128
,
128
})
{
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
&&
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
,
"Currently only FP8 is supported for A group shape 1x128 and "
"B group shape 128x128"
);
TORCH_CHECK
(
!
bias
,
"Bias not yet supported blockwise scaled_mm"
);
vllm
::
cutlass_scaled_mm_blockwise_sm90_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
}
else
{
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogue
>
(
TORCH_CHECK
(
false
,
"Unsupported scale group shapes for CUTLASS 3.x GEMM"
);
c
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
}
...
@@ -75,13 +72,6 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
...
@@ -75,13 +72,6 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
azp
)
{
vllm
::
cutlass_scaled_mm_azp_sm90_int8
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogueBiasAzpToken
>
(
azp
,
bias
);
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
}
else
{
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogueBiasAzp
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
}
}
}
#endif
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment