Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8d59dbb0
Unverified
Commit
8d59dbb0
authored
Aug 06, 2024
by
Luka Govedič
Committed by
GitHub
Aug 06, 2024
Browse files
[Kernel] Add per-tensor and per-token AZP epilogues (#5941)
Co-authored-by:
Tyler Michael Smith
<
tyler@neuralmagic.com
>
parent
5c60c8c4
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1175 additions
and
153 deletions
+1175
-153
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
+107
-78
csrc/ops.h
csrc/ops.h
+8
-0
csrc/quantization/cutlass_w8a8/Epilogues.md
csrc/quantization/cutlass_w8a8/Epilogues.md
+147
-0
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp
...quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp
+151
-1
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
+57
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
+217
-36
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
+230
-28
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+104
-7
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+10
-1
tests/kernels/test_cutlass.py
tests/kernels/test_cutlass.py
+119
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+25
-1
No files found.
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
View file @
8d59dbb0
...
...
@@ -32,7 +32,6 @@ def to_int8(tensor: torch.Tensor) -> torch.Tensor:
def
make_rand_tensors
(
dtype
:
torch
.
dtype
,
m
:
int
,
n
:
int
,
k
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
a
=
torch
.
randn
((
m
,
k
),
device
=
'cuda'
)
*
5
b
=
torch
.
randn
((
n
,
k
),
device
=
'cuda'
).
t
()
*
5
...
...
@@ -44,59 +43,18 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
raise
ValueError
(
"unsupported dtype"
)
# impl
def
pytorch_mm_impl
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
)
->
torch
.
Tensor
:
return
torch
.
mm
(
a
,
b
)
def
pytorch_fp8_impl
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
)
->
torch
.
Tensor
:
return
torch
.
_scaled_mm
(
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
out_dtype
)
def
pytorch_fp8_impl_fast_accum
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
)
->
torch
.
Tensor
:
return
torch
.
_scaled_mm
(
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
out_dtype
,
use_fast_accum
=
True
)
def
cutlass_impl
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
)
->
torch
.
Tensor
:
return
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
=
out_dtype
)
# bench
def
bench_fn
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
label
:
str
,
sub_label
:
str
,
fn
:
Callable
,
description
:
str
)
->
TMeasurement
:
def
bench_fn
(
label
:
str
,
sub_label
:
str
,
description
:
str
,
fn
:
Callable
,
*
args
,
**
kwargs
)
->
TMeasurement
:
min_run_time
=
1
globals
=
{
"a"
:
a
,
"b"
:
b
,
"scale_a"
:
scale_a
,
"scale_b"
:
scale_b
,
"out_dtype"
:
out_dtype
,
"args"
:
args
,
"kwargs"
:
kwargs
,
"fn"
:
fn
,
}
return
TBenchmark
.
Timer
(
stmt
=
"fn(
a, b, scale_a, scale_b, out_dtype
)"
,
stmt
=
"fn(
*args, **kwargs
)"
,
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
...
...
@@ -110,26 +68,58 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
a
,
b
=
make_rand_tensors
(
torch
.
int8
,
m
,
n
,
k
)
scale_a
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
bias
=
torch
.
zeros
((
n
,
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
azp
=
torch
.
zeros
((
m
,
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
azp_adj
=
torch
.
zeros
((
n
,
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
timers
=
[]
# pytorch impl - bfloat16
timers
.
append
(
bench_fn
(
a
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
b
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
scale_a
,
scale_b
,
torch
.
bfloat16
,
label
,
sub_label
,
pytorch_mm_impl
,
"pytorch_bf16_bf16_bf16_matmul-no-scales"
))
bench_fn
(
label
,
sub_label
,
"pytorch_bf16_bf16_bf16_matmul-no-scales"
,
torch
.
mm
,
a
.
to
(
dtype
=
torch
.
bfloat16
),
b
.
to
(
dtype
=
torch
.
bfloat16
)))
# pytorch impl - float16
timers
.
append
(
bench_fn
(
a
.
to
(
dtype
=
torch
.
float16
,
device
=
"cuda"
),
b
.
to
(
dtype
=
torch
.
float16
,
device
=
"cuda"
),
scale_a
,
scale_b
,
torch
.
float16
,
label
,
sub_label
,
pytorch_mm_impl
,
"pytorch_fp16_fp16_fp16_matmul-no-scales"
))
bench_fn
(
label
,
sub_label
,
"pytorch_fp16_fp16_fp16_matmul-no-scales"
,
torch
.
mm
,
a
.
to
(
dtype
=
torch
.
float16
),
b
.
to
(
dtype
=
torch
.
float16
)))
# cutlass impl
timers
.
append
(
bench_fn
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
label
,
sub_label
,
cutlass_impl
,
"cutlass_i8_i8_bf16_scaled_mm"
))
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
))
# cutlass with bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm_bias"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
bias
))
# cutlass with azp per-tensor
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm_azp"
,
ops
.
cutlass_scaled_mm_azp
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
))
# cutlass with azp per-tensor + bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm_azp_bias"
,
ops
.
cutlass_scaled_mm_azp
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
,
None
,
bias
))
# cutlass with azp per-token
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm_azp_pt"
,
ops
.
cutlass_scaled_mm_azp
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
,
azp
))
# cutlass with azp per-token + bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias"
,
ops
.
cutlass_scaled_mm_azp
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
,
azp
,
bias
))
return
timers
...
...
@@ -140,46 +130,88 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
a
,
b
=
make_rand_tensors
(
torch
.
float8_e4m3fn
,
m
,
n
,
k
)
scale_a
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
bias
=
torch
.
zeros
((
n
,
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
timers
=
[]
# pytorch impl w. bf16
timers
.
append
(
bench_fn
(
a
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
b
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
scale_a
,
scale_b
,
torch
.
bfloat16
,
label
,
sub_label
,
pytorch_mm_impl
,
"pytorch_bf16_bf16_bf16_matmul-no-scales"
))
bench_fn
(
label
,
sub_label
,
"pytorch_bf16_bf16_bf16_matmul-no-scales"
,
torch
.
mm
,
a
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
b
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)))
# pytorch impl: bf16 output, without fp8 fast accum
timers
.
append
(
bench_fn
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
label
,
sub_label
,
pytorch_fp8_impl
,
"pytorch_fp8_fp8_bf16_scaled_mm"
))
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_bf16_scaled_mm"
,
torch
.
_scaled_mm
,
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
bfloat16
))
# pytorch impl: bf16 output, with fp8 fast accum
timers
.
append
(
bench_fn
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
label
,
sub_label
,
pytorch_fp8_impl_fast_accum
,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum"
))
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum"
,
torch
.
_scaled_mm
,
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
bfloat16
,
use_fast_accum
=
True
))
# pytorch impl: fp16 output, without fp8 fast accum
timers
.
append
(
bench_fn
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
,
label
,
sub_label
,
pytorch_fp8_impl
,
"pytorch_fp8_fp8_fp16_scaled_mm"
))
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_fp16_scaled_mm"
,
torch
.
_scaled_mm
,
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
float16
))
# pytorch impl: fp16 output, with fp8 fast accum
timers
.
append
(
bench_fn
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
,
label
,
sub_label
,
pytorch_fp8_impl_fast_accum
,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum"
))
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum"
,
torch
.
_scaled_mm
,
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
float16
,
use_fast_accum
=
True
))
# cutlass impl: bf16 output
timers
.
append
(
bench_fn
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
label
,
sub_label
,
cutlass_impl
,
"cutlass_fp8_fp8_bf16_scaled_mm"
))
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_bf16_scaled_mm"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
))
# cutlass impl: fp16 output
timers
.
append
(
bench_fn
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
,
label
,
sub_label
,
cutlass_impl
,
"cutlass_fp8_fp8_fp16_scaled_mm"
))
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_fp16_scaled_mm"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
))
# cutlass impl: bf16 output, with bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_bf16_scaled_mm_bias"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
bias
))
# cutlass impl: fp16 output, with bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_fp16_scaled_mm_bias"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
,
bias
.
to
(
dtype
=
torch
.
float16
)))
return
timers
...
...
@@ -200,7 +232,6 @@ def print_timers(timers: Iterable[TMeasurement]):
def
run
(
dtype
:
torch
.
dtype
,
MKNs
:
Iterable
[
Tuple
[
int
,
int
,
int
]])
->
Iterable
[
TMeasurement
]:
results
=
[]
for
m
,
k
,
n
in
MKNs
:
timers
=
bench
(
dtype
,
m
,
k
,
n
,
f
"scaled-
{
dtype
}
-gemm"
,
...
...
@@ -216,7 +247,6 @@ def make_output(data: Iterable[TMeasurement],
MKNs
:
Iterable
[
Tuple
[
int
,
int
,
int
]],
base_description
:
str
,
timestamp
=
None
):
print
(
f
"== All Results
{
base_description
}
===="
)
print_timers
(
data
)
...
...
@@ -251,7 +281,6 @@ def run_range_bench(args):
def
run_model_bench
(
args
):
print
(
"Benchmarking models:"
)
for
i
,
model
in
enumerate
(
args
.
models
):
print
(
f
"[
{
i
}
]
{
model
}
"
)
...
...
csrc/ops.h
View file @
8d59dbb0
...
...
@@ -128,6 +128,14 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_azp
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
torch
::
Tensor
marlin_qqq_gemm
(
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b_q_weight
,
torch
::
Tensor
const
&
s_tok
,
...
...
csrc/quantization/cutlass_w8a8/Epilogues.md
0 → 100644
View file @
8d59dbb0
# CUTLASS Epilogues
## Introduction
This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs.
Currently, we only support symmetric quantization for weights,
and symmetric and asymmetric quantization for activations.
Both can be quantized per-tensor or per-channel (weights) / per-token (activations).
There are 4 epilogues:
1.
ScaledEpilogue: symmetric quantization for activations, no bias.
1.
ScaledEpilogueBias: symmetric quantization for activations, supports bias.
1.
ScaledEpilogueAzp: asymmetric per-tensor quantization for activations, supports bias.
1.
ScaledEpilogueAzpPerToken: asymmetric per-token quantization for activations, supports bias.
We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size.
Instead, if no bias is passed, the epilogue will use 0 as the bias.
That induces a redundant addition operation (and runtime check), but the performance impact is minor.
## Underlying Linear Algebra
More details available in the
[
Activation Quantization RFC
](
https://github.com/vllm-project/vllm/issues/3975
)
.
If $
` \widehat X `
$ is the quantized $
` X `
$, our matrices become the following
```
math
A = s_a (\widehat A - J_a z_a)
```
```
math
B = s_b \widehat B
```
```
math
D = A B + C
```
```
math
D = s_a s_b \widehat D + C
```
Here, D is the output of the GEMM, and C is the bias.
A is the activations and supports asymmetric quantization,
and B is the weights and only supports symmetric quantization.
$ s_a $ and $s_b$ are the scales for activations and weights, respectively.
$ z_a $ is the zero-point for activations, and $ J_a $ is the matrix of all ones with dimensions of A.
Additional epilogues would be required to support asymmetric quantization for weights.
Expanding further, we can calculate $
` \widehat D `
$ as follows:
```
math
A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B
```
```
math
A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right)
```
```
math
\widehat D = \widehat A \widehat B - z_a J_a \widehat B
```
Note that $
` \widehat A \widehat B `
$ is the raw output of the GEMM,
and $
` J_a \widehat B `
$ is known ahead of time.
Each row of it is equal to $
` \mathbf 1 \widehat B `
$, which is a row-vector of column sums of $
` \widehat B `
$.
## Epilogues
### ScaledEpilogue
This epilogue computes the symmetric quantization for activations without bias, meaning $
` C = 0 `
$ and $
` z_a = 0 `
$.
The output of the GEMM is:
```
math
\widehat D = \widehat A \widehat B
```
```
math
D = s_a s_b \widehat D
```
```
math
D = s_a s_b \widehat A \widehat B
```
Epilogue parameters:
-
`scale_a`
is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
-
`scale_b`
is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
### ScaledEpilogueBias
This epilogue computes the symmetric quantization for activations with bias, meaning $
` z_a = 0 `
$.
The output of the GEMM is:
```
math
\widehat D = \widehat A \widehat B
```
```
math
D = s_a s_b \widehat D + C
```
```
math
D = s_a s_b \widehat A \widehat B + C
```
Epilogue parameters:
-
`scale_a`
is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
-
`scale_b`
is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
-
`bias`
is the bias, is always per-channel (row-vector).
### ScaledEpilogueAzp
This epilogue computes the asymmetric per-tensor quantization for activations with bias.
The output of the GEMM is:
```
math
\widehat D = \widehat A \widehat B - z_a J_a \widehat B
```
```
math
D = s_a s_b \widehat D + C
```
```
math
D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C
```
Because $
` z_a `
$ is a scalar, the zero-point term $
` z_a J_a \widehat B `
$ has every row equal to $
` z_a \mathbf 1 B `
$.
That is precomputed and stored in
`azp_with_adj`
as a row-vector.
Epilogue parameters:
-
`scale_a`
is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
-
Generally this will be per-tensor as the zero-points are per-tensor.
-
`scale_b`
is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
-
`azp_with_adj`
is the precomputed zero-point term ($
` z_a J_a \widehat B `
$), is per-channel (row-vector).
-
`bias`
is the bias, is always per-channel (row-vector).
To use these kernels efficiently, users must precompute the
`azp_with_adj`
term offline and pass it to the kernel.
### ScaledEpilogueAzpPerToken
This epilogue computes the asymmetric per-token quantization for activations with bias.
The output of the GEMM is the same as above, but the $
` z_a `
$ is a column-vector.
That means the zero-point term $
` z_a J_a \widehat B `
$ becomes an outer product of $
` z_a `
$ and $
` \mathbf 1 \widehat B `
$.
Epilogue parameters:
-
`scale_a`
is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
-
Generally this will be per-token as the zero-points are per-token.
-
`scale_b`
is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
-
`azp_adj`
is the precomputed zero-point adjustment term ($
` \mathbf 1 \widehat B `
$), is per-channel (row-vector).
-
`azp`
is the zero-point (
`z_a`
), is per-token (column-vector).
-
`bias`
is the bias, is always per-channel (row-vector).
To use these kernels efficiently, users must precompute the
`azp_adj`
term offline and pass it to the kernel.
The epilogue performs the following computation (where
`Dq`
is the raw quantized output of the GEMM):
```
out = scale_a * scale_b * (Dq - azp_adj * azp) + bias
```
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp
View file @
8d59dbb0
...
...
@@ -207,6 +207,156 @@ struct VisitorRowOrScalarBroadcast {
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null
template
<
class
ThreadMap
,
class
Element
,
class
StrideMNL
>
struct
VisitorRowOrZeroBroadcast
{
// This struct has been modified to remove null_default (because it's always 0)
struct
Arguments
{
Element
const
*
ptr_row
=
nullptr
;
StrideMNL
dRow
=
{};
};
using
Params
=
Arguments
;
template
<
class
ProblemShape
>
static
constexpr
Params
to_underlying_arguments
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
,
void
*
workspace
)
{
return
args
;
}
template
<
class
ProblemShape
>
static
size_t
get_workspace_size
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
return
0
;
}
struct
SharedStorage
{};
// Global load type
static
int
constexpr
vec_bits
=
ThreadMap
::
kElementsPerAccess
*
sizeof_bits
<
Element
>::
value
;
using
VecType
=
uint_bit_t
<
cute
::
min
(
128
,
vec_bits
)
>
;
static
int
constexpr
VecLength
=
sizeof
(
VecType
)
/
sizeof
(
Element
);
CUTLASS_HOST_DEVICE
VisitorRowOrZeroBroadcast
()
{
}
CUTLASS_HOST_DEVICE
VisitorRowOrZeroBroadcast
(
Params
const
&
params
,
SharedStorage
const
&
shared_storage
)
:
params_ptr
(
&
params
)
{
}
Params
const
*
params_ptr
;
template
<
class
GTensor
,
class
RTensor
,
class
CTensor
,
class
ProblemShape
>
struct
Callbacks
:
EmptyCallbacks
{
CUTLASS_DEVICE
Callbacks
(
GTensor
&&
tC_gRow
,
RTensor
&&
tC_rRow
,
CTensor
&&
tC_cRow
,
ProblemShape
problem_shape
,
Params
const
*
params_ptr
)
:
tC_gRow
(
cute
::
forward
<
GTensor
>
(
tC_gRow
)),
tC_rRow
(
cute
::
forward
<
RTensor
>
(
tC_rRow
)),
tC_cRow
(
cute
::
forward
<
CTensor
>
(
tC_cRow
)),
n
(
get
<
1
>
(
problem_shape
)),
params_ptr
(
params_ptr
)
{
}
GTensor
tC_gRow
;
RTensor
tC_rRow
;
CTensor
tC_cRow
;
Params
const
*
params_ptr
;
int
n
;
// This function is modified from VisitorRowBroadcast
CUTLASS_DEVICE
void
begin_epilogue
()
{
clear
(
tC_rRow
);
auto
src_v
=
filter
(
tC_gRow
);
auto
coord_v
=
filter
(
tC_cRow
);
auto
dst_v
=
filter
(
tC_rRow
);
if
(
params_ptr
->
ptr_row
!=
nullptr
)
{
// In this case we are loading from a row vector and broadcasting
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
src_v
);
++
i
)
{
bool
guard
=
get
<
1
>
(
coord_v
(
i
))
<
n
;
cutlass
::
arch
::
global_load
<
VecType
,
sizeof
(
VecType
)
>
(
dst_v
(
i
),
(
void
const
*
)
&
src_v
(
i
),
guard
);
}
}
else
{
// In this case we are broadcasting 0
VecType
filled_vec
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
VecLength
;
i
++
)
{
reinterpret_cast
<
Element
*>
(
&
filled_vec
)[
i
]
=
Element
{
0
};
}
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
src_v
);
++
i
)
{
if
(
get
<
1
>
(
coord_v
(
i
))
<
n
)
{
dst_v
(
i
)
=
filled_vec
;
}
}
}
}
template
<
class
ElementAccumulator
,
int
FragmentSize
>
CUTLASS_DEVICE
auto
// returns an Array
visit
(
int
iter_idx
,
int
row_idx
,
int
column_idx
,
int
frg_idx
,
Array
<
ElementAccumulator
,
FragmentSize
>
const
&
frg_acc
)
{
Tensor
rRow_frg
=
recast
<
Array
<
Element
,
FragmentSize
>>
(
coalesce
(
tC_rRow
));
return
rRow_frg
(
column_idx
);
}
};
template
<
class
ProblemShape
>
CUTLASS_DEVICE
auto
get_callbacks
(
gemm
::
GemmCoord
threadblock_tile_offset
,
int
thread_idx
,
ProblemShape
problem_shape
)
{
Tensor
mRow
=
make_tensor
(
make_gmem_ptr
(
params_ptr
->
ptr_row
),
problem_shape
,
params_ptr
->
dRow
);
// VECTOR, FRAGMENT_COLUMN
Tensor
tC_gRow
=
recast
<
VecType
>
(
ThreadMap
::
partition
(
mRow
,
thread_idx
,
threadblock_tile_offset
)
)(
_
,
_
,
_0
{},
_0
{},
_0
{},
_0
{});
Tensor
tC_rRow
=
make_tensor_like
(
tC_gRow
);
// Generate the pred tensor
Tensor
cRow
=
make_identity_tensor
(
mRow
.
shape
());
Tensor
tC_cRow
=
outer_partition
(
ThreadMap
::
partition
(
cRow
,
thread_idx
,
threadblock_tile_offset
)(
_
,
_
,
_0
{},
_0
{},
_0
{},
_0
{}),
Shape
<
Int
<
VecLength
>>
{},
(
_0
{})
);
return
Callbacks
<
decltype
(
tC_gRow
),
decltype
(
tC_rRow
),
decltype
(
tC_cRow
),
ProblemShape
>
(
cute
::
move
(
tC_gRow
),
cute
::
move
(
tC_rRow
),
cute
::
move
(
tC_cRow
),
problem_shape
,
params_ptr
);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Column vector broadcast
...
...
@@ -217,7 +367,7 @@ template<
>
struct
VisitorColOrScalarBroadcast
{
// This struct has been modified to have a bool indicating that ptr_col is a
// This struct has been modified to have a bool indicating that ptr_col is a
// scalar that must be broadcast.
struct
Arguments
{
Element
const
*
ptr_col
=
nullptr
;
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
View file @
8d59dbb0
...
...
@@ -50,6 +50,25 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
}
}
void
cutlass_scaled_mm_azp_sm75
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
azp
)
{
return
cutlass_scaled_mm_sm75_epilogue
<
vllm
::
ScaledEpilogueBiasAzpToken
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
}
else
{
return
cutlass_scaled_mm_sm75_epilogue
<
vllm
::
ScaledEpilogueBiasAzp
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
}
}
template
<
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm80_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
...
...
@@ -87,6 +106,25 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
}
}
void
cutlass_scaled_mm_azp_sm80
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
azp
)
{
return
cutlass_scaled_mm_sm80_epilogue
<
vllm
::
ScaledEpilogueBiasAzpToken
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
}
else
{
return
cutlass_scaled_mm_sm80_epilogue
<
vllm
::
ScaledEpilogueBiasAzp
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
}
}
template
<
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm89_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
...
...
@@ -139,3 +177,22 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
void
cutlass_scaled_mm_azp_sm89
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
azp
)
{
return
cutlass_scaled_mm_sm89_epilogue
<
vllm
::
ScaledEpilogueBiasAzpToken
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
}
else
{
return
cutlass_scaled_mm_sm89_epilogue
<
vllm
::
ScaledEpilogueBiasAzp
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
}
}
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
View file @
8d59dbb0
...
...
@@ -73,19 +73,63 @@ struct enable_sm89_to_sm90 : Kernel {
};
/*
* This class provides the common
ScaleA and ScaleB
descriptors for the
* ScaledEpilogue
and ScaledEpilogueBias
classes
.
* This class provides the common
load
descriptors for the
* ScaledEpilogue
[...]
classes
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBase
{
protected:
using
Accum
=
cutlass
::
epilogue
::
threadblock
::
VisitorAccFetch
;
using
ScaleA
=
cutlass
::
epilogue
::
threadblock
::
VisitorColOrScalarBroadcast
<
OutputTileThreadMap
,
float
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
template
<
typename
T
>
using
ColOrScalarLoad
=
cutlass
::
epilogue
::
threadblock
::
VisitorColOrScalarBroadcast
<
OutputTileThreadMap
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
template
<
typename
T
>
using
RowOrScalarLoad
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowOrScalarBroadcast
<
OutputTileThreadMap
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
template
<
typename
T
>
using
ColLoad
=
cutlass
::
epilogue
::
threadblock
::
VisitorColBroadcast
<
OutputTileThreadMap
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
template
<
typename
T
>
using
RowLoad
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowBroadcast
<
OutputTileThreadMap
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
template
<
typename
T
>
using
RowOrZeroLoad
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowOrZeroBroadcast
<
OutputTileThreadMap
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template
<
typename
Descriptor
,
typename
T
>
static
auto
args_from_tensor
(
torch
::
Tensor
const
&
tensor
)
{
using
Arguments
=
typename
Descriptor
::
Arguments
;
auto
*
data_ptr
=
static_cast
<
T
*>
(
tensor
.
data_ptr
());
if
constexpr
(
std
::
is_same_v
<
Descriptor
,
ColOrScalarLoad
<
T
>>
||
std
::
is_same_v
<
Descriptor
,
RowOrScalarLoad
<
T
>>
)
{
return
Arguments
{
data_ptr
,
tensor
.
numel
()
!=
1
};
}
else
{
// it would technically work but no use case as data_ptr is never nullptr
static_assert
(
!
std
::
is_same_v
<
Descriptor
,
RowOrZeroLoad
<
T
>>
);
return
Arguments
{
data_ptr
};
}
}
using
ScaleB
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowOrScalarBroadcast
<
OutputTileThreadMap
,
float
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template
<
typename
Descriptor
,
typename
T
>
static
auto
args_from_tensor
(
c10
::
optional
<
torch
::
Tensor
>
const
&
tensor
)
{
static_assert
(
std
::
is_same_v
<
Descriptor
,
RowOrZeroLoad
<
T
>>
);
using
Arguments
=
typename
Descriptor
::
Arguments
;
auto
*
data_ptr
=
tensor
?
static_cast
<
T
*>
(
tensor
->
data_ptr
())
:
nullptr
;
return
Arguments
{
data_ptr
};
}
};
/*
...
...
@@ -110,8 +154,8 @@ struct ScaledEpilogue
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
ScaleA
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>
;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>
;
using
Compute0
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
...
...
@@ -131,28 +175,32 @@ struct ScaledEpilogue
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
ScaleAArgs
=
typename
ScaleA
::
Arguments
;
using
ScaleBArgs
=
typename
ScaleB
::
Arguments
;
ScaleBArgs
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
ScaleAArgs
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
typename
EVTCompute0
::
Arguments
evt0_compute_args
{
b_args
};
typename
EVTCompute
::
Arguments
evt_compute_args
{
a_args
,
evt0_compute_args
};
return
evt_compute_args
;
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
return
ArgumentType
{
a_args
,
evt0_args
};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBias
:
pr
ivate
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
pr
ivate
:
:
pr
otected
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
pr
otected
:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
ScaleA
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>
;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>
;
using
Bias
=
typename
SUPER
::
template
RowLoad
<
ElementD
>;
using
Compute0
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
...
...
@@ -164,30 +212,163 @@ struct ScaledEpilogueBias
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
Bias
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowBroadcast
<
OutputTileThreadMap
,
ElementD
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute1
,
ScaleA
,
EVTCompute0
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
bias
)
{
using
ScaleAArgs
=
typename
ScaleA
::
Arguments
;
using
ScaleBArgs
=
typename
ScaleB
::
Arguments
;
using
BiasArgs
=
typename
Bias
::
Arguments
;
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
return
ArgumentType
{
a_args
,
evt0_args
,
bias_args
};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBiasAzp
:
protected
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Bias
=
typename
SUPER
::
template
RowOrZeroLoad
<
ElementD
>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using
AzpWithAdj
=
typename
SUPER
::
template
RowLoad
<
int32_t
>;
// Compute float(accum - azp_adj), both operands are int32_t
using
ComputeAzp
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
minus
,
float
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeAzp
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeAzp
,
Accum
,
AzpWithAdj
>
;
using
ComputeScaleB
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeScaleB
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeScaleB
,
ScaleB
,
EVTComputeAzp
>
;
using
ComputeScaleBiasA
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeScaleBiasA
,
ScaleA
,
EVTComputeScaleB
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpWithAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{{},
azp_adj_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_azp_args
};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBiasAzpToken
:
protected
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Bias
=
typename
SUPER
::
template
RowOrZeroLoad
<
ElementD
>;
// Per-token azp term, shape (m,1)
using
Azp
=
typename
SUPER
::
template
ColLoad
<
int32_t
>;
ScaleBArgs
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
ScaleAArgs
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
BiasArgs
bias_args
{
static_cast
<
ElementD
*>
(
bias
.
data_ptr
()),
{}};
// This is the AZP adjustment term, J @ B, shape (1,n)
using
AzpAdj
=
typename
SUPER
::
template
RowLoad
<
int32_t
>;
// Compute azp * azp_adj
using
ComputeAzp
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
int32_t
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
typename
EVTCompute0
::
Arguments
evt0_compute_args
{
b_args
};
using
EVTComputeAzp
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeAzp
,
Azp
,
AzpAdj
>
;
typename
EVTCompute
::
Arguments
evt_compute_args
{
a_args
,
evt0_compute_args
,
bias_args
};
return
evt_compute_args
;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using
ComputeAcc
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
minus
,
float
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeAcc
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeAcc
,
Accum
,
EVTComputeAzp
>
;
using
ComputeScaleB
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeScaleB
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeScaleB
,
ScaleB
,
EVTComputeAcc
>
;
using
ComputeScaleBiasA
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeScaleBiasA
,
ScaleA
,
EVTComputeScaleB
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
torch
::
Tensor
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
auto
azp_args
=
SUPER
::
template
args_from_tensor
<
Azp
,
int32_t
>(
azp
);
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{
azp_args
,
azp_adj_args
};
typename
EVTComputeAcc
::
Arguments
evt_acc_args
{{},
evt_azp_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_acc_args
};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
}
};
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
View file @
8d59dbb0
...
...
@@ -58,21 +58,63 @@ struct enable_sm90_or_later : Kernel {
};
/*
* This class provides the common
ScaleA and ScaleB
descriptors for the
* ScaledEpilogue
and ScaledEpilogueBias
classes
.
* This class provides the common
load
descriptors for the
* ScaledEpilogue
[...]
classes
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogueBase
{
protected:
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
using
ScaleA
=
cutlass
::
epilogue
::
fusion
::
Sm90ColOrScalarBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
float
,
template
<
typename
T
>
using
ColOrScalarLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90ColOrScalarBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
using
ScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90RowOrScalarBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
float
,
template
<
typename
T
>
using
RowOrScalarLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90RowOrScalarBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
// Don't want to support nullptr by default
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
using
ColLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90ColBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
// Don't want to support nullptr by default
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
using
RowLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template
<
typename
Descriptor
,
typename
T
>
static
auto
args_from_tensor
(
torch
::
Tensor
const
&
tensor
)
{
using
Arguments
=
typename
Descriptor
::
Arguments
;
auto
*
data_ptr
=
static_cast
<
T
*>
(
tensor
.
data_ptr
());
if
constexpr
(
std
::
is_same_v
<
Descriptor
,
ColOrScalarLoad
<
T
>>
||
std
::
is_same_v
<
Descriptor
,
RowOrScalarLoad
<
T
>>
)
{
return
Arguments
{
data_ptr
,
tensor
.
numel
()
!=
1
};
}
else
{
static_assert
(
!
std
::
is_same_v
<
Descriptor
,
ColLoad
<
T
,
true
>>
&&
!
std
::
is_same_v
<
Descriptor
,
RowLoad
<
T
,
true
>>
);
return
Arguments
{
data_ptr
};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template
<
typename
Descriptor
,
typename
T
>
static
auto
args_from_tensor
(
c10
::
optional
<
torch
::
Tensor
>
const
&
tensor
)
{
using
Arguments
=
typename
Descriptor
::
Arguments
;
auto
*
data_ptr
=
tensor
?
static_cast
<
T
*>
(
tensor
->
data_ptr
())
:
nullptr
;
static_assert
(
std
::
is_same_v
<
Descriptor
,
ColLoad
<
T
,
true
>>
||
std
::
is_same_v
<
Descriptor
,
RowLoad
<
T
,
true
>>
);
return
Arguments
{
data_ptr
};
}
};
/*
...
...
@@ -97,8 +139,8 @@ struct ScaledEpilogue
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
ScaleA
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>
;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>
;
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
...
...
@@ -118,24 +160,32 @@ struct ScaledEpilogue
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
ScaleA_Args
=
typename
ScaleA
::
Arguments
;
using
ScaleB_Args
=
typename
ScaleB
::
Arguments
;
ScaleA_Args
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
ScaleB_Args
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
return
ArgumentType
{
a_args
,
{
b_args
}};
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
return
ArgumentType
{
a_args
,
evt0_args
};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogueBias
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
ScaleA
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Bias
=
typename
SUPER
::
template
RowLoad
<
ElementD
>;
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
...
...
@@ -148,27 +198,160 @@ struct ScaledEpilogueBias
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
Bias
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
ElementD
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
ElementD
>
,
false
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute1
,
ScaleA
,
EVTCompute0
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
bias
)
{
using
ScaleA_Args
=
typename
ScaleA
::
Arguments
;
using
ScaleB_Args
=
typename
ScaleB
::
Arguments
;
using
Bias_Args
=
typename
Bias
::
Arguments
;
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
return
ArgumentType
{
a_args
,
evt0_args
,
bias_args
};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogueBiasAzp
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Bias
=
typename
SUPER
::
template
RowLoad
<
ElementD
,
true
>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using
AzpWithAdj
=
typename
SUPER
::
template
RowLoad
<
int32_t
>;
// Compute float(accum - azp_adj), both operands are int32_t
using
ComputeAzp
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
minus
,
float
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeAzp
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeAzp
,
Accum
,
AzpWithAdj
>
;
using
ComputeScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeScaleB
,
ScaleB
,
EVTComputeAzp
>
;
using
ComputeScaleBiasA
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeScaleBiasA
,
ScaleA
,
EVTComputeScaleB
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpWithAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{{},
azp_adj_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_azp_args
};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogueBiasAzpToken
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Bias
=
typename
SUPER
::
template
RowLoad
<
ElementD
,
true
>;
// Per-token azp term, shape (m,1)
using
Azp
=
typename
SUPER
::
template
ColLoad
<
int32_t
>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using
AzpAdj
=
typename
SUPER
::
template
RowLoad
<
int32_t
>;
// Compute azp * azp_adj
using
ComputeAzp
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
int32_t
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeAzp
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeAzp
,
Azp
,
AzpAdj
>
;
ScaleA_Args
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
ScaleB_Args
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
Bias_Args
bias_args
{
static_cast
<
ElementD
*>
(
bias
.
data_ptr
())};
// Compute float(accum - azp*azp_adj), all operands are int32_t
using
ComputeAcc
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
minus
,
float
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
return
ArgumentType
{
a_args
,
{
b_args
},
bias_args
};
using
EVTComputeAcc
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeAcc
,
Accum
,
EVTComputeAzp
>
;
using
ComputeScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeScaleB
,
ScaleB
,
EVTComputeAcc
>
;
using
ComputeScaleBiasA
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeScaleBiasA
,
ScaleA
,
EVTComputeScaleB
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
torch
::
Tensor
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
auto
azp_args
=
SUPER
::
template
args_from_tensor
<
Azp
,
int32_t
>(
azp
);
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{
azp_args
,
azp_adj_args
};
typename
EVTComputeAcc
::
Arguments
evt_acc_args
{{},
evt_azp_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_acc_args
};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
}
};
...
...
@@ -546,4 +729,23 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
}
}
void
cutlass_scaled_mm_azp_sm90
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
azp
)
{
return
cutlass_scaled_mm_sm90_epilogue
<
ScaledEpilogueBiasAzpToken
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
}
else
{
return
cutlass_scaled_mm_sm90_epilogue
<
ScaledEpilogueBiasAzp
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
}
}
#endif
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
8d59dbb0
...
...
@@ -29,6 +29,40 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
void
cutlass_scaled_mm_azp_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_azp_sm80
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_azp_sm89
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
void
cutlass_scaled_mm_azp_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
)
{
// CUTLASS FP8 kernels need at least
// CUDA 12.0 on SM90 systems (Hopper)
...
...
@@ -45,18 +79,20 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
return
false
;
}
void
cutlass_scaled_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
int32_t
major_capability
;
int32_t
minor_capability
;
int32_t
get_sm_version_num
()
{
int32_t
major_capability
,
minor_capability
;
cudaDeviceGetAttribute
(
&
major_capability
,
cudaDevAttrComputeCapabilityMajor
,
0
);
cudaDeviceGetAttribute
(
&
minor_capability
,
cudaDevAttrComputeCapabilityMinor
,
0
);
int32_t
version_num
=
major_capability
*
10
+
minor_capability
;
return
version_num
;
}
void
cutlass_scaled_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
// Checks for conformality
TORCH_CHECK
(
a
.
dim
()
==
2
&&
b
.
dim
()
==
2
&&
c
.
dim
()
==
2
);
TORCH_CHECK
(
c
.
size
(
0
)
==
a
.
size
(
0
)
&&
a
.
size
(
1
)
==
b
.
size
(
0
)
&&
...
...
@@ -77,7 +113,7 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
}
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
int32_t
version_num
=
get_sm_version_num
();
if
(
version_num
>=
90
)
{
// Hopper
...
...
@@ -99,3 +135,64 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
cutlass_scaled_mm_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
}
void
cutlass_scaled_mm_azp
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
// Checks for conformality
TORCH_CHECK
(
a
.
dim
()
==
2
&&
b
.
dim
()
==
2
&&
c
.
dim
()
==
2
);
TORCH_CHECK
(
c
.
size
(
0
)
==
a
.
size
(
0
)
&&
a
.
size
(
1
)
==
b
.
size
(
0
)
&&
b
.
size
(
1
)
==
c
.
size
(
1
));
TORCH_CHECK
(
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
));
TORCH_CHECK
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
));
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
&&
c
.
stride
(
1
)
==
1
);
// Row-major
TORCH_CHECK
(
b
.
stride
(
0
)
==
1
);
// Column-major
TORCH_CHECK
(
c
.
stride
(
0
)
%
16
==
0
&&
b
.
stride
(
1
)
%
16
==
0
);
// 16 Byte Alignment
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
// bias, azp, azp_adj are all 1d
// bias and azp_adj have n elements, azp has m elements
if
(
bias
)
{
TORCH_CHECK
(
bias
->
numel
()
==
b
.
size
(
1
)
&&
bias
->
is_contiguous
());
}
if
(
azp
)
{
TORCH_CHECK
(
azp
->
numel
()
==
a
.
size
(
0
)
&&
azp
->
is_contiguous
());
}
TORCH_CHECK
(
azp_adj
.
numel
()
==
b
.
size
(
1
)
&&
azp_adj
.
is_contiguous
());
// azp & bias types
TORCH_CHECK
(
azp_adj
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
!
azp
||
azp
->
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
!
bias
||
bias
->
dtype
()
==
c
.
dtype
(),
"currently bias dtype must match output dtype "
,
c
.
dtype
());
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
int32_t
version_num
=
get_sm_version_num
();
if
(
version_num
>=
90
)
{
// Hopper
// Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
cutlass_scaled_mm_azp_sm90
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
#else
cutlass_scaled_mm_azp_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
#endif
}
else
if
(
version_num
==
89
)
{
// Ada Lovelace
cutlass_scaled_mm_azp_sm89
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
}
else
if
(
version_num
>=
80
)
{
// Ampere
cutlass_scaled_mm_azp_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
}
else
{
// Turing
TORCH_CHECK
(
version_num
>=
75
);
cutlass_scaled_mm_azp_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
}
}
\ No newline at end of file
csrc/torch_bindings.cpp
View file @
8d59dbb0
...
...
@@ -166,13 +166,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
impl
(
"marlin_qqq_gemm"
,
torch
::
kCUDA
,
&
marlin_qqq_gemm
);
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization
.
// quantization
, as well as bias
ops
.
def
(
"cutlass_scaled_mm(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()"
);
ops
.
impl
(
"cutlass_scaled_mm"
,
torch
::
kCUDA
,
&
cutlass_scaled_mm
);
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
ops
.
def
(
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()"
);
ops
.
impl
(
"cutlass_scaled_mm_azp"
,
torch
::
kCUDA
,
&
cutlass_scaled_mm_azp
);
// Check if cutlass scaled_mm is supported for CUDA devices of the given
// capability
ops
.
def
(
"cutlass_scaled_mm_supports_fp8"
,
&
cutlass_scaled_mm_supports_fp8
);
...
...
tests/kernels/test_cutlass.py
View file @
8d59dbb0
...
...
@@ -28,13 +28,16 @@ def to_int8(tensor: torch.Tensor):
return
torch
.
round
(
tensor
.
clamp
(
min
=-
128
,
max
=
127
)).
to
(
dtype
=
torch
.
int8
)
def
rand_int8
(
shape
:
tuple
,
device
:
str
=
"cuda"
):
return
to_int8
(
torch
.
rand
(
shape
,
device
=
device
)
*
255
-
128
)
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
Type
[
torch
.
dtype
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
output
=
(
scale_a
*
(
scale_b
*
(
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
float32
),
b
.
to
(
dtype
=
torch
.
float32
))))).
to
(
out_dtype
)
if
bias
is
not
None
:
...
...
@@ -221,6 +224,121 @@ def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
use_bias
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
16
,
32
,
64
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
64
,
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
skip
def
test_cutlass_int8_azp_bias_fold
(
m
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
torch
.
dtype
):
# Currently, the test is failing because folding azp into
# 16-bit bias loses too much precision
scale_a
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
scale_b
=
torch
.
randn
((
1
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
aq_i8
=
rand_int8
((
m
,
k
))
bq_i8
=
rand_int8
((
n
,
k
)).
t
()
aq_i32
=
aq_i8
.
to
(
dtype
=
torch
.
int32
)
bq_i32
=
bq_i8
.
to
(
dtype
=
torch
.
int32
)
aq_f32
=
aq_i8
.
to
(
dtype
=
torch
.
float32
)
bq_f32
=
bq_i8
.
to
(
dtype
=
torch
.
float32
)
b_dq
=
scale_b
*
bq_f32
azp_a
=
torch
.
rand
((
1
,
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
*
10
+
1.5
azp_aq_i8
=
(
azp_a
/
scale_a
).
to
(
dtype
=
torch
.
int8
)
azp_a
=
azp_aq_i8
.
to
(
dtype
=
torch
.
float32
)
*
scale_a
# correct for rounding
a_dq
=
scale_a
*
(
aq_i32
+
azp_aq_i8
).
to
(
dtype
=
torch
.
float32
)
assert
torch
.
allclose
(
a_dq
,
scale_a
*
aq_f32
+
azp_a
)
baseline_dq
=
torch
.
mm
(
a_dq
,
b_dq
).
to
(
out_dtype
)
J
=
torch
.
ones
((
1
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
azp_bias
=
(
azp_a
*
scale_b
*
(
J
@
bq_f32
)).
to
(
out_dtype
)
assert
azp_bias
.
shape
==
(
1
,
n
)
assert
azp_bias
[
0
,
:].
shape
==
(
n
,
)
baseline_q
=
(
scale_a
.
to
(
device
=
'cpu'
)
*
scale_b
.
to
(
device
=
'cpu'
)
*
(
(
aq_i32
+
azp_aq_i8
).
to
(
device
=
'cpu'
)
@
bq_i32
.
to
(
device
=
'cpu'
))).
to
(
dtype
=
out_dtype
,
device
=
'cuda'
)
out
=
ops
.
cutlass_scaled_mm
(
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
out_dtype
=
out_dtype
,
bias
=
azp_bias
[
0
,
:])
assert
torch
.
allclose
(
out
,
baseline_dq
,
rtol
=
1e-2
,
atol
=
1e0
)
assert
torch
.
allclose
(
out
,
baseline_q
,
rtol
=
1e-2
,
atol
=
1e0
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
16
,
32
,
64
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
64
,
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"azp_per_token"
,
[
True
,
False
])
def
test_cutlass_int8_azp
(
m
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
torch
.
dtype
,
use_bias
:
bool
,
azp_per_token
:
bool
):
m_azp
=
m
if
azp_per_token
else
1
scale_a
=
torch
.
randn
((
m_azp
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
scale_b
=
torch
.
randn
((
1
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
aq_i8
=
rand_int8
((
m
,
k
))
aq_i32
=
aq_i8
.
to
(
dtype
=
torch
.
int32
)
aq_f32
=
aq_i8
.
to
(
dtype
=
torch
.
float32
)
bq_i8
=
rand_int8
((
n
,
k
)).
t
()
bq_i32
=
bq_i8
.
to
(
dtype
=
torch
.
int32
)
bq_f32
=
bq_i8
.
to
(
dtype
=
torch
.
float32
)
b_dq
=
scale_b
*
bq_f32
azp_a
=
torch
.
rand
(
(
m_azp
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
*
10
+
1.5
azp_aq_i8
=
(
azp_a
/
scale_a
).
to
(
dtype
=
torch
.
int8
)
azp_a
=
azp_aq_i8
.
to
(
dtype
=
torch
.
float32
)
*
scale_a
# correct for rounding
a_dq
=
scale_a
*
(
aq_i32
-
azp_aq_i8
).
to
(
dtype
=
torch
.
float32
)
assert
torch
.
allclose
(
a_dq
,
scale_a
*
aq_f32
-
azp_a
,
rtol
=
1e-4
,
atol
=
1e-3
)
if
use_bias
:
bias
=
torch
.
rand
((
1
,
n
),
device
=
"cuda"
,
dtype
=
out_dtype
)
*
10
+
2.5
else
:
bias
=
torch
.
zeros
((
1
,
n
),
device
=
"cuda"
,
dtype
=
out_dtype
)
baseline_dq
=
(
torch
.
mm
(
a_dq
,
b_dq
)
+
bias
).
to
(
out_dtype
)
# int32 mm not supported on CUDA
a_noazp_i32_cpu
=
(
aq_i32
-
azp_aq_i8
).
to
(
device
=
'cpu'
)
cq
=
(
a_noazp_i32_cpu
@
bq_i32
.
to
(
device
=
'cpu'
)).
to
(
device
=
'cuda'
)
baseline_q
=
(
scale_a
*
scale_b
*
cq
+
bias
).
to
(
dtype
=
out_dtype
)
# Hadamard is just the sum of the cols
azp_adj_i32
=
bq_i32
.
sum
(
dim
=
0
,
keepdim
=
True
,
dtype
=
torch
.
int32
)
azp_i32
=
azp_aq_i8
.
to
(
dtype
=
torch
.
int32
)
func_bias
=
bias
if
use_bias
else
None
if
azp_per_token
:
out
=
ops
.
cutlass_scaled_mm_azp
(
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
out_dtype
,
azp_adj_i32
,
azp_i32
,
func_bias
)
else
:
azp_with_adj_i32
=
azp_i32
*
azp_adj_i32
out
=
ops
.
cutlass_scaled_mm_azp
(
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
out_dtype
,
azp_with_adj_i32
,
None
,
func_bias
)
# bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
rtol
=
1e-2
if
out_dtype
==
torch
.
bfloat16
else
1e-3
atol
=
1e-3
assert
torch
.
allclose
(
out
,
baseline_dq
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
baseline_q
,
rtol
=
rtol
,
atol
=
atol
)
# Test working with a subset of A and B
def
test_cutlass_subset
():
big_m
,
big_n
,
big_k
=
1024
,
1024
,
1024
...
...
vllm/_custom_ops.py
View file @
8d59dbb0
...
...
@@ -241,6 +241,8 @@ def cutlass_scaled_mm(a: torch.Tensor,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
assert
(
b
.
shape
[
0
]
%
16
==
0
and
b
.
shape
[
1
]
%
16
==
0
)
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
assert
bias
is
None
or
bias
.
shape
[
0
]
==
b
.
shape
[
1
]
and
bias
.
dtype
==
out_dtype
m
=
a
.
shape
[
0
]
n
=
b
.
shape
[
1
]
...
...
@@ -251,6 +253,28 @@ def cutlass_scaled_mm(a: torch.Tensor,
return
out
def
cutlass_scaled_mm_azp
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
azp_adj
:
torch
.
Tensor
,
azp
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
assert
(
b
.
shape
[
0
]
%
16
==
0
and
b
.
shape
[
1
]
%
16
==
0
)
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
assert
bias
is
None
or
bias
.
numel
(
)
==
b
.
shape
[
1
]
and
bias
.
dtype
==
out_dtype
m
=
a
.
shape
[
0
]
n
=
b
.
shape
[
1
]
out
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
a
.
device
)
torch
.
ops
.
_C
.
cutlass_scaled_mm_azp
(
out
,
a
,
b
,
scale_a
,
scale_b
,
azp_adj
,
azp
,
bias
)
return
out
# aqlm
def
aqlm_gemm
(
input
:
torch
.
Tensor
,
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
...
...
@@ -572,7 +596,7 @@ for k, v in names_and_values.items():
if
isinstance
(
v
,
fn_type
)
\
and
v
.
__code__
.
co_filename
==
__file__
\
and
any
(
arg
is
torch
.
Tensor
or
arg
==
"torch.Tensor"
for
arg
in
v
.
__annotations__
.
values
()):
for
arg
in
v
.
__annotations__
.
values
()):
names_and_values_to_update
[
k
]
=
hint_on_error
(
v
)
names_and_values
.
update
(
names_and_values_to_update
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment