jit.py 12.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
#
# See LICENSE for license information.

"""NVFuser functions and JIT utilities"""
6
import os
7
from functools import wraps
ngoyal2707's avatar
ngoyal2707 committed
8
from typing import Callable, Optional, Tuple
Przemek Tredak's avatar
Przemek Tredak committed
9
10
import torch

11
from . import torch_version
12
from .export import is_in_onnx_export_mode
13
from .utils import gpu_autocast_ctx
yuguo's avatar
yuguo committed
14
from torch.utils.cpp_extension import IS_HIP_EXTENSION
Przemek Tredak's avatar
Przemek Tredak committed
15

16
17
# pylint: disable=unnecessary-lambda-assignment

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

def lazy_compile(func):
    """Lazy compile a function with torch.compile

    This decorator defers the compilation of a function until the first call, speeding up the
    overall module's import time if these functions are not used.
    """
    compiled_func = None

    @wraps(func)
    def wrapper(*args, **kwargs):
        nonlocal compiled_func
        if compiled_func is None:
            compiled_func = torch.compile(func)
        return compiled_func(*args, **kwargs)

    return wrapper


37
jit_fuser = lambda func: func
38
if torch_version() >= (2, 0, 0) and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
39
    jit_fuser = lazy_compile
40

41

42
43
# See: https://github.com/NVIDIA/TransformerEngine/issues/597
dropout_fuser = torch.jit.script
44
if torch_version() >= (2, 2, 0) and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
45
    dropout_fuser = lazy_compile
46

47

48
49
# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
50
51
52
53
54
55
56
57
58
59
60
no_torch_dynamo = lambda recursive=True: lambda func: func
if torch.__version__ >= "2":
    import torch._dynamo

    if torch.__version__ >= "2.1":
        no_torch_dynamo = lambda recursive=True: lambda f: (
            f if is_in_onnx_export_mode() else torch._dynamo.disable(f, recursive=recursive)
        )
    else:
        # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
        no_torch_dynamo = lambda recursive=True: torch._dynamo.disable
61

Przemek Tredak's avatar
Przemek Tredak committed
62
63

def set_jit_fusion_options() -> None:
yuguo's avatar
yuguo committed
64
65
66
    if not IS_HIP_EXTENSION:
        """Set PyTorch JIT layer fusion options."""
        # flags required to enable jit fusion kernels
67
        if torch_version() >= (2, 2, 0):
yuguo's avatar
yuguo committed
68
            pass
69
        elif torch_version() >= (1, 10, 0):
yuguo's avatar
yuguo committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
            # nvfuser
            torch._C._jit_set_profiling_executor(True)
            torch._C._jit_set_profiling_mode(True)
            torch._C._jit_override_can_fuse_on_cpu(False)
            torch._C._jit_override_can_fuse_on_gpu(False)
            torch._C._jit_set_texpr_fuser_enabled(False)
            torch._C._jit_set_nvfuser_enabled(True)
            torch._C._debug_set_autodiff_subgraph_inlining(False)
        else:
            # legacy pytorch fuser
            torch._C._jit_set_profiling_mode(False)
            torch._C._jit_set_profiling_executor(False)
            torch._C._jit_override_can_fuse_on_cpu(True)
            torch._C._jit_override_can_fuse_on_gpu(True)
Przemek Tredak's avatar
Przemek Tredak committed
84
85


86
@jit_fuser
Przemek Tredak's avatar
Przemek Tredak committed
87
88
89
90
91
92
def bias_gelu_fused_(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
    """Bias-GeLU fused"""
    x = inp + bias
    return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))


93
@jit_fuser
ngoyal2707's avatar
ngoyal2707 committed
94
95
96
97
98
99
100
101
def gelu_fused_(inp: torch.Tensor) -> torch.Tensor:
    """
    GeLU fused, this is copy of bias_gelu_fused cause jit fusion doesn't allow conditioning.
    """
    x = inp
    return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))


Przemek Tredak's avatar
Przemek Tredak committed
102
103
104
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
105
@jit_fuser
Przemek Tredak's avatar
Przemek Tredak committed
106
107
108
109
110
111
112
def bgrad_dgelu_fused_(
    grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Bgrad-Dgelu fused"""
    x = inp + bias
    tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
    # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
113
114
115
    ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
        1 + tanh_out
    )
Przemek Tredak's avatar
Przemek Tredak committed
116
117
118
119
120
    dgelu = ff * grad_output
    bgrad = dgelu.sum(dim=0)
    return bgrad, dgelu


121
@jit_fuser
122
def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
ngoyal2707's avatar
ngoyal2707 committed
123
124
125
126
127
128
    """
    Dgelu fused, this is copy of bgrad_dgelu_fused_ cause jit fusion doesn't allow conditioning.
    """
    x = inp
    tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
    # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
129
130
131
    ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
        1 + tanh_out
    )
ngoyal2707's avatar
ngoyal2707 committed
132
133
134
135
    dgelu = ff * grad_output
    return dgelu


136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
@jit_fuser
def l2normalization_fused_(x: torch.Tensor, eps: float) -> torch.Tensor:
    """L2 normalization fused - inference version"""
    x_squared = x.pow(2)
    l2_norm_squared = x_squared.sum(dim=-1, keepdim=True)
    rsqrt_norm = torch.rsqrt(l2_norm_squared + eps)
    return x * rsqrt_norm


@jit_fuser
def l2normalization_fwd_fused_(x: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor]:
    """L2 normalization fused - training version that returns intermediate values"""
    x_squared = x.pow(2)
    l2_norm_squared = x_squared.sum(dim=-1, keepdim=True)
    rsqrt_norm = torch.rsqrt(l2_norm_squared + eps)
    y = x * rsqrt_norm
    return y, rsqrt_norm


@jit_fuser
def l2normalization_backward_fused_(
    grad_output: torch.Tensor, x: torch.Tensor, rsqrt_norm: torch.Tensor, eps: float
) -> torch.Tensor:
    """L2 normalization backward fused"""
    x_dy_sum = (x * grad_output).sum(dim=-1, keepdim=True)
    x_norm_squared = x.pow(2).sum(dim=-1, keepdim=True) + eps
    return rsqrt_norm * (grad_output - x * x_dy_sum / x_norm_squared)


Przemek Tredak's avatar
Przemek Tredak committed
165
166
def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
    """Disable native AMP for bias_gelu_fused_"""
167
    with gpu_autocast_ctx(enabled=False):
168
        if bias is not None and bias.numel() != 0:
ngoyal2707's avatar
ngoyal2707 committed
169
170
            return bias_gelu_fused_(inp, bias)
        return gelu_fused_(inp)
Przemek Tredak's avatar
Przemek Tredak committed
171
172
173
174


def bgrad_dgelu_fused(
    grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor
ngoyal2707's avatar
ngoyal2707 committed
175
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
Przemek Tredak's avatar
Przemek Tredak committed
176
    """Disable native AMP for `bgrad_dgelu_fused_`"""
177
    with gpu_autocast_ctx(enabled=False):
178
        if bias is not None and bias.numel() != 0:
ngoyal2707's avatar
ngoyal2707 committed
179
180
            return bgrad_dgelu_fused_(grad_output, inp, bias)
        return None, dgelu_fused_(grad_output, inp)
Przemek Tredak's avatar
Przemek Tredak committed
181
182


183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def l2normalization_fused(x: torch.Tensor, eps: float) -> torch.Tensor:
    """Disable native AMP for l2normalization_fused_ - inference version"""
    with gpu_autocast_ctx(enabled=False):
        return l2normalization_fused_(x, eps)


def l2normalization_fwd_fused(x: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor]:
    """Disable native AMP for l2normalization_fwd_fused_ - training version"""
    with gpu_autocast_ctx(enabled=False):
        return l2normalization_fwd_fused_(x, eps)


def l2normalization_backward_fused(
    grad_output: torch.Tensor, x: torch.Tensor, rsqrt_norm: torch.Tensor, eps: float
) -> torch.Tensor:
    """Disable native AMP for l2normalization_backward_fused_"""
    with gpu_autocast_ctx(enabled=False):
        return l2normalization_backward_fused_(grad_output, x, rsqrt_norm, eps)


Przemek Tredak's avatar
Przemek Tredak committed
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def bias_dropout_add(
    x: torch.Tensor,
    bias: torch.Tensor,
    residual: torch.Tensor,
    prob: float,
    training: bool,
) -> torch.Tensor:
    """dropout(inp + bias) + residual"""
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training: bool) -> Callable:
    """bias_dropout_add based on training or not"""

    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)

    return _bias_dropout_add


225
@dropout_fuser
Przemek Tredak's avatar
Przemek Tredak committed
226
227
228
229
230
231
232
233
234
235
236
237
def bias_dropout_add_fused_train_(
    x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
    """Jit fused bias_dropout_add for training"""
    return bias_dropout_add(x, bias, residual, prob, True)


def bias_dropout_add_fused_train(
    x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
    """Disable native AMP and enable grad for BDA"""
    with torch.enable_grad():
238
        with gpu_autocast_ctx(enabled=False):
Przemek Tredak's avatar
Przemek Tredak committed
239
240
241
            return bias_dropout_add_fused_train_(x, bias, residual, prob)


242
@dropout_fuser
Przemek Tredak's avatar
Przemek Tredak committed
243
244
245
246
247
248
249
250
251
252
253
def bias_dropout_add_fused_inference_(
    x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
    """Jit fused bias_dropout_add for inference"""
    return bias_dropout_add(x, bias, residual, prob, False)


def bias_dropout_add_fused_inference(
    x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
    """Disable native AMP for BDA"""
254
    with gpu_autocast_ctx(enabled=False):
Przemek Tredak's avatar
Przemek Tredak committed
255
256
257
258
259
260
        return bias_dropout_add_fused_inference_(x, bias, residual, prob)


def warmup_jit_bias_dropout_add(
    hidden_size: int, dtype: torch.dtype, seq_length: int, micro_batch_size: int
) -> None:
261
262
263
264
265
    """Compile BDA JIT function before the main training steps"""

    # Save cuda RNG state to ensure warmup does not affect reproducibility.
    rng_state = torch.cuda.get_rng_state()

266
267
    inp = torch.rand((seq_length, micro_batch_size, hidden_size), dtype=dtype, device="cuda")
    residual = torch.rand((seq_length, micro_batch_size, hidden_size), dtype=dtype, device="cuda")
Przemek Tredak's avatar
Przemek Tredak committed
268
269
270
271
    bias = torch.rand((hidden_size), dtype=dtype, device="cuda")
    dropout_rate = 0.1
    # Warmup JIT fusions with the input grad_enable state of both forward
    # prop and recomputation
272
    for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
Przemek Tredak's avatar
Przemek Tredak committed
273
274
275
276
277
278
        inp.requires_grad = input_grad
        bias.requires_grad = bias_grad
        residual.requires_grad = residual_grad
        for _ in range(5):
            output = bias_dropout_add_fused_train(inp, bias, residual, dropout_rate)
    del bias, inp, residual, output
279

Przemek Tredak's avatar
Przemek Tredak committed
280
    torch.cuda.empty_cache()
281
    torch.cuda.set_rng_state(rng_state)
Przemek Tredak's avatar
Przemek Tredak committed
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297


def warmup_jit_bias_dropout_add_all_dtypes(
    hidden_size: int, seq_length: int, micro_batch_size: int
) -> None:
    """Call `warmup_jit_bias_dropout_add` for all training dtypes"""
    for dtype in [torch.float32, torch.bfloat16, torch.float16]:
        warmup_jit_bias_dropout_add(hidden_size, dtype, seq_length, micro_batch_size)


def warmup_jit_bias_gelu(
    ffn_hidden_size_per_partition: int,
    dtype: torch.dtype,
    seq_length: int,
    micro_batch_size: int,
) -> None:
298
299
300
301
302
    """Compile bias-gelu JIT function before the main training steps"""

    # Save cuda RNG state to ensure warmup does not affect reproducibility.
    rng_state = torch.cuda.get_rng_state()

Przemek Tredak's avatar
Przemek Tredak committed
303
304
    bias = torch.rand(ffn_hidden_size_per_partition, dtype=dtype, device="cuda")
    inp = torch.rand(
305
        (seq_length * micro_batch_size, ffn_hidden_size_per_partition),
Przemek Tredak's avatar
Przemek Tredak committed
306
307
308
309
310
311
312
313
        dtype=dtype,
        device="cuda",
    )
    # Warmup JIT fusions with the input grad_enable state of both forward
    # prop and recomputation
    for bias_grad, input_grad in zip([True, True], [False, True]):
        bias.requires_grad, inp.requires_grad = bias_grad, input_grad
        for _ in range(5):
314
315
316
            _ = bias_gelu_fused_(inp, bias)
            _ = gelu_fused_(inp)
    del bias, inp
Przemek Tredak's avatar
Przemek Tredak committed
317

318
319
320
    torch.cuda.empty_cache()
    torch.cuda.set_rng_state(rng_state)

Przemek Tredak's avatar
Przemek Tredak committed
321
322
323
324
325
326
327

def warmup_jit_bias_gelu_all_dtypes(
    ffn_hidden_size: int, seq_length: int, micro_batch_size: int
) -> None:
    """Call `warmup_jit_bias_gelu` for all training dtypes"""
    for dtype in [torch.float32, torch.bfloat16, torch.float16]:
        warmup_jit_bias_gelu(ffn_hidden_size, dtype, seq_length, micro_batch_size)
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369


def warmup_jit_l2normalization(
    hidden_size: int, dtype: torch.dtype, seq_length: int, micro_batch_size: int
) -> None:
    """Compile L2Normalization JIT function before the main training steps"""

    # Save cuda RNG state to ensure warmup does not affect reproducibility.
    rng_state = torch.cuda.get_rng_state()

    inp = torch.rand(
        (seq_length * micro_batch_size, hidden_size),
        dtype=dtype,
        device="cuda",
    )
    eps = 1e-6
    # Warmup JIT fusions with the input grad_enable state of both forward
    # prop and recomputation
    for input_grad in [False, True]:
        inp.requires_grad = input_grad
        for _ in range(5):
            if input_grad:
                # Test training version that returns intermediate values
                output, rsqrt_norm = l2normalization_fwd_fused_(inp, eps)
                # Test backward pass as well
                grad_out = torch.rand_like(output)
                _ = l2normalization_backward_fused_(grad_out, inp, rsqrt_norm, eps)
            else:
                # Test inference version
                output = l2normalization_fused_(inp, eps)
    del inp, output

    torch.cuda.empty_cache()
    torch.cuda.set_rng_state(rng_state)


def warmup_jit_l2normalization_all_dtypes(
    hidden_size: int, seq_length: int, micro_batch_size: int
) -> None:
    """Call `warmup_jit_l2normalization` for all training dtypes"""
    for dtype in [torch.float32, torch.bfloat16, torch.float16]:
        warmup_jit_l2normalization(hidden_size, dtype, seq_length, micro_batch_size)