cpp_extensions.py 9.91 KB
Newer Older
1
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# See LICENSE for license information.

"""TE FP8 extensions and GEMMs"""
from typing import Optional, Tuple, Union
import torch
import transformer_engine_extensions as tex
from .constants import TE_DType


def fp8_gemm(
    A: torch.Tensor,
    A_scale_inv: torch.Tensor,
15
    A_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
Przemek Tredak's avatar
Przemek Tredak committed
16
17
18
    A_dtype: tex.DType,
    B: torch.Tensor,
    B_scale_inv: torch.Tensor,
19
    B_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
Przemek Tredak's avatar
Przemek Tredak committed
20
21
22
23
24
    B_dtype: tex.DType,
    out_dtype: torch.dtype,
    workspace: torch.Tensor,
    accumulate: bool = False,
    out: Optional[torch.Tensor] = None,
25
26
    out_index = None,
    fp8_meta_tensor: tex.FP8TensorMeta = None,
Przemek Tredak's avatar
Przemek Tredak committed
27
28
29
    bias: Optional[torch.Tensor] = None,
    use_bias: bool = False,
    use_split_accumulator: bool = False,
30
    D_dtype: Optional[tex.DType] = None,
Przemek Tredak's avatar
Przemek Tredak committed
31
32
33
34
) -> torch.Tensor:
    """TN layout GEMM with fp8 inputs."""

    empty_tensor = torch.Tensor()
35
36
    if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
        assert fp8_meta_tensor is not None and out_index is not None
Przemek Tredak's avatar
Przemek Tredak committed
37
38
39
40
41
42

    return_output = False
    if out is None:
        out = torch.empty(
            B.shape[0],
            A.shape[0],
43
            dtype=out_dtype,
Przemek Tredak's avatar
Przemek Tredak committed
44
45
46
47
            device="cuda",
        )
        return_output = True

48
    out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype
49
50
    # Use bfloat16 as default bias_dtype
    bias_dtype = tex.DType.kBFloat16 if bias is None else TE_DType[bias.dtype]
Przemek Tredak's avatar
Przemek Tredak committed
51

52
    _ = torch.ops.tex_ts.te_gemm_ts(
Przemek Tredak's avatar
Przemek Tredak committed
53
54
        A,
        A_scale_inv,
55
        A_fp8_tensor,
Przemek Tredak's avatar
Przemek Tredak committed
56
57
58
59
        A_dtype,
        True,  # transa
        B,
        B_scale_inv,
60
        B_fp8_tensor,
Przemek Tredak's avatar
Przemek Tredak committed
61
62
63
        B_dtype,
        False,  # transb
        out,
64
        empty_tensor if out_index is None else fp8_meta_tensor.scale[out_index],
Przemek Tredak's avatar
Przemek Tredak committed
65
        out_dtype,
66
        empty_tensor if out_index is None else fp8_meta_tensor.amax_history[0][out_index],
Przemek Tredak's avatar
Przemek Tredak committed
67
        bias if use_bias else empty_tensor,
68
        bias_dtype,
69
        empty_tensor,  # this is pre_gelu_out
Przemek Tredak's avatar
Przemek Tredak committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        False,  # grad
        workspace,
        workspace.shape[0],
        accumulate,
        use_split_accumulator,
    )

    if return_output:
        return out
    return None


def gemm(
    A: torch.Tensor,
    B: torch.Tensor,
    dtype: torch.dtype,
    workspace: torch.Tensor,
    gelu: bool = False,
    gelu_input: Optional[torch.Tensor] = None,
    grad: bool = False,
    accumulate: bool = False,
    layout: str = "TN",
    out: Optional[torch.Tensor] = None,
    bias: Optional[torch.Tensor] = None,
    use_bias: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]:
    """Non FP8 GEMM."""

    assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
    transa = layout[0] == "T"
    transb = layout[1] == "T"
    empty_tensor = torch.Tensor()
102
    fp8_index = -1 # dummy index
Przemek Tredak's avatar
Przemek Tredak committed
103
104
105
106
107
108

    return_output = False
    if out is None:
        out = torch.empty(
            B.shape[1] if transb else B.shape[0],
            A.shape[0] if transa else A.shape[1],
109
            dtype=dtype,
Przemek Tredak's avatar
Przemek Tredak committed
110
111
112
113
114
115
116
117
118
119
            device="cuda",
        )
        return_output = True

    if gelu and not grad:
        gelu_input = torch.empty_like(out, dtype=dtype)
    elif not gelu:
        gelu_input = empty_tensor

    if grad and use_bias:
120
        grad_bias = torch.empty(B.shape[1], dtype=out.dtype, device="cuda")
Przemek Tredak's avatar
Przemek Tredak committed
121
122
123
124
125
    else:
        grad_bias = empty_tensor

    bias = bias if use_bias else empty_tensor

126
127
128
129
130
131
132
133
134
    assert A.dtype == dtype and B.dtype == dtype, \
        f'Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}'
    input_dtype = TE_DType[dtype]
    output_dtype = TE_DType[out.dtype]
    if use_bias:
        bias_dtype = TE_DType[grad_bias.dtype] if grad else TE_DType[bias.dtype]
    else:
        bias_dtype = output_dtype

135
    _ = torch.ops.tex_ts.te_gemm_ts(
Przemek Tredak's avatar
Przemek Tredak committed
136
137
        A,
        empty_tensor,
138
        fp8_index,
Przemek Tredak's avatar
Przemek Tredak committed
139
140
141
142
        input_dtype,
        transa,
        B,
        empty_tensor,
143
        fp8_index,
Przemek Tredak's avatar
Przemek Tredak committed
144
145
146
        input_dtype,
        transb,
        out,
147
        empty_tensor, # out_scale
Przemek Tredak's avatar
Przemek Tredak committed
148
        output_dtype,
149
        empty_tensor, # out_amax
Przemek Tredak's avatar
Przemek Tredak committed
150
        grad_bias if grad else bias,
151
        bias_dtype,
Przemek Tredak's avatar
Przemek Tredak committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        gelu_input,
        grad,
        workspace,
        workspace.shape[0],
        accumulate,
        False,  # use_split_accumulator
    )

    if return_output:
        return out, grad_bias, gelu_input
    return None, grad_bias, gelu_input


def fp8_cast_transpose_fused(
    inp: torch.Tensor,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
    cast_out: Optional[torch.Tensor] = None,
    transpose_out: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor], None]:
    """Cast + Transpose with FP8 output"""

    return_outputs = False
    if cast_out is None or transpose_out is None:
        cast_out = torch.empty_like(inp, dtype=torch.int8)
        transpose_out = torch.empty(
            inp.shape[1], inp.shape[0], device="cuda", dtype=torch.int8
        )
        return_outputs = True

    tex.fused_cast_transpose(
        inp,
        fp8_meta_tensor.scale[fp8_tensor],
        fp8_meta_tensor.amax_history[0][fp8_tensor],
        fp8_meta_tensor.scale_inv[fp8_tensor],
        cast_out,
        transpose_out,
        otype,
    )

    if return_outputs:
        return cast_out, transpose_out
    return None


def fp8_cast_transpose_bgrad_fused(
    inp: torch.Tensor,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Cast + Transpose + BGRAD with FP8 output"""
    return tex.fused_cast_transpose_bgrad(
        inp,
        fp8_meta_tensor.scale[fp8_tensor],
        fp8_meta_tensor.amax_history[0][fp8_tensor],
        fp8_meta_tensor.scale_inv[fp8_tensor],
        otype,
    )


214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
def fp8_transpose_bgrad_fused(
    inp: torch.Tensor,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
    grad_bias_type: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Transpose + BGRAD with FP8 output"""
    return tex.fused_fp8_transpose_bgrad(
        inp,
        fp8_meta_tensor.scale[fp8_tensor],
        fp8_meta_tensor.amax_history[0][fp8_tensor],
        fp8_meta_tensor.scale_inv[fp8_tensor],
        otype,
        TE_DType[grad_bias_type],
    )


Przemek Tredak's avatar
Przemek Tredak committed
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
def fp8_cast_transpose_bgrad_dgelu_fused(
    grad_output: torch.Tensor,
    gelu_input: torch.Tensor,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Cast + Transpose + BGRAD + DGELU with FP8 output"""
    return tex.fused_cast_transpose_bgrad_dgelu(
        grad_output,
        gelu_input,
        fp8_meta_tensor.scale[fp8_tensor],
        fp8_meta_tensor.amax_history[0][fp8_tensor],
        fp8_meta_tensor.scale_inv[fp8_tensor],
        otype,
    )


def fp8_gelu(
    inp: torch.Tensor,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
) -> torch.Tensor:
    """GeLU with FP8 output"""
257
    return torch.ops.tex_ts.fp8_gelu_ts(
Przemek Tredak's avatar
Przemek Tredak committed
258
        inp,
259
260
261
262
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
        fp8_tensor,
Przemek Tredak's avatar
Przemek Tredak committed
263
264
265
266
267
268
269
270
271
272
273
274
        otype,
    )


def layernorm_fwd_fp8(
    inp: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    eps: float,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
275
    sm_margin: int,
276
    zero_centered_gamma: bool
Przemek Tredak's avatar
Przemek Tredak committed
277
278
279
280
281
282
283
284
285
286
287
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """LayerNorm with FP8 output"""
    return tex.layernorm_fwd_fp8(
        inp,
        weight,
        bias,
        eps,
        fp8_meta_tensor.scale[fp8_tensor],
        fp8_meta_tensor.amax_history[0][fp8_tensor],
        fp8_meta_tensor.scale_inv[fp8_tensor],
        otype,
288
        sm_margin,
289
        zero_centered_gamma
Przemek Tredak's avatar
Przemek Tredak committed
290
291
292
    )


293
294
295
296
297
298
299
300
def layernorm_fwd_fp8_inf(
    inp: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    eps: float,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
301
    zero_centered_gamma,
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
) -> torch.Tensor:
    """LayerNorm with FP8 output.

    This version of layernorm_fwd_fp8 is specialized for inference, and returns
    only the normalized output.
    """
    ret = torch.ops.tex_ts.layernorm_fwd_fp8_inf_ts(
        inp,
        weight,
        bias,
        eps,
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
        fp8_tensor,
317
318
        otype,
        zero_centered_gamma)
319
320
321
322
323
324
325
326
    return ret


def layernorm_fwd_inf(
    inp: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    eps: float,
327
    zero_centered_gamma: bool,
328
329
330
331
332
333
334
) -> torch.Tensor:
    """LayerNorm with FP8 output"""
    return torch.ops.tex_ts.layernorm_fwd_inf_ts(
        inp,
        weight,
        bias,
        eps,
335
        zero_centered_gamma,
336
337
338
    )


Przemek Tredak's avatar
Przemek Tredak committed
339
340
341
342
343
344
345
def cast_to_fp8(
    inp: torch.Tensor,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
) -> torch.Tensor:
    """Cast input to FP8"""
346
    return torch.ops.tex_ts.cast_to_fp8_ts(
Przemek Tredak's avatar
Przemek Tredak committed
347
        inp,
348
349
350
351
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
        fp8_tensor,
Przemek Tredak's avatar
Przemek Tredak committed
352
353
354
355
356
357
358
359
360
361
362
363
        otype,
    )


def cast_from_fp8(
    inp: torch.Tensor,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    itype: tex.DType,
    otype: tex.DType,
) -> torch.Tensor:
    """Cast input from FP8"""
364
    return torch.ops.tex_ts.cast_from_fp8_ts(
Przemek Tredak's avatar
Przemek Tredak committed
365
        inp,
366
367
        fp8_meta_tensor.scale_inv,
        fp8_tensor,
Przemek Tredak's avatar
Przemek Tredak committed
368
369
370
        itype,
        otype,
    )