jit.py 13.8 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
#
# See LICENSE for license information.

"""NVFuser functions and JIT utilities"""
6
import os
7
from functools import wraps
ngoyal2707's avatar
ngoyal2707 committed
8
from typing import Callable, Optional, Tuple
Przemek Tredak's avatar
Przemek Tredak committed
9
10
import torch

Paweł Gadziński's avatar
Paweł Gadziński committed
11
from .torch_version import torch_version
12
from .export import is_in_onnx_export_mode
13
from .utils import gpu_autocast_ctx
yuguo's avatar
yuguo committed
14
from torch.utils.cpp_extension import IS_HIP_EXTENSION
Przemek Tredak's avatar
Przemek Tredak committed
15

16
17
# pylint: disable=unnecessary-lambda-assignment

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

def lazy_compile(func):
    """Lazy compile a function with torch.compile

    This decorator defers the compilation of a function until the first call, speeding up the
    overall module's import time if these functions are not used.
    """
    compiled_func = None

    @wraps(func)
    def wrapper(*args, **kwargs):
        nonlocal compiled_func
        if compiled_func is None:
            compiled_func = torch.compile(func)
        return compiled_func(*args, **kwargs)

    return wrapper


37
jit_fuser = lambda func: func
38
if torch_version() >= (2, 0, 0) and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
39
    jit_fuser = lazy_compile
40

41

42
43
# See: https://github.com/NVIDIA/TransformerEngine/issues/597
dropout_fuser = torch.jit.script
44
if torch_version() >= (2, 2, 0) and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
45
    dropout_fuser = lazy_compile
46

47

48
49
# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
50
51
52
if torch.__version__ >= "2":
    import torch._dynamo

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    def no_torch_dynamo(recursive=True):
        """Decorator to disable Torch Dynamo, except during ONNX export."""

        def decorator(f):
            # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
            disabled_f = (
                torch._dynamo.disable(f, recursive=recursive)
                if torch.__version__ >= "2.1"
                else torch._dynamo.disable(f)
            )

            @wraps(f)
            def wrapper(*args, **kwargs):
                if is_in_onnx_export_mode():
                    return f(*args, **kwargs)
                return disabled_f(*args, **kwargs)

            return wrapper

        return decorator

else:
    # Fallback for PyTorch < 2.0: no-op decorator
    def no_torch_dynamo(recursive=True):  # pylint: disable=unused-argument
        """No-op decorator for PyTorch < 2.0."""
        return lambda func: func
79

Przemek Tredak's avatar
Przemek Tredak committed
80
81

def set_jit_fusion_options() -> None:
yuguo's avatar
yuguo committed
82
83
84
    if not IS_HIP_EXTENSION:
        """Set PyTorch JIT layer fusion options."""
        # flags required to enable jit fusion kernels
85
        if torch_version() >= (2, 2, 0):
yuguo's avatar
yuguo committed
86
            pass
87
        elif torch_version() >= (1, 10, 0):
yuguo's avatar
yuguo committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
            # nvfuser
            torch._C._jit_set_profiling_executor(True)
            torch._C._jit_set_profiling_mode(True)
            torch._C._jit_override_can_fuse_on_cpu(False)
            torch._C._jit_override_can_fuse_on_gpu(False)
            torch._C._jit_set_texpr_fuser_enabled(False)
            torch._C._jit_set_nvfuser_enabled(True)
            torch._C._debug_set_autodiff_subgraph_inlining(False)
        else:
            # legacy pytorch fuser
            torch._C._jit_set_profiling_mode(False)
            torch._C._jit_set_profiling_executor(False)
            torch._C._jit_override_can_fuse_on_cpu(True)
            torch._C._jit_override_can_fuse_on_gpu(True)
Przemek Tredak's avatar
Przemek Tredak committed
102
103


104
@jit_fuser
Przemek Tredak's avatar
Przemek Tredak committed
105
106
107
108
109
110
def bias_gelu_fused_(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
    """Bias-GeLU fused"""
    x = inp + bias
    return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))


111
@jit_fuser
ngoyal2707's avatar
ngoyal2707 committed
112
113
114
115
116
117
118
119
def gelu_fused_(inp: torch.Tensor) -> torch.Tensor:
    """
    GeLU fused, this is copy of bias_gelu_fused cause jit fusion doesn't allow conditioning.
    """
    x = inp
    return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))


Przemek Tredak's avatar
Przemek Tredak committed
120
121
122
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
123
@jit_fuser
Przemek Tredak's avatar
Przemek Tredak committed
124
125
126
127
128
129
130
def bgrad_dgelu_fused_(
    grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Bgrad-Dgelu fused"""
    x = inp + bias
    tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
    # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
131
132
133
    ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
        1 + tanh_out
    )
Przemek Tredak's avatar
Przemek Tredak committed
134
135
136
137
138
    dgelu = ff * grad_output
    bgrad = dgelu.sum(dim=0)
    return bgrad, dgelu


139
@jit_fuser
140
def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
ngoyal2707's avatar
ngoyal2707 committed
141
142
143
144
145
146
    """
    Dgelu fused, this is copy of bgrad_dgelu_fused_ cause jit fusion doesn't allow conditioning.
    """
    x = inp
    tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
    # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
147
148
149
    ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
        1 + tanh_out
    )
ngoyal2707's avatar
ngoyal2707 committed
150
151
152
153
    dgelu = ff * grad_output
    return dgelu


154
155
156
@jit_fuser
def l2normalization_fused_(x: torch.Tensor, eps: float) -> torch.Tensor:
    """L2 normalization fused - inference version"""
157
158
    x_fp32 = x.float()
    x_squared = x_fp32.pow(2)
159
160
    l2_norm_squared = x_squared.sum(dim=-1, keepdim=True)
    rsqrt_norm = torch.rsqrt(l2_norm_squared + eps)
161
162
    y_fp32 = x_fp32 * rsqrt_norm
    return y_fp32.to(x.dtype)
163
164
165
166
167


@jit_fuser
def l2normalization_fwd_fused_(x: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor]:
    """L2 normalization fused - training version that returns intermediate values"""
168
169
    x_fp32 = x.float()
    x_squared = x_fp32.pow(2)
170
    l2_norm_squared = x_squared.sum(dim=-1, keepdim=True)
171
172
173
174
    l2_norm_squared_eps = l2_norm_squared + eps
    rsqrt_norm = torch.rsqrt(l2_norm_squared_eps)
    y_fp32 = x_fp32 * rsqrt_norm
    y = y_fp32.to(x.dtype)
175
176
177
178
179
    return y, rsqrt_norm


@jit_fuser
def l2normalization_backward_fused_(
180
181
182
183
    grad_output: torch.Tensor,
    x: torch.Tensor,
    rsqrt_norm: torch.Tensor,
    eps: float,
184
185
) -> torch.Tensor:
    """L2 normalization backward fused"""
186
187
188
189
190
191
192
193
    x_fp32 = x.float()
    grad_output_fp32 = grad_output.float()
    x_dy_sum = (x_fp32 * grad_output_fp32).sum(dim=-1, keepdim=True)
    x_squared = x_fp32.pow(2)
    l2_norm_squared = x_squared.sum(dim=-1, keepdim=True)
    x_norm_squared = l2_norm_squared + eps
    dx_fp32 = rsqrt_norm * (grad_output_fp32 - x_fp32 * x_dy_sum / x_norm_squared)
    return dx_fp32.to(x.dtype)
194
195


Przemek Tredak's avatar
Przemek Tredak committed
196
197
def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
    """Disable native AMP for bias_gelu_fused_"""
198
    with gpu_autocast_ctx(enabled=False):
199
        if bias is not None and bias.numel() != 0:
ngoyal2707's avatar
ngoyal2707 committed
200
201
            return bias_gelu_fused_(inp, bias)
        return gelu_fused_(inp)
Przemek Tredak's avatar
Przemek Tredak committed
202
203
204
205


def bgrad_dgelu_fused(
    grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor
ngoyal2707's avatar
ngoyal2707 committed
206
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
Przemek Tredak's avatar
Przemek Tredak committed
207
    """Disable native AMP for `bgrad_dgelu_fused_`"""
208
    with gpu_autocast_ctx(enabled=False):
209
        if bias is not None and bias.numel() != 0:
ngoyal2707's avatar
ngoyal2707 committed
210
211
            return bgrad_dgelu_fused_(grad_output, inp, bias)
        return None, dgelu_fused_(grad_output, inp)
Przemek Tredak's avatar
Przemek Tredak committed
212
213


214
215
216
217
218
219
220
221
222
223
224
225
226
def l2normalization_fused(x: torch.Tensor, eps: float) -> torch.Tensor:
    """Disable native AMP for l2normalization_fused_ - inference version"""
    with gpu_autocast_ctx(enabled=False):
        return l2normalization_fused_(x, eps)


def l2normalization_fwd_fused(x: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor]:
    """Disable native AMP for l2normalization_fwd_fused_ - training version"""
    with gpu_autocast_ctx(enabled=False):
        return l2normalization_fwd_fused_(x, eps)


def l2normalization_backward_fused(
227
228
229
230
    grad_output: torch.Tensor,
    x: torch.Tensor,
    rsqrt_norm: torch.Tensor,
    eps: float,
231
232
233
234
235
236
) -> torch.Tensor:
    """Disable native AMP for l2normalization_backward_fused_"""
    with gpu_autocast_ctx(enabled=False):
        return l2normalization_backward_fused_(grad_output, x, rsqrt_norm, eps)


Przemek Tredak's avatar
Przemek Tredak committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def bias_dropout_add(
    x: torch.Tensor,
    bias: torch.Tensor,
    residual: torch.Tensor,
    prob: float,
    training: bool,
) -> torch.Tensor:
    """dropout(inp + bias) + residual"""
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training: bool) -> Callable:
    """bias_dropout_add based on training or not"""

    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)

    return _bias_dropout_add


259
@dropout_fuser
Przemek Tredak's avatar
Przemek Tredak committed
260
261
262
263
264
265
266
267
268
269
270
271
def bias_dropout_add_fused_train_(
    x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
    """Jit fused bias_dropout_add for training"""
    return bias_dropout_add(x, bias, residual, prob, True)


def bias_dropout_add_fused_train(
    x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
    """Disable native AMP and enable grad for BDA"""
    with torch.enable_grad():
272
        with gpu_autocast_ctx(enabled=False):
Przemek Tredak's avatar
Przemek Tredak committed
273
274
275
            return bias_dropout_add_fused_train_(x, bias, residual, prob)


276
@dropout_fuser
Przemek Tredak's avatar
Przemek Tredak committed
277
278
279
280
281
282
283
284
285
286
287
def bias_dropout_add_fused_inference_(
    x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
    """Jit fused bias_dropout_add for inference"""
    return bias_dropout_add(x, bias, residual, prob, False)


def bias_dropout_add_fused_inference(
    x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
    """Disable native AMP for BDA"""
288
    with gpu_autocast_ctx(enabled=False):
Przemek Tredak's avatar
Przemek Tredak committed
289
290
291
292
293
294
        return bias_dropout_add_fused_inference_(x, bias, residual, prob)


def warmup_jit_bias_dropout_add(
    hidden_size: int, dtype: torch.dtype, seq_length: int, micro_batch_size: int
) -> None:
295
296
297
298
299
    """Compile BDA JIT function before the main training steps"""

    # Save cuda RNG state to ensure warmup does not affect reproducibility.
    rng_state = torch.cuda.get_rng_state()

300
301
    inp = torch.rand((seq_length, micro_batch_size, hidden_size), dtype=dtype, device="cuda")
    residual = torch.rand((seq_length, micro_batch_size, hidden_size), dtype=dtype, device="cuda")
Przemek Tredak's avatar
Przemek Tredak committed
302
303
304
305
    bias = torch.rand((hidden_size), dtype=dtype, device="cuda")
    dropout_rate = 0.1
    # Warmup JIT fusions with the input grad_enable state of both forward
    # prop and recomputation
306
    for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
Przemek Tredak's avatar
Przemek Tredak committed
307
308
309
310
311
312
        inp.requires_grad = input_grad
        bias.requires_grad = bias_grad
        residual.requires_grad = residual_grad
        for _ in range(5):
            output = bias_dropout_add_fused_train(inp, bias, residual, dropout_rate)
    del bias, inp, residual, output
313

Przemek Tredak's avatar
Przemek Tredak committed
314
    torch.cuda.empty_cache()
315
    torch.cuda.set_rng_state(rng_state)
Przemek Tredak's avatar
Przemek Tredak committed
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331


def warmup_jit_bias_dropout_add_all_dtypes(
    hidden_size: int, seq_length: int, micro_batch_size: int
) -> None:
    """Call `warmup_jit_bias_dropout_add` for all training dtypes"""
    for dtype in [torch.float32, torch.bfloat16, torch.float16]:
        warmup_jit_bias_dropout_add(hidden_size, dtype, seq_length, micro_batch_size)


def warmup_jit_bias_gelu(
    ffn_hidden_size_per_partition: int,
    dtype: torch.dtype,
    seq_length: int,
    micro_batch_size: int,
) -> None:
332
333
334
335
336
    """Compile bias-gelu JIT function before the main training steps"""

    # Save cuda RNG state to ensure warmup does not affect reproducibility.
    rng_state = torch.cuda.get_rng_state()

Przemek Tredak's avatar
Przemek Tredak committed
337
338
    bias = torch.rand(ffn_hidden_size_per_partition, dtype=dtype, device="cuda")
    inp = torch.rand(
339
        (seq_length * micro_batch_size, ffn_hidden_size_per_partition),
Przemek Tredak's avatar
Przemek Tredak committed
340
341
342
343
344
345
346
347
        dtype=dtype,
        device="cuda",
    )
    # Warmup JIT fusions with the input grad_enable state of both forward
    # prop and recomputation
    for bias_grad, input_grad in zip([True, True], [False, True]):
        bias.requires_grad, inp.requires_grad = bias_grad, input_grad
        for _ in range(5):
348
349
350
            _ = bias_gelu_fused_(inp, bias)
            _ = gelu_fused_(inp)
    del bias, inp
Przemek Tredak's avatar
Przemek Tredak committed
351

352
353
354
    torch.cuda.empty_cache()
    torch.cuda.set_rng_state(rng_state)

Przemek Tredak's avatar
Przemek Tredak committed
355
356
357
358
359
360
361

def warmup_jit_bias_gelu_all_dtypes(
    ffn_hidden_size: int, seq_length: int, micro_batch_size: int
) -> None:
    """Call `warmup_jit_bias_gelu` for all training dtypes"""
    for dtype in [torch.float32, torch.bfloat16, torch.float16]:
        warmup_jit_bias_gelu(ffn_hidden_size, dtype, seq_length, micro_batch_size)
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403


def warmup_jit_l2normalization(
    hidden_size: int, dtype: torch.dtype, seq_length: int, micro_batch_size: int
) -> None:
    """Compile L2Normalization JIT function before the main training steps"""

    # Save cuda RNG state to ensure warmup does not affect reproducibility.
    rng_state = torch.cuda.get_rng_state()

    inp = torch.rand(
        (seq_length * micro_batch_size, hidden_size),
        dtype=dtype,
        device="cuda",
    )
    eps = 1e-6
    # Warmup JIT fusions with the input grad_enable state of both forward
    # prop and recomputation
    for input_grad in [False, True]:
        inp.requires_grad = input_grad
        for _ in range(5):
            if input_grad:
                # Test training version that returns intermediate values
                output, rsqrt_norm = l2normalization_fwd_fused_(inp, eps)
                # Test backward pass as well
                grad_out = torch.rand_like(output)
                _ = l2normalization_backward_fused_(grad_out, inp, rsqrt_norm, eps)
            else:
                # Test inference version
                output = l2normalization_fused_(inp, eps)
    del inp, output

    torch.cuda.empty_cache()
    torch.cuda.set_rng_state(rng_state)


def warmup_jit_l2normalization_all_dtypes(
    hidden_size: int, seq_length: int, micro_batch_size: int
) -> None:
    """Call `warmup_jit_l2normalization` for all training dtypes"""
    for dtype in [torch.float32, torch.bfloat16, torch.float16]:
        warmup_jit_l2normalization(hidden_size, dtype, seq_length, micro_batch_size)