base.py 34.3 KB
Newer Older
1
2
3
4
5
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Base modules and utilities for TransformerEngine PyTorch API"""
6
import io
7
8
9
10
import os
import pickle
import warnings
from abc import ABC, abstractmethod
Jan Bielak's avatar
Jan Bielak committed
11
from typing import Generator, Union, Optional, Tuple, Dict, Any, List
12
13
14
15
16
17
18
19
from functools import partial
from contextlib import contextmanager

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

import transformer_engine_extensions as tex
20
from ..export import is_in_onnx_export_mode
21
22
23
from ..fp8 import (
    get_default_fp8_recipe,
    get_fp8_te_dtype,
24
    FP8GlobalStateManager,
25
26
27
28
29
30
    amax_and_scale_update,
)
from ..distributed import (
    gather_along_first_dim,
    is_fp8_activation_recompute_enabled,
    in_fp8_activation_recompute_phase,
31
    get_distributed_world_size,
32
33
34
35
36
37
38
)
from ..cpp_extensions import (
    fp8_cast_transpose_fused,
    fp8_cast_transpose_bgrad_fused,
    cast_to_fp8,
)
from ..constants import dist_group_type
39
from ..float8_tensor import Float8Tensor
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
_cublas_workspace = None
_ub_communicators = None
_NUM_MAX_UB_STREAMS = 3
_amax_reduce_handle_bwd = None


def get_cublas_workspace_size_bytes() -> None:
    """Return 32 MiB if using hopper, 4 MiB for all other architectures."""
    if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
        return 33_554_432
    return 4_194_304


def get_workspace() -> torch.Tensor:
    """Returns workspace for cublas."""
    global _cublas_workspace
    if _cublas_workspace is None:
        _cublas_workspace = torch.empty(
            get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
        )
    return _cublas_workspace

@contextmanager
def _prepare_backward(
    fp8: bool,
    fp8_meta: Dict[str, Any],
    tp_group: dist_group_type,
    tp_size: int,
    name: str = ""
Jan Bielak's avatar
Jan Bielak committed
73
) -> Generator[None, None, None]:
74
75
76
77
78
79
80
81
    """Checks and prep for BWD."""
    if fp8:
        global _amax_reduce_handle_bwd
        if _amax_reduce_handle_bwd is not None:
            _amax_reduce_handle_bwd.wait()
            _amax_reduce_handle_bwd = None

        # Update amax and scale; Skip all setup for global amax reduction
82
        if fp8_meta["recipe"].reduce_amax and get_distributed_world_size(fp8_meta["fp8_group"]) > 1:
83
            # From previous iteration
84
            FP8GlobalStateManager.copy_amax_from_global_buffer(fp8_meta, forward=False)
85
            amax_and_scale_update(fp8_meta, False)
86
            FP8GlobalStateManager.set_amax_buffer_key_deletion(fp8_meta, forward=False)
87
88
89
90

            # Get new backward key.
            fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0)

91
            FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False)
92
93
        else:
            amax_and_scale_update(fp8_meta, False)
94
95
96
97

    with torch.cuda.nvtx.range(name + " backward"):
        yield

98
99
    if (fp8 and fp8_meta["recipe"].reduce_amax
        and get_distributed_world_size(fp8_meta["fp8_group"]) > 1):
100
        if fp8_meta["first_module"]:
101
            _amax_reduce_handle_bwd = FP8GlobalStateManager.global_amax_reduction(
102
103
104
105
106
                fp8_meta,
                tp_group,
                tp_size,
                forward=False
            )
107
            FP8GlobalStateManager.delete_key_from_amax_buffer(forward=False)
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129


def initialize_ub(
    shape: list,
    tp_size: int,
    use_fp8: bool = False,
    ub_cfgs: Optional[dict] = None
) -> None:
    """Initialize communicators for TP comm overlap using userbuffers."""
    global _ub_communicators
    assert _ub_communicators is None, "UB communicators are already initialized."
    _ub_communicators = {}
    rank_id = torch.distributed.get_rank()

    # Increase the workspace by the number of maximum concurrent streams
    global _cublas_workspace
    _cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS)

    # Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
    fp8_buf = [
        "qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad"
    ]
130
131
    if bool(int(os.getenv("NVTE_UB_FP8_RS", "0"))):
        fp8_buf.append ("proj_fprop")
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    # Default overlap methods for layers
    methods = {
        "ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
        "pipeline":["proj_fprop", "fc2_fprop"],
        "bulk":["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
    }

    def get_method(name):
        for method, names in methods.items():
            if name in names:
                return method
        raise KeyError(f"Given layer name {name} does not exist.")

    def add_ub(
        name: str,
        method: str,
        num_sm: int = 16,
        cga_size: int = 2,
        set_sm_margin: int = 0,
        num_splits: int = 4,
        aggregate: int = 0,
    ) -> None:
        dtype = torch.uint8 if (use_fp8 and name in fp8_buf) else torch.bfloat16
        sample_buffer = torch.empty(shape, dtype=dtype, device='cuda')
        if method == 'ring_exchange':
            ub_obj = tex.UbufP2PCommOverlap(
                    sample_buffer,          # Sample userbuffer
                    rank_id,                # Rank id
                    tp_size,                # TP size
161
162
163
                    num_sm,                 # Number of communication SMs
                    cga_size,               # CGA cluster size
                    set_sm_margin,          # Set SM margin
164
165
                    aggregate,              # Aggregate 2X GEMM chunks
                    _NUM_MAX_UB_STREAMS,    # Max concurrent GEMM streams
166
                    torch.Tensor(),         # empty tensor to pass to counters
167
168
169
170
171
172
173
174
175
176
177
                )
        else:
            ub_obj = tex.UbufCommOverlap(
                    sample_buffer,          # Sample userbuffer
                    rank_id,                # Rank id
                    tp_size,                # TP size
                    num_sm,                 # Number of communication SMs
                    cga_size,               # CGA cluster size
                    num_splits,             # Number of communication splits
                    set_sm_margin,          # Set SM margin
                    _NUM_MAX_UB_STREAMS,    # Max concurrent GEMM streams
178
                    torch.Tensor(),         # empty tensor to pass to counters
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
                )
        _ub_communicators[name] = ub_obj

    for name in (methods["ring_exchange"]+methods["pipeline"]+methods["bulk"]):
        if ub_cfgs is not None and name in ub_cfgs:
            ub_cfg = ub_cfgs[name]
            method = ub_cfg["method"] if "method" in ub_cfg else get_method(name)
            num_sm = ub_cfg["num_sm"] if "num_sm" in ub_cfg else 16
            cga_size = ub_cfg["cga_size"] if "cga_size" in ub_cfg else 2
            num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 0
            set_sm_margin = ub_cfg["set_sm_margin"] if "set_sm_margin" in ub_cfg else 0
            aggregate = ub_cfg["aggregate"] if "aggregate" in ub_cfg else 0
            add_ub(
                name,
                method,
                num_sm,
                cga_size,
                set_sm_margin,
                num_splits,
                aggregate
            )
        else:
            method = get_method(name)
            if method == "pipeline":
                add_ub(name, method)
            else:
                add_ub(name, method, num_splits=0)


def get_ub(name: str):
    """Get userbuffer communicator corresponding to give key."""
    global _ub_communicators
    assert _ub_communicators is not None, "UB manager is not initialized."
    assert name in _ub_communicators, f"UB for {name} is not registered."
    return _ub_communicators[name]


class _NoopCat(torch.autograd.Function):
    """This class is a no-op replacement for `torch.cat`."""

    @staticmethod
    def forward(ctx,
                full_param_buffer: torch.Tensor,
                *params_split: Tuple[torch.Tensor, ...],
    ) -> torch.Tensor:
        assert not full_param_buffer.requires_grad, "Buffers should not require gradient"
cyanguwa's avatar
cyanguwa committed
225
        sum_params_shape = sum(p.shape[0] for p in params_split)
226
        assert (
cyanguwa's avatar
cyanguwa committed
227
            full_param_buffer.shape[0] == sum_params_shape
228
229
230
231
232
233
234
235
236
        ), "Dimensions not compatible for concatenation"

        param_temp = full_param_buffer.new()
        param_temp.set_(full_param_buffer.storage(),
                        full_param_buffer.storage_offset(),
                        full_param_buffer.size(),
                        full_param_buffer.stride())
        param_temp.requires_grad = True

cyanguwa's avatar
cyanguwa committed
237
        ctx.save_for_backward(*params_split)
238
239
240
241
        return param_temp

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
cyanguwa's avatar
cyanguwa committed
242
        params_split = ctx.saved_tensors
243
        grads = []
cyanguwa's avatar
cyanguwa committed
244
        slice_begin = 0
245
        for i, _ in enumerate(params_split):
cyanguwa's avatar
cyanguwa committed
246
247
248
249
            slice_size = params_split[i].shape[0]
            slice_end = slice_begin + slice_size
            grads.append(grad_output[slice_begin:slice_end])
            slice_begin = slice_end
250
251
252
253
254
255
256
257
258
259
260
261
262
263

        return None, *grads


class TransformerEngineBaseModule(torch.nn.Module, ABC):
    """Base TE module."""

    def __init__(self) -> None:
        super().__init__()
        assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
        self.fp8_initialized = False
        self.fp8 = False
        self.fp8_calibration = False
        self.fp8_meta = {}
264
        self.fp8_meta["fp8_checkpoint"] = False
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
        self.fp8_meta["fp8_group"] = None
        self.fp8_meta["recipe"] = get_default_fp8_recipe()
        self.fp8_meta_tensors_initialized = False
        self.tp_group = None
        self.tp_size = 1
        self.sequence_parallel = False
        self.fp8_weight_shapes = []
        self.fp8_meta["autocast_id_fwd_stack"] = []
        self.fp8_meta["async_amax_reduction"] = bool(
            int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0"))
        )

    def set_meta_tensor(self, fwd: bool) -> None:
        """Init scales and amaxes for fwd | bwd."""
        fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"

        if self.fp8_meta_tensors_initialized:
            # Handle changed amax history size.
            curr_len = self.fp8_meta[fp8_meta_tensor_key].amax_history.shape[0]
            need_len = self.fp8_meta["recipe"].amax_history_len
            if need_len < curr_len:
                self.fp8_meta[fp8_meta_tensor_key].amax_history = (
                    self.fp8_meta[fp8_meta_tensor_key]
                    .amax_history[: self.fp8_meta["recipe"].amax_history_len].clone()
                )
            elif need_len > curr_len:
                extra_rows = need_len - curr_len
                self.fp8_meta[fp8_meta_tensor_key].amax_history = F.pad(
                    self.fp8_meta[fp8_meta_tensor_key].amax_history, pad=(0, 0, 0, extra_rows)
                )
            return

        # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
        # 2 (grad_output and grad_input) for bwd
        num_fp8_tensors = (
            self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2
        )

        self.fp8_meta[fp8_meta_tensor_key] = tex.FP8TensorMeta()
        self.fp8_meta[fp8_meta_tensor_key].scale = torch.ones(
            num_fp8_tensors, dtype=torch.float32, device="cuda"
        )
        self.fp8_meta[fp8_meta_tensor_key].scale_inv = torch.ones(
            num_fp8_tensors, dtype=torch.float32, device="cuda"
        )
        self.fp8_meta[fp8_meta_tensor_key].amax_history = torch.zeros(
            self.fp8_meta["recipe"].amax_history_len,
            num_fp8_tensors,
            dtype=torch.float32,
            device="cuda",
        )

        # Needed for calculation of scale inverses to
        # preserve scale_inv when caching FP8 weights
        if fwd:
            # [True, False, True]: -> [input, weight, output]
            self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
                [True, False, True] * self.fp8_meta["num_gemms"]
            ).cuda()
        else:
            # [True, True]: -> [grad_output, grad_input]
            self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
                [True, True] * self.fp8_meta["num_gemms"]
            ).cuda()

    def init_fp8_meta_tensors(self) -> None:
        """Init scales and amaxes."""
        self.set_meta_tensor(True)
        self.set_meta_tensor(False)
        self.fp8_meta_tensors_initialized = True

    def get_extra_state(self) -> torch.Tensor:
        """Save before checkpointing."""
        state = None
339

340
        fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
341
342

        if fp8_checkpoint:
343
            state = {}
344
345
346
347
348
349
            state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale
            state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv
            state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history
            state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
            state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv
            state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
350
351
            state["global_fp8_buffer"] = FP8GlobalStateManager.get_global_fp8_buffer_checkpoint()
            state["global_fp8_state"] = FP8GlobalStateManager.get_global_fp8_state_checkpoint()
352
353
354
355

            # Store other pickelable values.
            extra = {}
            for k, v in self.fp8_meta.items():
356
                if isinstance(v, (bool, int, float, str, list)):
357
358
359
                    extra[k] = v
            state["extra_fp8_variables"] = extra

360
361
362
363
364
        if is_in_onnx_export_mode():
            state_serialized = torch.frombuffer(pickle.dumps(state), dtype=torch.uint8)
        else:
            state_serialized = io.BytesIO()
            torch.save(state, state_serialized)
365

366
        return state_serialized
367
368
369
370
371
372
373
374

    def set_extra_state(self, state: torch.Tensor) -> None:
        """Load previous state."""
        if state is None:
            return

        if isinstance(state, torch.Tensor):
            state = pickle.loads(state.detach().cpu().numpy().tobytes())
375
376
377
        elif isinstance(state, io.BytesIO):
            state.seek(0)
            state = torch.load(state, map_location='cuda')
378
379
        else:
            raise RuntimeError("Unsupported checkpoint format.")
380
381
382

        if state is None:
            return
383

384
385
386
        # Restore global FP8 amax buffer.
        FP8GlobalStateManager.set_global_fp8_buffer_checkpoint(state["global_fp8_buffer"])
        # Restore global FP8 state.
387
388
        FP8GlobalStateManager.set_global_fp8_state_checkpoint(state["global_fp8_state"])

389
390
391
392
393
394
395
396
397
398
399
400
        # Load extra items.
        self.fp8_meta.update(state["extra_fp8_variables"])
        self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
        if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta:
            del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"]

        # Initialize before loading.
        self.init_fp8_meta_tensors()
        self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"])
        self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"])
        self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"])
        self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"])
401
402
        self.fp8_meta["scaling_fwd"].scale_inv.copy_(state["scale_inv_fwd"])
        self.fp8_meta["scaling_bwd"].scale_inv.copy_(state["scale_inv_bwd"])
403
404
405
406
407
408
409
410
411

    def set_activation_dtype(self, inp: torch.Tensor) -> None:
        """Get activation data type for AMP."""
        # Native AMP (`torch.autocast`) gets highest priority
        if torch.is_autocast_enabled():
            self.activation_dtype = torch.get_autocast_gpu_dtype()
            return

        # All checks after this have already been performed once, thus skip
412
        if hasattr(self, "activation_dtype") and self.activation_dtype == inp.dtype:
413
414
            return

415
416
417
418
419
420
421
422
423
424
425
426
427
428
        dtype = inp.dtype
        for name, param in self.named_parameters():
            if param is not None:
                assert dtype == param.dtype, (
                    "Data types for parameters must match when outside of autocasted region. "
                    f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
                )
        for name, buf in self.named_buffers():
            if buf is not None:
                assert dtype == buf.dtype, (
                    "Data types for buffers must match when outside of autocasted region. "
                    f" Found input dtype: {dtype} and {name!r} dtype: {buf.dtype}"
                )
        self.activation_dtype = dtype
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454

    def set_fp8_weights(self) -> None:
        """Initializes FP8 weights for the module as class attributes. These
        are not parameters or buffers since we do not want functions such as
        `.to(dtype)` or `.to(device)` to effect them. These also do not need
        to be checkpointed. During `init` phase of the module, the attribute
        `fp8_weight_shapes` must be populated with the tensor shapes for FP8
        weights. This function will iterate over those shapes and initialize
        respective attributed named `weight1_fp8`, `weight2_fp8`, ...
        """
        if not self.fp8:
            return

        for i, shape in enumerate(self.fp8_weight_shapes, start=1):
            weight_cast_attr = f"weight{i}_fp8"
            weight_transpose_attr = f"weight{i}_t_fp8"

            if (
                hasattr(self, weight_cast_attr)
                and getattr(self, weight_cast_attr).shape == shape
            ):
                return

            setattr(
                self,
                weight_cast_attr,
455
456
457
458
459
460
461
462
463
                Float8Tensor(
                    data=torch.empty(
                        shape,
                        device=torch.cuda.current_device(),
                        dtype=torch.uint8,
                    ),
                    fp8_dtype=tex.DType.kFloat8E4M3,
                    fp8_scale_inv=1,
                )
464
465
466
467
            )
            setattr(
                self,
                weight_transpose_attr,
468
469
470
471
472
473
474
475
476
477
                Float8Tensor(
                    data=torch.empty(
                        shape[1],
                        shape[0],
                        device=torch.cuda.current_device(),
                        dtype=torch.uint8,
                    ),
                    fp8_dtype=tex.DType.kFloat8E4M3,
                    fp8_scale_inv=1,
                )
478
479
480
            )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
481
482
483
484
485
486
487
488
489
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
490
491
492
493
494
        self.tp_group = tp_group
        self.tp_group_initialized = True

    # This routine is shared across FP8 and FP8_calibration paths so should not actually
    # assume FP8 execution.
495
    def init_fp8_metadata(self, num_gemms: int = 1) -> None:
496
        """Initialize fp8 related metadata and tensors during fprop."""
497
        self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
498
499
        self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
        self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
500
        self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
501

502
503
504
505
        if self.fp8_parameters and not self.fp8_initialized:
            self.fp8_meta["num_gemms"] = num_gemms
            self.init_fp8_meta_tensors()

506
507
        if self.fp8 or self.fp8_calibration:
            # FP8 init has already been run and recipe is the same, don't do anything.
508
509
            if (self.fp8_initialized
                and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"]):
510
511
512
                return

            # Set FP8, recipe, and other FP8 metadata
513
            self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
514
            self.fp8_meta["num_gemms"] = num_gemms
515
            self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534

            # Set FP8_MAX per tensor according to recipe
            self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
            self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd

            # Allocate scales and amaxes
            self.init_fp8_meta_tensors()
            self.fp8_initialized = True
        else:
            # If fp8 isn't enabled, turn off and return.
            self.fp8_initialized = False
            return

    @contextmanager
    def prepare_forward(
        self,
        inp: torch.Tensor,
        is_first_microbatch: Union[bool, None],
        num_gemms: int = 1,
Jan Bielak's avatar
Jan Bielak committed
535
    ) -> Generator[torch.Tensor, None, None]:
536
537
538
539
540
541
542
543
544
        """Checks and prep for FWD.
        The context manager is needed because there isn't a way for a module to know
        if it's the last FP8 module in the forward autocast. It is useful
        to setup the forward aggregated amax reduction for every module
        just in case. The autocast exit will pick up the most recent one.
        """

        # Activation recomputation is used and this is the second forward phase.
        if self.fp8 and in_fp8_activation_recompute_phase():
545
            FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
546
547
548
549
550
551
552
        else:
            assert inp.is_cuda, "TransformerEngine needs CUDA."

            if self.tp_size > 1:
                assert self.tp_group_initialized, "TP group not initialized."

            self.set_activation_dtype(inp)
553
            self.init_fp8_metadata(num_gemms=num_gemms)
554
555
556
557
558

            # Create persistent tensors for fp8 weights and their transposes
            # only when fp8 weight caching is used.
            if is_first_microbatch is not None:
                self.set_fp8_weights()
559
560
561
562
563
564
565
566
567

            update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
            if self.fp8 and self.sequence_parallel:
                assert self.fp8_meta["recipe"].reduce_amax, \
                "Amax reduction across tensor parallel group is " \
                "necessary when using sequence parallelism with FP8."

            # Previous iteration was grad_enabled
            if self.fp8_meta.get("update_amax_and_scale_fwd", False):
568
569
                if (self.fp8_meta["recipe"].reduce_amax
                    and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1):
570
                    FP8GlobalStateManager.copy_amax_from_global_buffer(self.fp8_meta, forward=True)
571
572
573
                    amax_and_scale_update(
                        self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
                    )
574
                    FP8GlobalStateManager.set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
575
576
577
578
579
580
581
                else:
                    amax_and_scale_update(
                        self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
                    )

            if self.fp8 and self.training:
                # Setup for amax reduction
582
583
                if (self.fp8_meta["recipe"].reduce_amax
                    and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1):
584
                    self.fp8_meta["first_module"] = FP8GlobalStateManager.is_first_fp8_module()
585
586
                    if self.fp8_meta["first_module"]:
                        # Wait for the prior AMAX reduction to finish
587
                        amax_reduce_handle_fwd = FP8GlobalStateManager.get_amax_reduce_handle_fwd()
588
589
                        if amax_reduce_handle_fwd is not None:
                            amax_reduce_handle_fwd.wait()
590
591
592
                        self.fp8_meta["autocast_id_fwd"] = (
                            FP8GlobalStateManager.new_fp8_context_id())
                        FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
593
                    else:
594
595
                        self.fp8_meta["autocast_id_fwd"] = (
                            FP8GlobalStateManager.get_fp8_context_id())
596
597
598
                    self.fp8_meta["autocast_id_fwd_stack"].append(
                        self.fp8_meta["autocast_id_fwd"]
                    )
599
                    FP8GlobalStateManager.add_amax_to_global_buffer(self.fp8_meta, forward=True)
600
601
602
603
604
605
606
607
608
609
                self.fp8_meta["update_amax_and_scale_fwd"] = True
            else:
                self.fp8_meta["update_amax_and_scale_fwd"] = False

            # Activation recomputation is used and this is the first forward phase.
            if (
                self.fp8
                and self.training
                and is_fp8_activation_recompute_enabled()
            ):
610
                FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
611
612
613
614
615

        with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
            yield inp.contiguous()

        if self.fp8 and in_fp8_activation_recompute_phase():
616
            FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
617
618
            return

619
620
        if (self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax
            and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1):
621
            FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
622
            reduce_func = partial(
623
                FP8GlobalStateManager.global_amax_reduction,
624
625
626
627
628
                self.fp8_meta,
                self.tp_group,
                self.tp_size,
                forward=True
            )
629
            FP8GlobalStateManager.setup_amax_forward_global_reduce_func(reduce_func)
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663

    def set_nccl_overlap_warning_if_tp(self) -> None:
        """When using TP, the NCCL communication needs to be scheduled
        before the GEMM for there to be a guaranteed overlap. From the
        host side in TE, the comm calls are always launched first, but
        to ensure that the GEMM isn't scheduled first, the environment
        variable `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to
        force a single channel.
        """
        if self.tp_size == 1:
            return
        num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0"))
        if num_cuda_work_queues != 1:
            warnings.warn(
                "To guarantee overlapping TP and SP collectives with the backward"
                "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1"
            )

    @staticmethod
    def grad_output_preprocess(
        ctx, grad_output: torch.Tensor, row_parallel_mode: bool
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        """Utility function for backward.
        Returns tuple in order (all optional/None based on training precion/recipe):
            R1: gathered `grad_output` in higher precision.
            R2: gathered `grad_output` in FP8.
            R3: R2 transposed.
            R4: bias gradient on R1.

        """
        grad_output = grad_output.contiguous()
        grad_output_mat = grad_output.view((-1, grad_output.shape[-1]))
        gather_grad_output = row_parallel_mode and ctx.sequence_parallel

664
665
        if gather_grad_output:
            ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag
666
667
668
        # No-FP8 case: bgrad is fused with wgrad for this case.
        if not ctx.fp8:
            if gather_grad_output:
669
                if not ub_overlap_ag:
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
                    grad_output_mat, _ = gather_along_first_dim(
                        grad_output_mat, ctx.tp_group
                    )
                else:
                    ctx.ub_obj_gradout.copy_input_to_ubuf(grad_output, True)
                    grad_output_mat = ctx.ub_obj_gradout.get_ubuf_output(1)
            return grad_output_mat, None, None, None

        fp8_dtype_backward = get_fp8_te_dtype(
            ctx.fp8_meta["recipe"], fprop_tensor=False
        )

        # FP8 case with non-FP8 wgrad
        if (
            gather_grad_output
            and ctx.fp8_meta["recipe"].override_linear_precision.wgrad
        ):
            assert (
688
689
                not ub_overlap_ag
            ), "override_linear_precision.wgrad not supported with UB AG overlap"
690
691
692
693
694
695
696
            grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group)
        # FP8 case with gather: unfused bgrad, cast, transpose for efficient gather
        elif gather_grad_output:
            if ctx.use_bias:
                grad_bias = grad_output_mat.sum(dim=0)
            else:
                grad_bias = None
697
            if ub_overlap_ag:
698
699
700
701
702
703
704
705
706
707
                grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0)
            else:
                grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8)
            cast_to_fp8(
                grad_output_mat,
                ctx.fp8_meta["scaling_bwd"],
                tex.FP8BwdTensors.GRAD_OUTPUT1,
                fp8_dtype_backward,
                out=grad_output_c,
            )
708
            if not ub_overlap_ag:
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
                grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group)
                grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
            else:
                grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(1)
                grad_output_t = None

            return grad_output_mat, grad_output_c, grad_output_t, grad_bias

        # FP8 case without gather: cast, transpose, bgrad fused
        if ctx.use_bias:
            grad_bias, grad_output_c, grad_output_t = fp8_cast_transpose_bgrad_fused(
                grad_output_mat,
                ctx.fp8_meta["scaling_bwd"],
                tex.FP8BwdTensors.GRAD_OUTPUT1,
                fp8_dtype_backward,
            )
        else:
            if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
                grad_output_c, grad_output_t = fp8_cast_transpose_fused(
                    grad_output_mat,
                    ctx.fp8_meta["scaling_bwd"],
                    tex.FP8BwdTensors.GRAD_OUTPUT1,
                    fp8_dtype_backward,
                )
            else:
                grad_output_t = None
                grad_output_c = cast_to_fp8(
                    grad_output_mat,
                    ctx.fp8_meta["scaling_bwd"],
                    tex.FP8BwdTensors.GRAD_OUTPUT1,
                    fp8_dtype_backward,
                )
            grad_bias = None

        return grad_output_mat, grad_output_c, grad_output_t, grad_bias

cyanguwa's avatar
cyanguwa committed
745
746
747
748
749
    def noop_cat(self,
        buffer_name: str,
        pnames: List[str],
        parameters_split: Dict[str, int]
        ) -> torch.Tensor:
750
751
752
753
754
755
756
757
758
        """No-op replacement of `torch.cat`. The buffer and split parameters must occupy
           the same memory region. If this is not the case, then the split parameters
           are concatenated and the buffer is overwritten. The parameters' memory is then
           re-assigned to point to the buffer to avoid subsequent concatenations.
        """

        assert hasattr(self, buffer_name), f"No buffer named {buffer_name}"
        full_param_buffer = getattr(self, buffer_name)
        params = [getattr(self, name) for name in pnames]
cyanguwa's avatar
cyanguwa committed
759
        slice_begin = 0
760
        for i, p in enumerate(params):
cyanguwa's avatar
cyanguwa committed
761
762
763
            slice_size = parameters_split[pnames[i].split('_')[0]+'_']
            slice_end = slice_begin + slice_size
            if p.data.data_ptr() != full_param_buffer[slice_begin:slice_end].data_ptr():
764
765
                with torch.no_grad():
                    setattr(self, buffer_name, torch.cat(params))
cyanguwa's avatar
cyanguwa committed
766
767
768
769
                    slice_begin_j = 0
                    for pname in pnames:
                        slice_size_j = parameters_split[pname.split('_')[0]+'_']
                        slice_end_j = slice_begin_j + slice_size_j
770
771
                        full_param_buffer = getattr(self, buffer_name)
                        setattr(self, pname,
cyanguwa's avatar
cyanguwa committed
772
773
                                Parameter(full_param_buffer[slice_begin_j:slice_end_j]))
                        slice_begin_j = slice_end_j
774
                break
cyanguwa's avatar
cyanguwa committed
775
            slice_begin = slice_end
776
777
778

        return _NoopCat.apply(getattr(self, buffer_name), *[getattr(self, name) for name in pnames])

779
780
781
    def get_fp8_weights_empty_tensors(
        self,
        is_first_microbatch: Union[bool, None],
782
    ) -> List[Float8Tensor]:
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
        """
        Returns empty tensors to be later used to store fp8 version of weights
        and their transposes (for the bwd pass) for this batch (or microbatch).
        When `is_first_microbatch` is `None`, this is especially useful since
        we then don't need to store the fp8 weights that are needed for one time
        only in the forward pass. Note that we still need to store the tensor
        for the fp8 weight transpose which is at least needed in the backward
        pass but that's taken care of by storing the transpose tensor in
        `ctx.save_for_backward`.
        """
        assert is_first_microbatch is None, "Should only be here when "\
                                            "`is_first_microbatch` is None!"
        fp8_weight_tensors = []
        for shape in self.fp8_weight_shapes:
            fp8_weight_tensors.append(
798
799
800
801
802
803
804
805
                Float8Tensor(
                    data=torch.empty(
                        shape,
                        device=torch.cuda.current_device(),
                        dtype=torch.uint8,
                    ),
                    fp8_dtype=tex.DType.kFloat8E4M3,
                    fp8_scale_inv=1,
806
807
808
                )
            )
            fp8_weight_tensors.append(
809
810
811
812
813
814
815
816
817
                Float8Tensor(
                    data=torch.empty(
                        shape[1],
                        shape[0],
                        device=torch.cuda.current_device(),
                        dtype=torch.uint8,
                    ),
                    fp8_dtype=tex.DType.kFloat8E4M3,
                    fp8_scale_inv=1,
818
819
820
821
                )
            )
        return fp8_weight_tensors

822
823
824
825
826
827
828
829
830
831
832
833
    def state_dict(self, *args, **kwargs) -> Dict:
        """Get dictionary containing module state"""
        state = super().state_dict(*args, **kwargs)

        # Convert Float8Tensors to plain tensors
        # Note: Float8Tensors don't serialize well, especially if they
        # contain references to FP8 metadata.
        for key, val in state.items():
            if isinstance(val, Float8Tensor):
                state[key] = val.from_float8()

        return state
834

835
836
837
    @abstractmethod
    def forward(self):
        """Needs override."""
838
839
840
841
842
843
844

    @abstractmethod
    def get_fp8_weights_scratchpad(
        self,
        is_first_microbatch: Union[bool, None],
    ) -> List[torch.Tensor]:
        """Needs override."""