base.py 22.5 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.
"""Base modules and utilities for TransformerEngine Paddle API"""

from abc import ABC, abstractmethod
from contextlib import contextmanager
8
import os
9
import pickle
10
from typing import Generator, Dict, Tuple, Union, Any, List, Optional
11
12

import numpy as np
13
14

import paddle
15

16
17
18
19
20
21
try:
    from paddle.base import core
    from paddle.base.framework import _dygraph_tracer
except ImportError:
    from paddle.fluid import core
    from paddle.fluid.framework import _dygraph_tracer
22

23
from ..constants import FP8FwdTensors, FP8BwdTensors, dist_group_type
24
from ..cpp_extensions import cast_transpose, cast_transpose_bgrad, cast_to_fp8, transpose
25
from ..fp8 import (
26
27
    FP8State,
    FP8TensorMeta,
28
    amax_and_scale_update,
29
    get_global_fp8_state,
30
31
    get_fp8_te_dtype,
)
32
from ..distributed import allgather, register_pp_fwd_begin_hook, is_pp_enabled
33
from ..profile import nvtx_range
Tian Zheng's avatar
Tian Zheng committed
34
35
from ..recompute import is_in_recompute_phase
from ..fp8_buffer import FP8RecomputeBuffer
36

37
38
39
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
_cublas_workspace = None


def get_cublas_workspace_size_bytes() -> None:
    """Return 32 MiB if using hopper, 4 MiB for all other architectures."""
    if paddle.device.cuda.get_device_capability()[0] >= 9:
        return 33_554_432
    return 4_194_304


def get_workspace() -> paddle.Tensor:
    """Returns workspace for cublas."""
    global _cublas_workspace
    if _cublas_workspace is None:
        _cublas_workspace = paddle.empty(
            [get_cublas_workspace_size_bytes()],
56
            dtype="uint8",
57
58
59
60
61
62
63
64
65
        )
    return _cublas_workspace


class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
    """Base TE Layer."""

    def __init__(self) -> None:
        super().__init__()
66
        assert "gpu" in paddle.device.get_device(), "TransformerEngine needs CUDA."
67
68
69
70
71
        self.fp8_initialized = False
        self.fp8_enabled = False
        self.fp8_calibration = False
        self.fp8_meta = {}
        self.fp8_meta["fp8_checkpoint"] = False
72
        self.fp8_meta["fp8_group"] = None
73
        self.fp8_meta["recipe"] = FP8State.get_default_fp8_recipe()
74
75
        self.fp8_meta["scaling_fwd"] = FP8TensorMeta(is_forward=True)
        self.fp8_meta["scaling_bwd"] = FP8TensorMeta(is_forward=False)
76
77
        self.tp_group = None
        self.tp_size = 1
78
        self.sequence_parallel = False
79
80
        self.fp8_meta["autocast_id_fwd_stack"] = []
        self.fp8_meta["async_amax_reduction"] = bool(
81
82
            int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0"))
        )
83
84
        # weights that stored in fp16 would be cast into fp8 every first microstep
        self.fp8_weights = []
85
        self.fp8_weight_cache = {}
86
87
88
89
90
91
92
93
94
95
        self.registered_pp_start_callback = False

        self.current_step_id = paddle.to_tensor([1], dtype=paddle.int32, place=paddle.CPUPlace())

        def current_step_id_callback(step_id=None, **kwargs):  # pylint: disable=unused-argument
            self.current_step_id.copy_(
                paddle.to_tensor([step_id], dtype=paddle.int32, place=paddle.CPUPlace()), True
            )

        register_pp_fwd_begin_hook(current_step_id_callback)
96
97
98
99
100

    def set_activation_dtype(self, inp: paddle.Tensor) -> None:
        """Get activation data type for AMP."""
        tracer = _dygraph_tracer()
        if tracer and tracer._amp_level != core.AmpLevel.O0:
101
            # Set activation_dtype to the Paddle AMP dtype if under 'paddle.amp.auto_cast' context
102
            if tracer._amp_dtype == "float32":
103
                self.activation_dtype = paddle.float32
104
            elif tracer._amp_dtype == "bfloat16":
105
                self.activation_dtype = paddle.bfloat16
106
            elif tracer._amp_dtype == "float16":
107
108
109
                self.activation_dtype = paddle.float16
            else:
                raise RuntimeError(f"AMP format {tracer._amp_dtype} is not supported.")
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        else:
            # If not under paddle.amp.auto_cast, set activation_dtype to the input dtype.
            # Also, make sure the parameters match the input dtype.

            # Skip the check if activation_dtype is already set and if activation_dtype
            # matches input dtype. If they do not match, e.g, when user switch from AMP
            # training to normal training, activation_dtype will still be updated.
            if hasattr(self, "activation_dtype") and self.activation_dtype == inp.dtype:
                return

            dtype = inp.dtype

            for name, param in self.named_parameters():
                if param is not None:
                    assert dtype == param.dtype, (
                        "Data types for parameters must match when outside of autocasted region. "
126
127
                        f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
                    )
128
129
130
131
132
133
134

            self.activation_dtype = dtype

    # This routine is shared across FP8 and FP8_calibration paths so should not actually
    # assume FP8 execution.
    def fp8_init(self, num_gemms: int = 1) -> None:
        """Initialize fp8 related metadata and tensors during fprop."""
135
136
137
        global_fp8_state = get_global_fp8_state()
        self.fp8_enabled = global_fp8_state.is_fp8_enabled()
        self.fp8_calibration = global_fp8_state.is_fp8_calibration()
138
139
140
141
        self.fp8_meta["fp8_checkpoint"] = self.fp8_enabled or self.fp8_calibration

        if self.fp8_enabled or self.fp8_calibration:
            # FP8 init has already been run and recipe is the same, don't do anything.
142
143
144
145
            if (
                self.fp8_initialized
                and global_fp8_state.get_fp8_recipe() == self.fp8_meta["recipe"]
            ):
146
147
148
                return

            # Set FP8, recipe, and other FP8 metadata
149
150
            self.fp8_meta["recipe"] = global_fp8_state.get_fp8_recipe()
            self.fp8_meta["fp8_group"] = global_fp8_state.get_fp8_group()
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165

            # Set FP8_MAX per tensor according to recipe
            self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
            self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd

            # Allocate scales and amaxes
            amax_history_len = self.fp8_meta["recipe"].amax_history_len
            self.fp8_meta["scaling_fwd"].prepare(num_gemms, amax_history_len)
            self.fp8_meta["scaling_bwd"].prepare(num_gemms, amax_history_len)
            self.fp8_initialized = True
        else:
            # If fp8 isn't enabled, turn off and return.
            self.fp8_initialized = False
            return

166
167
168
169
170
    def set_fp8_weights(self) -> None:
        """Initializes FP8 weights for the module"""
        if not self.fp8_enabled:
            return

171
        for i, weight in enumerate(self.fp8_weights, start=1):
172
173
174
            weight_cast_key = f"weight{i}_fp8"
            weight_transpose_key = f"weight{i}_t_fp8"

175
176
            if (
                weight_cast_key in self.fp8_weight_cache
177
                and self.fp8_weight_cache[weight_cast_key].shape == weight.shape
178
            ):
179
180
181
                return

            self.fp8_weight_cache[weight_cast_key] = paddle.empty(
182
                shape=weight.shape,
183
184
185
186
                dtype=paddle.uint8,
            )

            self.fp8_weight_cache[weight_transpose_key] = paddle.empty(
187
                shape=[weight.shape[1], weight.shape[0]],
188
189
190
                dtype=paddle.uint8,
            )

191
192
193
194
195
196
197
    def _get_fp8_state(self) -> paddle.Tensor:
        """Dump FP8 state to paddle.Tensor."""
        state = None
        if self.fp8_meta["fp8_checkpoint"]:
            state = {}
            state["scaling_fwd"] = self.fp8_meta["scaling_fwd"].to_numpy()
            state["scaling_bwd"] = self.fp8_meta["scaling_bwd"].to_numpy()
198
199
            state["global_fp8_fwd_buffer"] = get_global_fp8_state().get_fp8_fwd_buffer().to_numpy()
            state["global_fp8_bwd_buffer"] = get_global_fp8_state().get_fp8_bwd_buffer().to_numpy()
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
            # Store other pickelable values.
            extra = {}
            for k, v in self.fp8_meta.items():
                if isinstance(v, (bool, int, float, str)):
                    extra[k] = v
            state["extra_fp8_variables"] = extra

        state_serialized = pickle.dumps(state)
        state_tensor = paddle.to_tensor(np.frombuffer(state_serialized, dtype=np.uint8))

        return state_tensor

    @paddle.no_grad()
    def state_dict(
        self,
        destination=None,
        include_sublayers=True,
        structured_name_prefix="",
        use_hook=True,
    ):
        """Save FP8 State when checkpointing."""
        st = super().state_dict(
            destination=destination,
            include_sublayers=include_sublayers,
            structured_name_prefix=structured_name_prefix,
            use_hook=use_hook,
        )
        st["fp8_state"] = self._get_fp8_state()
        return st

    def _set_fp8_state(self, state: paddle.Tensor) -> None:
        """Load previous state."""
        if state is None:
233
234
            return

235
236
        state = pickle.loads(state.numpy().tobytes())
        if state is None:
237
238
            return

239
240
241
        # Load fp8 meta tensors.
        self.fp8_meta["scaling_fwd"].from_numpy(state["scaling_fwd"])
        self.fp8_meta["scaling_bwd"].from_numpy(state["scaling_bwd"])
242

243
244
245
246
247
248
        # Restore global FP8 buffer states.
        global_fp8_fwd_buffer = get_global_fp8_state().get_fp8_fwd_buffer()
        global_fp8_bwd_buffer = get_global_fp8_state().get_fp8_bwd_buffer()
        global_fp8_fwd_buffer.from_numpy(state["global_fp8_fwd_buffer"])
        global_fp8_bwd_buffer.from_numpy(state["global_fp8_bwd_buffer"])

249
250
251
        # Load extra items.
        self.fp8_meta.update(state["extra_fp8_variables"])
        self.fp8_meta["recipe"].amax_history_len = self.fp8_meta["scaling_fwd"].amax_history.shape[
252
253
            0
        ]
Tian Zheng's avatar
Tian Zheng committed
254
255
256
        recompute_buffer_pos_key = FP8RecomputeBuffer.get_buffer_position_key()
        if recompute_buffer_pos_key in self.fp8_meta:
            del self.fp8_meta[recompute_buffer_pos_key]
257

258
259
260
261
262
263
264
    @paddle.no_grad()
    def set_state_dict(self, state_dict, use_structured_name=True):
        """Restore FP8 State from checkpoint."""
        fp8_state_tensor = state_dict.pop("fp8_state")
        self._set_fp8_state(fp8_state_tensor)

        return super().set_state_dict(state_dict)
265
266
267
268
269

    @contextmanager
    def prepare_forward(
        self,
        inp: paddle.Tensor,
270
        is_first_microbatch: Union[bool, None],
271
272
273
274
275
276
277
        num_gemms: int = 1,
    ) -> Generator[paddle.Tensor, None, None]:
        """Checks and prep for FWD.
        The context manager is needed because there isn't a way for a module to know
        if it's the last FP8 module in the forward autocast. It is useful
        to setup the forward aggregated amax reduction for every module
        just in case. The autocast exit will pick up the most recent one.
278
279
        """

Tian Zheng's avatar
Tian Zheng committed
280
281
282
        if self.fp8_enabled and is_in_recompute_phase():
            global_recompute_buffer = get_global_fp8_state().get_fp8_recompute_buffer()
            global_recompute_buffer.retrieve_fp8_meta_tensors(self.fp8_meta)
283
        else:
Tian Zheng's avatar
Tian Zheng committed
284
285
286
            self.set_activation_dtype(inp)
            self.fp8_init(num_gemms=num_gemms)

287
288
289
290
291
292
            # Create persistent tensors for fp8 weights and their transposes
            # only when fp8 weight caching is used.
            if is_first_microbatch is not None:
                self.set_fp8_weights()

            if self.fp8_enabled and self.sequence_parallel:
293
294
295
296
                assert self.fp8_meta["recipe"].reduce_amax, (
                    "Amax reduction across tensor parallel group is "
                    "necessary when using sequence parallelism with FP8."
                )
297
298
299

            update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch

Tian Zheng's avatar
Tian Zheng committed
300
301
302
303
304
305
            # Previous iteration was grad_enabled
            if self.fp8_meta.get("update_amax_and_scale_fwd", False):
                global_fp8_fwd_buffer = get_global_fp8_state().get_fp8_fwd_buffer()
                global_fp8_fwd_buffer.wait()
                if self.fp8_meta["recipe"].reduce_amax:
                    global_fp8_fwd_buffer.copy_amax_from_buffer(self.fp8_meta)
306
                    amax_and_scale_update(
307
308
309
310
311
                        self.fp8_meta,
                        fwd_update=True,
                        update_weight_scale_inv=update_weight_scale_inv,
                        current_step_id_tensor=self.current_step_id,
                        use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(),
312
                    )
Tian Zheng's avatar
Tian Zheng committed
313
314
                    global_fp8_fwd_buffer.set_for_deletion(self.fp8_meta)
                else:
315
                    amax_and_scale_update(
316
317
318
319
320
                        self.fp8_meta,
                        fwd_update=True,
                        update_weight_scale_inv=update_weight_scale_inv,
                        current_step_id_tensor=self.current_step_id,
                        use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(),
321
                    )
Tian Zheng's avatar
Tian Zheng committed
322
323
324
325
326
327
328
329
330
331
332
333
334

            if self.fp8_enabled and self.training:
                # Setup for amax reduction
                if self.fp8_meta["recipe"].reduce_amax:
                    global_fp8_state = get_global_fp8_state()
                    self.fp8_meta["first_module"] = global_fp8_state.is_first_fp8_module()
                    self.fp8_meta["autocast_id_fwd"] = global_fp8_state.get_autocast_id()
                    self.fp8_meta["autocast_id_fwd_stack"].append(self.fp8_meta["autocast_id_fwd"])
                self.fp8_meta["update_amax_and_scale_fwd"] = True
            else:
                self.fp8_meta["update_amax_and_scale_fwd"] = False

            # Activation recomputation is used and this is the first forward phase.
335
336
337
338
339
            if (
                self.fp8_enabled
                and self.training
                and get_global_fp8_state().is_fp8_recompute_enabled()
            ):
Tian Zheng's avatar
Tian Zheng committed
340
341
                global_recompute_buffer = get_global_fp8_state().get_fp8_recompute_buffer()
                global_recompute_buffer.stash_fp8_meta_tensors(self.fp8_meta)
342
343
344
345

        with nvtx_range(self.__class__.__name__ + " forward"):
            yield inp

Tian Zheng's avatar
Tian Zheng committed
346
347
348
349
        if self.fp8_enabled and is_in_recompute_phase():
            FP8RecomputeBuffer.restore_fp8_meta_tensors(self.fp8_meta)
            return

350
351
352
353
354
355
356
357
358
359
        if self.fp8_enabled and self.training and self.fp8_meta["recipe"].reduce_amax:
            global_fp8_state = get_global_fp8_state()
            global_fp8_fwd_buffer = global_fp8_state.get_fp8_fwd_buffer()
            global_fp8_fwd_buffer.add_amax(self.fp8_meta)
            global_fp8_fwd_buffer.set_for_amax_reduction(
                self.fp8_meta,
                self.tp_group,
                self.tp_size,
            )

360
361
    @staticmethod
    @contextmanager
362
363
364
365
366
367
368
    def prepare_backward(
        fp8_enabled: bool,
        fp8_meta: Dict[str, Any],
        tp_group: dist_group_type,
        tp_size: int,
        name: str = "",
    ) -> Generator[None, None, None]:
369
370
        """Checks and prep for BWD."""
        if fp8_enabled:
371
372
373
374
375
376
            global_fp8_state = get_global_fp8_state()
            global_fp8_bwd_buffer = global_fp8_state.get_fp8_bwd_buffer()
            global_fp8_bwd_buffer.wait()

            if fp8_meta["recipe"].reduce_amax:
                global_fp8_bwd_buffer.copy_amax_from_buffer(fp8_meta)
377
378
379
380
381
                amax_and_scale_update(
                    fp8_meta,
                    fwd_update=False,
                    use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(),
                )
382
383
384
385
386
                global_fp8_bwd_buffer.set_for_deletion(fp8_meta)

                # Get new backward key.
                fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0)
            else:
387
388
389
390
391
                amax_and_scale_update(
                    fp8_meta,
                    fwd_update=False,
                    use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(),
                )
392
393
394
395

        with nvtx_range(name + " backward"):
            yield

396
397
398
399
400
        if fp8_enabled and fp8_meta["recipe"].reduce_amax:
            global_fp8_bwd_buffer.add_amax(fp8_meta)
            if fp8_meta["first_module"]:
                global_fp8_bwd_buffer.finalize(fp8_meta, tp_group, tp_size)

401
    @staticmethod
402
403
404
    def grad_output_preprocess(
        ctx, grad_output: paddle.Tensor, row_parallel_mode: bool
    ) -> Tuple[Union[paddle.Tensor, None], ...]:
405
406
407
408
409
410
411
412
        """Utility function for backward.
        Returns tuple in order (all optional/None based on training precion/recipe):
            R1: gathered `grad_output` in higher precision.
            R2: gathered `grad_output` in FP8.
            R3: R2 transposed.
            R4: bias gradient on R1.
        """
        grad_output_mat = grad_output.reshape((-1, grad_output.shape[-1]))
413
        gather_grad_output = row_parallel_mode and ctx.sequence_parallel
414
415
416

        # No-FP8 case: bgrad is fused with wgrad for this case.
        if not ctx.fp8_enabled:
417
418
            if gather_grad_output:
                grad_output_mat, _ = allgather(grad_output_mat, ctx.tp_group)
419
420
421
422
            return grad_output_mat, None, None, None

        fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)

423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
        if gather_grad_output:
            if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
                # FP8 case with gather: unfused bgrad, cast, transpose for efficient gather
                if ctx.use_bias:
                    bgrad = grad_output_mat.sum(axis=0)
                else:
                    bgrad = None
                grad_output_c = cast_to_fp8(
                    grad_output_mat,
                    ctx.fp8_meta["scaling_bwd"],
                    FP8BwdTensors.GRAD_OUTPUT1,
                    fp8_dtype_backward,
                )
                grad_output_c, _ = allgather(grad_output_c, ctx.tp_group)
                grad_output_t = transpose(grad_output_c, fp8_dtype_backward)

                return grad_output_mat, grad_output_c, grad_output_t, bgrad

            # FP8 case with gather and non-FP8 wgrad
            grad_output_mat, _ = allgather(grad_output_mat, ctx.tp_group)

444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
        # FP8 case without gather: cast, transpose, bgrad fused
        if ctx.use_bias:
            bgrad, grad_output_c, grad_output_t = cast_transpose_bgrad(
                grad_output_mat,
                ctx.fp8_meta["scaling_bwd"],
                FP8BwdTensors.GRAD_OUTPUT1,
                fp8_dtype_backward,
            )
        else:
            if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
                grad_output_c, grad_output_t = cast_transpose(
                    grad_output_mat,
                    ctx.fp8_meta["scaling_bwd"],
                    FP8BwdTensors.GRAD_OUTPUT1,
                    fp8_dtype_backward,
                )
            else:
                grad_output_t = None
                grad_output_c = cast_to_fp8(
                    grad_output_mat,
                    ctx.fp8_meta["scaling_bwd"],
                    FP8BwdTensors.GRAD_OUTPUT1,
                    fp8_dtype_backward,
                )
            bgrad = None
        return grad_output_mat, grad_output_c, grad_output_t, bgrad

471
472
473
    @abstractmethod
    def forward(self):
        """Needs override."""
474

475
    def get_fp8_weights_scratchpad_and_cast(
476
477
478
479
480
481
482
483
        self,
        is_first_microbatch: Union[bool, None],
    ) -> List[Optional[paddle.Tensor]]:
        """
        Fetch the fp8 weight tensor placeholders if they exist (when
        `is_first_microbatch` is not `None`)
        """
        if not self.fp8_enabled or is_first_microbatch is None:
484
            return [None, None] * len(self.fp8_weights)
485
486

        out_list = []
487
        for i, _ in enumerate(self.fp8_weights, start=1):
488
489
490
            weight_cast_key = f"weight{i}_fp8"
            weight_transpose_key = f"weight{i}_t_fp8"

491
492
493
            assert (
                weight_cast_key in self.fp8_weight_cache
            ), "TE internal error: fp8 weight buffer is not found"
494

495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
            weight_fp8 = self.fp8_weight_cache[weight_cast_key]
            weight_t_fp8 = self.fp8_weight_cache[weight_transpose_key]

            # Disable fp8 weight cache
            # is_first_microbatch is None -> we cast the weights into fp8 every micro step
            # Enalbe fp8 weight cache
            # is_first_microbatch == true -> we cast the weights into fp8 every micro step

            out_list.extend([weight_fp8, weight_t_fp8])

        # is cudagraph is enabled we cast the weight before the pp pipe
        # we only register the callback once
        if get_global_fp8_state().is_cudagraph_enabled() and (
            not self.registered_pp_start_callback and is_pp_enabled()
        ):

            fp8_dtype_forward = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True)

            def cast_callback(step_id=None, **kwargs):  # pylint: disable=unused-argument
                update_fp8_weights = step_id == 0

                for i, weight in enumerate(self.fp8_weights, start=1):
                    weight_cast_key = f"weight{i}_fp8"
                    weight_transpose_key = f"weight{i}_t_fp8"

                    assert (
                        weight_cast_key in self.fp8_weight_cache
                    ), "TE internal error: fp8 weight buffer is not found"

                    weight_fp8 = self.fp8_weight_cache[weight_cast_key]
                    weight_t_fp8 = self.fp8_weight_cache[weight_transpose_key]

                    if paddle.is_grad_enabled():
                        if update_fp8_weights:
                            cast_transpose(
                                weight,
                                self.fp8_meta["scaling_fwd"],
                                (
                                    FP8FwdTensors.GEMM1_WEIGHT
                                    if i == 1
                                    else FP8FwdTensors.GEMM2_WEIGHT
                                ),
                                fp8_dtype_forward,
                                cast_out=weight_fp8,
                                transpose_out=weight_t_fp8,
                            )
                    else:
                        if update_fp8_weights:
                            cast_to_fp8(
                                weight,
                                self.fp8_meta["scaling_fwd"],
                                (
                                    FP8FwdTensors.GEMM1_WEIGHT
                                    if i == 1
                                    else FP8FwdTensors.GEMM2_WEIGHT
                                ),
                                fp8_dtype_forward,
                                out=weight_fp8,
                            )

            cast_callback(0 if is_first_microbatch else 1)
            register_pp_fwd_begin_hook(cast_callback)
            self.registered_pp_start_callback = True
558
        return out_list