Epilogues.md 5.55 KB
Newer Older
1
2
3
# CUTLASS Epilogues

## Introduction
4
5

This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs.
6
7
8
9
10
11

Currently, we only support symmetric quantization for weights,
and symmetric and asymmetric quantization for activations.
Both can be quantized per-tensor or per-channel (weights) / per-token (activations).

There are 4 epilogues:
12
13
14
15
16

1. `ScaledEpilogue`: symmetric quantization for activations, no bias.
1. `ScaledEpilogueBias`: symmetric quantization for activations, supports bias.
1. `ScaledEpilogueAzp`: asymmetric per-tensor quantization for activations, supports bias.
1. `ScaledEpilogueAzpPerToken`: asymmetric per-token quantization for activations, supports bias.
17
18
19
20
21
22
23
24
25
26
27
28
29
30

We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size.
Instead, if no bias is passed, the epilogue will use 0 as the bias.
That induces a redundant addition operation (and runtime check), but the performance impact is minor.

## Underlying Linear Algebra

More details available in the [Activation Quantization RFC](https://github.com/vllm-project/vllm/issues/3975).

If $` \widehat X `$ is the quantized $` X `$, our matrices become the following

```math
A = s_a (\widehat A - J_a z_a)
```
31

32
33
34
```math
B = s_b \widehat B
```
35

36
37
38
```math
D = A B + C
```
39

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
```math
D = s_a s_b \widehat D + C
```

Here, D is the output of the GEMM, and C is the bias.
A is the activations and supports asymmetric quantization,
and B is the weights and only supports symmetric quantization.
$ s_a $ and $s_b$ are the scales for activations and weights, respectively.
$ z_a $ is the zero-point for activations, and $ J_a $ is the matrix of all ones with dimensions of A.
Additional epilogues would be required to support asymmetric quantization for weights.

Expanding further, we can calculate $` \widehat D `$ as follows:

```math
A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B
```
56

57
58
59
```math
A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right)
```
60

61
62
63
64
65
66
67
68
69
70
```math
\widehat D = \widehat A \widehat B - z_a J_a \widehat B
```

Note that $` \widehat A \widehat B `$ is the raw output of the GEMM,
and $` J_a \widehat B `$ is known ahead of time.
Each row of it is equal to $` \mathbf 1 \widehat B `$, which is a row-vector of column sums of $` \widehat B `$.

## Epilogues

71
72
### `ScaledEpilogue`

73
74
75
76
77
78
This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$.
The output of the GEMM is:

```math
\widehat D = \widehat A \widehat B
```
79

80
81
82
```math
D = s_a s_b \widehat D
```
83

84
85
86
87
88
```math
D = s_a s_b \widehat A \widehat B
```

Epilogue parameters:
89

90
91
92
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).

93
94
### `ScaledEpilogueBias`

95
96
97
98
99
100
This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$.
The output of the GEMM is:

```math
\widehat D = \widehat A \widehat B
```
101

102
103
104
```math
D = s_a s_b \widehat D + C 
```
105

106
107
108
109
110
```math
D = s_a s_b \widehat A \widehat B + C
```

Epilogue parameters:
111

112
113
114
115
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
- `bias` is the bias, is always per-channel (row-vector).

116
117
### `ScaledEpilogueAzp`

118
119
120
121
122
123
This epilogue computes the asymmetric per-tensor quantization for activations with bias.
The output of the GEMM is:

```math
\widehat D = \widehat A \widehat B - z_a J_a \widehat B
```
124

125
126
127
```math
D = s_a s_b \widehat D + C 
```
128

129
130
131
132
```math
D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C
```

133
Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$.
134
135
136
That is precomputed and stored in `azp_with_adj` as a row-vector.

Epilogue parameters:
137

138
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
139
    - Generally this will be per-tensor as the zero-points are per-tensor.
140
141
142
143
144
145
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
- `azp_with_adj` is the precomputed zero-point term ($` z_a J_a \widehat B `$), is per-channel (row-vector).
- `bias` is the bias, is always per-channel (row-vector).

To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel.

146
147
### `ScaledEpilogueAzpPerToken`

148
149
150
151
152
153
This epilogue computes the asymmetric per-token quantization for activations with bias.

The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector.
That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$.

Epilogue parameters:
154

155
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
156
    - Generally this will be per-token as the zero-points are per-token.
157
158
159
160
161
162
163
164
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
- `azp_adj` is the precomputed zero-point adjustment term ($` \mathbf 1 \widehat B `$), is per-channel (row-vector).
- `azp` is the zero-point (`z_a`), is per-token (column-vector).
- `bias` is the bias, is always per-channel (row-vector).

To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel.

The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM):
165
166

```math
167
168
out = scale_a * scale_b * (Dq - azp_adj * azp) + bias
```