cpp_extensions.py 13.1 KB
Newer Older
1
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# See LICENSE for license information.

"""TE FP8 extensions and GEMMs"""
from typing import Optional, Tuple, Union
import torch
import transformer_engine_extensions as tex
from .constants import TE_DType


def fp8_gemm(
    A: torch.Tensor,
    A_scale_inv: torch.Tensor,
15
    A_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
Przemek Tredak's avatar
Przemek Tredak committed
16
17
18
    A_dtype: tex.DType,
    B: torch.Tensor,
    B_scale_inv: torch.Tensor,
19
    B_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
Przemek Tredak's avatar
Przemek Tredak committed
20
21
22
    B_dtype: tex.DType,
    out_dtype: torch.dtype,
    workspace: torch.Tensor,
23
    gelu: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
24
25
    accumulate: bool = False,
    out: Optional[torch.Tensor] = None,
26
27
    out_index = None,
    fp8_meta_tensor: tex.FP8TensorMeta = None,
Przemek Tredak's avatar
Przemek Tredak committed
28
29
30
    bias: Optional[torch.Tensor] = None,
    use_bias: bool = False,
    use_split_accumulator: bool = False,
31
    D_dtype: Optional[tex.DType] = None,
32
33
34
    ub_algo: tex.UbufOverlapAlgo = None,
    ub: Union[tex.UbufCommOverlap, tex.UbufP2PCommOverlap] = None,
    extra_output_tensor: torch.Tensor = None,
Przemek Tredak's avatar
Przemek Tredak committed
35
36
37
38
) -> torch.Tensor:
    """TN layout GEMM with fp8 inputs."""

    empty_tensor = torch.Tensor()
39
40
    if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
        assert fp8_meta_tensor is not None and out_index is not None
Przemek Tredak's avatar
Przemek Tredak committed
41
42
43
44
45
46

    return_output = False
    if out is None:
        out = torch.empty(
            B.shape[0],
            A.shape[0],
47
            dtype=out_dtype,
Przemek Tredak's avatar
Przemek Tredak committed
48
49
50
            device="cuda",
        )
        return_output = True
51
52
53
54
55
56
57
    # Use bfloat16 as default bias_dtype
    bias_dtype = torch.bfloat16 if bias is None else bias.dtype
    if gelu:
        gelu_input = torch.empty_like(out, dtype=bias_dtype)
    else:
        gelu_input = empty_tensor
    bias_dtype = TE_DType[bias_dtype]
Przemek Tredak's avatar
Przemek Tredak committed
58

59
    out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype
Przemek Tredak's avatar
Przemek Tredak committed
60

61
    args = (
Przemek Tredak's avatar
Przemek Tredak committed
62
63
        A,
        A_scale_inv,
64
        A_fp8_tensor,
Przemek Tredak's avatar
Przemek Tredak committed
65
66
67
68
        A_dtype,
        True,  # transa
        B,
        B_scale_inv,
69
        B_fp8_tensor,
Przemek Tredak's avatar
Przemek Tredak committed
70
71
72
        B_dtype,
        False,  # transb
        out,
73
        empty_tensor if out_index is None else fp8_meta_tensor.scale[out_index],
Przemek Tredak's avatar
Przemek Tredak committed
74
        out_dtype,
75
        empty_tensor if out_index is None else fp8_meta_tensor.amax_history[0][out_index],
Przemek Tredak's avatar
Przemek Tredak committed
76
        bias if use_bias else empty_tensor,
77
        bias_dtype,
78
        gelu_input,  # this is pre_gelu_out
Przemek Tredak's avatar
Przemek Tredak committed
79
80
81
82
        False,  # grad
        workspace,
        workspace.shape[0],
        accumulate,
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        use_split_accumulator)
    fn = torch.ops.tex_ts.te_gemm_ts
    if ub_algo is not None:
        assert ub is not None, 'ub object is None!'
        if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
            fn = ub.bulk_overlap
            args = tuple(args + (1,))
        elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
            fn = ub.bulk_overlap
            args = tuple(args + (0,))
        elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
            fn = ub.split_overlap_ag
            extra_output_tensor = (
                empty_tensor if extra_output_tensor is None else extra_output_tensor
            )
            args = tuple(args + (extra_output_tensor,))
        elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS:
            fn = ub.split_overlap_rs
            assert (
                extra_output_tensor is not None
            ), 'SPLIT_PIPELINED_RS requires extra output tensor'
            args = tuple(args + (True, extra_output_tensor,))
    _ = fn(*args)
Przemek Tredak's avatar
Przemek Tredak committed
106
107

    if return_output:
108
109
        if gelu:
            return out, gelu_input
Przemek Tredak's avatar
Przemek Tredak committed
110
        return out
111
112
    if gelu:
        return gelu_input
Przemek Tredak's avatar
Przemek Tredak committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    return None


def gemm(
    A: torch.Tensor,
    B: torch.Tensor,
    dtype: torch.dtype,
    workspace: torch.Tensor,
    gelu: bool = False,
    gelu_input: Optional[torch.Tensor] = None,
    grad: bool = False,
    accumulate: bool = False,
    layout: str = "TN",
    out: Optional[torch.Tensor] = None,
    bias: Optional[torch.Tensor] = None,
    use_bias: bool = False,
129
130
131
    ub_algo: tex.UbufOverlapAlgo = None,
    ub: tex.UbufCommOverlap = None,
    extra_output_tensor: torch.Tensor = None,
Przemek Tredak's avatar
Przemek Tredak committed
132
133
134
135
136
137
138
) -> Tuple[Union[torch.Tensor, None], ...]:
    """Non FP8 GEMM."""

    assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
    transa = layout[0] == "T"
    transb = layout[1] == "T"
    empty_tensor = torch.Tensor()
139
    fp8_index = -1 # dummy index
Przemek Tredak's avatar
Przemek Tredak committed
140
141
142
143
144
145

    return_output = False
    if out is None:
        out = torch.empty(
            B.shape[1] if transb else B.shape[0],
            A.shape[0] if transa else A.shape[1],
146
            dtype=dtype,
Przemek Tredak's avatar
Przemek Tredak committed
147
148
149
150
151
152
153
154
155
156
            device="cuda",
        )
        return_output = True

    if gelu and not grad:
        gelu_input = torch.empty_like(out, dtype=dtype)
    elif not gelu:
        gelu_input = empty_tensor

    if grad and use_bias:
157
        grad_bias = torch.empty(B.shape[1], dtype=out.dtype, device="cuda")
Przemek Tredak's avatar
Przemek Tredak committed
158
159
160
161
162
    else:
        grad_bias = empty_tensor

    bias = bias if use_bias else empty_tensor

163
164
165
166
167
168
169
170
171
    assert A.dtype == dtype and B.dtype == dtype, \
        f'Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}'
    input_dtype = TE_DType[dtype]
    output_dtype = TE_DType[out.dtype]
    if use_bias:
        bias_dtype = TE_DType[grad_bias.dtype] if grad else TE_DType[bias.dtype]
    else:
        bias_dtype = output_dtype

172
    args = (
Przemek Tredak's avatar
Przemek Tredak committed
173
174
        A,
        empty_tensor,
175
        fp8_index,
Przemek Tredak's avatar
Przemek Tredak committed
176
177
178
179
        input_dtype,
        transa,
        B,
        empty_tensor,
180
        fp8_index,
Przemek Tredak's avatar
Przemek Tredak committed
181
182
183
        input_dtype,
        transb,
        out,
184
        empty_tensor, # out_scale
Przemek Tredak's avatar
Przemek Tredak committed
185
        output_dtype,
186
        empty_tensor, # out_amax
Przemek Tredak's avatar
Przemek Tredak committed
187
        grad_bias if grad else bias,
188
        bias_dtype,
Przemek Tredak's avatar
Przemek Tredak committed
189
190
191
192
193
194
195
        gelu_input,
        grad,
        workspace,
        workspace.shape[0],
        accumulate,
        False,  # use_split_accumulator
    )
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    fn = torch.ops.tex_ts.te_gemm_ts
    if ub_algo is not None:
        assert ub is not None, 'ub object is None!'
        if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
            fn = ub.bulk_overlap
            args = tuple(args + (1,))
        elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
            fn = ub.bulk_overlap
            args = tuple(args + (0,))
        elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
            fn = ub.split_overlap_ag
            extra_output_tensor = (
                empty_tensor if extra_output_tensor is None else extra_output_tensor
            )
            args = tuple(args + (extra_output_tensor,))
        elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS:
            fn = ub.split_overlap_rs
            assert (
                extra_output_tensor is not None
            ), 'SPLIT_PIPELINED_RS requires extra output tensor'
            args = tuple(args + (False, extra_output_tensor,))
    _ = fn(*args)
Przemek Tredak's avatar
Przemek Tredak committed
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272

    if return_output:
        return out, grad_bias, gelu_input
    return None, grad_bias, gelu_input


def fp8_cast_transpose_fused(
    inp: torch.Tensor,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
    cast_out: Optional[torch.Tensor] = None,
    transpose_out: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor], None]:
    """Cast + Transpose with FP8 output"""

    return_outputs = False
    if cast_out is None or transpose_out is None:
        cast_out = torch.empty_like(inp, dtype=torch.int8)
        transpose_out = torch.empty(
            inp.shape[1], inp.shape[0], device="cuda", dtype=torch.int8
        )
        return_outputs = True

    tex.fused_cast_transpose(
        inp,
        fp8_meta_tensor.scale[fp8_tensor],
        fp8_meta_tensor.amax_history[0][fp8_tensor],
        fp8_meta_tensor.scale_inv[fp8_tensor],
        cast_out,
        transpose_out,
        otype,
    )

    if return_outputs:
        return cast_out, transpose_out
    return None


def fp8_cast_transpose_bgrad_fused(
    inp: torch.Tensor,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Cast + Transpose + BGRAD with FP8 output"""
    return tex.fused_cast_transpose_bgrad(
        inp,
        fp8_meta_tensor.scale[fp8_tensor],
        fp8_meta_tensor.amax_history[0][fp8_tensor],
        fp8_meta_tensor.scale_inv[fp8_tensor],
        otype,
    )


273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
def fp8_transpose_bgrad_fused(
    inp: torch.Tensor,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
    grad_bias_type: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Transpose + BGRAD with FP8 output"""
    return tex.fused_fp8_transpose_bgrad(
        inp,
        fp8_meta_tensor.scale[fp8_tensor],
        fp8_meta_tensor.amax_history[0][fp8_tensor],
        fp8_meta_tensor.scale_inv[fp8_tensor],
        otype,
        TE_DType[grad_bias_type],
    )


Przemek Tredak's avatar
Przemek Tredak committed
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
def fp8_cast_transpose_bgrad_dgelu_fused(
    grad_output: torch.Tensor,
    gelu_input: torch.Tensor,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Cast + Transpose + BGRAD + DGELU with FP8 output"""
    return tex.fused_cast_transpose_bgrad_dgelu(
        grad_output,
        gelu_input,
        fp8_meta_tensor.scale[fp8_tensor],
        fp8_meta_tensor.amax_history[0][fp8_tensor],
        fp8_meta_tensor.scale_inv[fp8_tensor],
        otype,
    )


def fp8_gelu(
    inp: torch.Tensor,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
) -> torch.Tensor:
    """GeLU with FP8 output"""
316
    return torch.ops.tex_ts.fp8_gelu_ts(
Przemek Tredak's avatar
Przemek Tredak committed
317
        inp,
318
319
320
321
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
        fp8_tensor,
Przemek Tredak's avatar
Przemek Tredak committed
322
323
324
325
326
327
328
329
330
331
332
333
        otype,
    )


def layernorm_fwd_fp8(
    inp: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    eps: float,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
334
    sm_margin: int,
335
336
    zero_centered_gamma: bool,
    ln_out: Optional[torch.Tensor] = None,
Przemek Tredak's avatar
Przemek Tredak committed
337
338
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """LayerNorm with FP8 output"""
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
    if ln_out is not None:
        return tex.layernorm_fwd_fp8_noalloc(
            inp,
            weight,
            bias,
            eps,
            fp8_meta_tensor.scale[fp8_tensor],
            ln_out,
            fp8_meta_tensor.amax_history[0][fp8_tensor],
            fp8_meta_tensor.scale_inv[fp8_tensor],
            otype,
            sm_margin,
            zero_centered_gamma
        )

Przemek Tredak's avatar
Przemek Tredak committed
354
355
356
357
358
359
360
361
362
    return tex.layernorm_fwd_fp8(
        inp,
        weight,
        bias,
        eps,
        fp8_meta_tensor.scale[fp8_tensor],
        fp8_meta_tensor.amax_history[0][fp8_tensor],
        fp8_meta_tensor.scale_inv[fp8_tensor],
        otype,
363
        sm_margin,
364
        zero_centered_gamma
Przemek Tredak's avatar
Przemek Tredak committed
365
366
367
    )


368
369
370
371
372
373
374
375
def layernorm_fwd_fp8_inf(
    inp: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    eps: float,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
376
    zero_centered_gamma,
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
) -> torch.Tensor:
    """LayerNorm with FP8 output.

    This version of layernorm_fwd_fp8 is specialized for inference, and returns
    only the normalized output.
    """
    ret = torch.ops.tex_ts.layernorm_fwd_fp8_inf_ts(
        inp,
        weight,
        bias,
        eps,
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
        fp8_tensor,
392
393
        otype,
        zero_centered_gamma)
394
395
396
397
398
399
400
401
    return ret


def layernorm_fwd_inf(
    inp: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    eps: float,
402
    zero_centered_gamma: bool,
403
404
405
406
407
408
409
) -> torch.Tensor:
    """LayerNorm with FP8 output"""
    return torch.ops.tex_ts.layernorm_fwd_inf_ts(
        inp,
        weight,
        bias,
        eps,
410
        zero_centered_gamma,
411
412
413
    )


Przemek Tredak's avatar
Przemek Tredak committed
414
415
416
417
418
def cast_to_fp8(
    inp: torch.Tensor,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
419
420
    out: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
Przemek Tredak's avatar
Przemek Tredak committed
421
    """Cast input to FP8"""
422
423
424
425
426
427
428
429
430
431
432

    if out is not None:
        tex.cast_to_fp8_noalloc(
            inp,
            fp8_meta_tensor.scale[fp8_tensor],
            out,
            fp8_meta_tensor.amax_history[0][fp8_tensor],
            fp8_meta_tensor.scale_inv[fp8_tensor],
            otype
        )
        return None
433
    return torch.ops.tex_ts.cast_to_fp8_ts(
Przemek Tredak's avatar
Przemek Tredak committed
434
        inp,
435
436
437
438
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
        fp8_tensor,
Przemek Tredak's avatar
Przemek Tredak committed
439
440
441
442
443
444
445
446
447
448
449
450
        otype,
    )


def cast_from_fp8(
    inp: torch.Tensor,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    itype: tex.DType,
    otype: tex.DType,
) -> torch.Tensor:
    """Cast input from FP8"""
451
    return torch.ops.tex_ts.cast_from_fp8_ts(
Przemek Tredak's avatar
Przemek Tredak committed
452
        inp,
453
454
        fp8_meta_tensor.scale_inv,
        fp8_tensor,
Przemek Tredak's avatar
Przemek Tredak committed
455
456
457
        itype,
        otype,
    )