"vscode:/vscode.git/clone" did not exist on "81429b80388b3e0b3b7a746ed3568694a2fdd5eb"
Unverified Commit b1a0e0a7 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Refactor FP8 workspaces in linear modules (#820)



* Initial refactor of FP8 workspaces in Linear module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove extra kernel launch
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Minor perf optimizations

Tensor base class functions in Float8Tensor have significant overhead.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug FP8 recipe test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactor FP8 workspaces in LayerNormLinear and LayerNormMLP
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Document FP8 workspace function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Revert changes to FP8 recipe tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add support for lazy FP8 transpose caching

Previous caching behavior (always fill cache) incorrectly filled cache during CUDA graph warmup steps.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix Pylint warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug ONNX export

ONNX FP8 cast ops assumed that FP8 scales were created during model export (i.e. not initialized during training).
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug fused attention tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure Float8Tensor.transpose_2d is backward compatible
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Revert changes to ONNX export operations

Work around ONNX test failures by filling FP8 scale tensors instead of copying into them.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug scale factor update in Float8Tensor transpose_2d
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 4e30bc4b
......@@ -19,3 +19,4 @@ NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py
pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py
pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py
......@@ -1752,7 +1752,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
dtype=params_dtype,
)
)
self.fp8_weight_shapes.append(self.qkv_weight.shape)
self.qkv_bias = torch.nn.Parameter(
torch.empty(
self.hidden_size * 3,
......@@ -1786,9 +1785,3 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
self.training,
self.mask_type)
return out
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
"""Needs override."""
......@@ -294,7 +294,7 @@ class TestFloat8Tensor:
assert x_fp8._transpose_invalid, "Transpose cache must be invalid when not caching."
x_fp8 += 0.5
x = x_fp8.from_float8()
x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(cache=True))
x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(fill_cache=True))
x_t = x.transpose(0, 1)
torch.testing.assert_close(x_fp8_t, x_t, **tols)
assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."
......@@ -303,7 +303,7 @@ class TestFloat8Tensor:
x_fp8 += 0.5
assert x_fp8._transpose_invalid, "Transpose cache not invalidated properly."
x = x_fp8.from_float8()
x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(cache=True))
x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(fill_cache=True))
x_t = x.transpose(0, 1)
torch.testing.assert_close(x_fp8_t, x_t, **tols)
assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."
......
......@@ -5,15 +5,16 @@
"""Tensor class with FP8 data"""
from __future__ import annotations
from typing import Any, Dict, Optional, Tuple, Union
import warnings
import torch
from torch.utils._pytree import tree_map
import transformer_engine_extensions as tex
from .constants import TE_DType
from .cpp_extensions import fp8_cast_transpose_fused
from .fp8 import FP8GlobalStateManager
aten = torch.ops.aten
c10d = torch.ops.c10d
updated_fp8_params = {}
......@@ -381,6 +382,7 @@ class Float8Tensor(torch.Tensor):
raise ValueError(
"Float8Tensor requires non-differentiable data buffer"
)
if not data.is_cuda:
data = data.cuda()
# Initialize tensor object
......@@ -426,32 +428,38 @@ class Float8Tensor(torch.Tensor):
self._transpose_invalid: bool = True
# FP8 scale-inverse
self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv
if self._scale_inv is None and self._fp8_meta is not None:
if fp8_scale_inv is None and self._fp8_meta is not None:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=self._fp8_meta_forward,
)
scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index]
self._scale_inv = scale_inv.detach().view(1).clone()
if self._scale_inv is None:
fp8_scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index]
fp8_scale_inv = fp8_scale_inv.detach().view(1).clone()
if fp8_scale_inv is None:
raise ValueError(
"Attempted to initialize Float8Tensor without specifying scale-inverse"
)
if not isinstance(self._scale_inv, torch.Tensor):
self._scale_inv = torch.full(
if not isinstance(fp8_scale_inv, torch.Tensor):
fp8_scale_inv = torch.full(
[1],
self._scale_inv,
fp8_scale_inv,
dtype=torch.float32,
device=self._data.device,
)
if self._scale_inv.numel() != 1:
if fp8_scale_inv.numel() != 1:
raise ValueError(
"Attempted to initialize Float8Tensor with invalid scale-inverse tensor"
)
self._scale_inv = self._scale_inv.to(
if fp8_scale_inv.dim() != 1:
fp8_scale_inv = fp8_scale_inv.reshape(1)
if (
fp8_scale_inv.device != self._data.device
or fp8_scale_inv.dtype != torch.float32
):
fp8_scale_inv = fp8_scale_inv.to(
device=self._data.device,
dtype=torch.float32,
)
self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv
return self
......@@ -559,48 +567,177 @@ class Float8Tensor(torch.Tensor):
def transpose_2d(
self,
*,
cache: bool = False,
force_compute: bool = False,
fill_cache: bool = False,
noop_flag: Optional[torch.Tensor] = None,
cache: Optional[bool] = None,
) -> torch.Tensor:
"""
2D transpose with caching support.
Parameters
----------
cache: bool, default = `False`
Whether or not to cache the transpose.
noop_flag: Optional[torch.Tensor], default = `None`
Only used if argument `cache` is `True`, ignored otherwise.
A single element fp32 tensor with a value of 1.0 or 0.0
which is treated as a boolean. `1.0` forces recompute
and `0.0` executes a noop using the same kernel.
force_compute: bool, default = `False`
Force computation of transpose. Otherwise use
cached values, if possible.
fill_cache: bool, default = `False`
Cache output tensor for future function calls.
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid updating
cached values, if possible.
cache: bool, deprecated
"""
assert self.dim() == 2, f"{self.dim()}-D transpose not supported."
# Case: no caching.
if not cache:
return tex.fp8_transpose(self._data, self._fp8_dtype)
# Handle deprecated cache kwarg
if cache is not None:
msg = (
"cache kwarg for Float8Tensor.transpose_2d is deprecated, "
"please use force_compute and fill_cache instead"
)
warnings.warn(msg, DeprecationWarning)
if cache:
force_compute = False
fill_cache = True
else:
force_compute = True
fill_cache = False
# Case: reuse cache without calling a kernel.
if not self._transpose_invalid and noop_flag is None:
assert self._transpose is not None, "Tranpose cache is empty."
# Need to compute transpose if cache is invalid
need_compute = force_compute
if self._transpose is None:
need_compute = True
elif self._transpose_invalid:
need_compute = True
# Need to apply transpose kernel if noop flag is applied
if noop_flag is not None:
need_compute = True
# Return cached transpose if possible
if not need_compute:
return self._transpose
# Allocate transpose if needed.
data_2d = self._data.reshape(-1, self._data.shape[-1])
if self._transpose is None:
shape = (data_2d.shape[1], data_2d.shape[0])
self._transpose = torch.empty(shape, dtype=torch.uint8, device=self._data.device)
# Allocate output if needed
data = self._data.contiguous().reshape(-1, self.size(-1))
out = self._transpose
if out is None:
out = torch.empty(
(data.size(1), data.size(0)),
dtype=torch.uint8,
device=data.device,
)
noop_flag = None
else:
self._transpose_invalid = False
# Case: recompute transpose and store cache.
# Apply transpose kernel
fp8_dtype = self._fp8_dtype
if noop_flag is None:
tex.fp8_transpose_noalloc(data_2d, self._transpose, self._fp8_dtype)
tex.fp8_transpose_noalloc(data, out, fp8_dtype)
else:
# Case: cuda graph capture.
tex.fp8_transpose_noalloc_noop(data_2d, self._transpose, noop_flag, self._fp8_dtype)
noop_flag = noop_flag.to(dtype=torch.float32, device=data.device)
tex.fp8_transpose_noalloc_noop(data, out, noop_flag, fp8_dtype)
# Fill cache if needed
if fill_cache:
self._transpose = out
self._transpose_invalid = False
return out
@torch.no_grad()
def cast_transpose_(
self,
tensor: torch.Tensor,
noop_flag: Optional[torch.Tensor] = None,
) -> None:
"""Cast from tensor and populate transpose cache
Only supported for 2D tensors.
Parameters
----------
tensor: torch.Tensor
Tensor to copy from. Must have same dimensions as
destination tensor.
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid updating
destination tensor.
"""
# Make sure tensor is in expected format
data = self._data
if (
tensor.device != data.device
or tensor.dtype not in (torch.float32, torch.float16, torch.bfloat16)
or not tensor.is_contiguous()
):
dtype = tensor.dtype
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
dtype = torch.float32
tensor = tensor.to(
device=self.device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
if tensor.size() != data.size() or data.dim() != 2:
raise ValueError(
"Invalid tensor dimensions for FP8 cast-transpose "
f"(src={tuple(tensor.size())}, dst={tuple(data.size())})"
)
if not data.is_contiguous():
raise ValueError(
"FP8 cast-transpose is only supported for "
"`Float8Tensor`s with contiguous data"
)
if self._fp8_meta is None:
raise ValueError(
"FP8 cast-transpose is only supported for "
"`Float8Tensor`s with FP8 metadata "
)
# Construct transpose cache if needed
transpose = self._transpose
if transpose is None or not transpose.is_contiguous():
transpose = torch.empty(
(data.size(1), data.size(0)),
dtype=torch.uint8,
device=data.device,
)
self._transpose = transpose
noop_flag = None
# Launch cast-transpose kernel
fp8_meta_index = int(self._fp8_meta_index)
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=self._fp8_meta_forward,
)
fp8_meta = self._fp8_meta[fp8_meta_key]
fp8_cast_transpose_fused(
tensor,
fp8_meta,
fp8_meta_index,
self._fp8_dtype,
cast_out=data,
transpose_out=transpose,
noop_flag=noop_flag,
)
scale = fp8_meta.scale[fp8_meta_index:fp8_meta_index+1]
scale_inv = self._scale_inv
if noop_flag is None:
torch.reciprocal(scale, out=scale_inv)
else:
torch.where(
noop_flag.bool(),
scale_inv,
scale.reciprocal(),
out=scale_inv,
)
self._transpose_invalid = False
return self._transpose
@torch.no_grad()
def reset_fp8_meta_scale_inv(self) -> None:
......
......@@ -8,7 +8,7 @@ import os
import pickle
import warnings
from abc import ABC, abstractmethod
from typing import Generator, Union, Optional, Tuple, List
from typing import Dict, Generator, List, Optional, Tuple, Union
from contextlib import contextmanager
import torch
......@@ -252,9 +252,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.tp_group = None
self.tp_size = 1
self.sequence_parallel = False
self.fp8_weight_shapes = []
self.param_init_meta = {}
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self._fp8_workspaces: Dict[str, Float8Tensor] = {}
def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
"""Increase or decrease size of amax history based on given `length`.
......@@ -452,60 +452,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
)
self.activation_dtype = dtype
def set_fp8_weights(self) -> None:
"""Construct workspace buffers for FP8 weights, if needed
These workspace buffers are used for FP8 training when the
module parameters are not natively in FP8 and there are
multiple microbatches per training step. The buffers, with
names like `weight1_fp8` and `weight1_t_fp8`, cache the FP8
values and transposed FP8 values in between microbatches. They
are not registered as module parameters or buffers since we
don't want them to be affected by `.to` and since they aren't
needed for checkpointing.
"""
if not self.fp8 or self.primary_weights_in_fp8:
return
for i, shape in enumerate(self.fp8_weight_shapes, start=1):
weight_cast_attr = f"weight{i}_fp8"
weight_transpose_attr = f"weight{i}_t_fp8"
if (
hasattr(self, weight_cast_attr)
and getattr(self, weight_cast_attr).shape == shape
):
return
setattr(
self,
weight_cast_attr,
Float8Tensor(
data=torch.empty(
shape,
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_scale_inv=1,
)
)
setattr(
self,
weight_transpose_attr,
Float8Tensor(
data=torch.empty(
shape[1],
shape[0],
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_scale_inv=1,
)
)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""
Set the tensor parallel group for the given
......@@ -522,7 +468,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def _get_fp8_params(self) -> Union[List[torch.Tensor], None]:
"""returns the FP8 weights."""
fp8_params = []
for param in self.parameters():
for param in self.parameters(recurse=False):
if isinstance(param, Float8Tensor) and param.requires_grad:
fp8_params.append(param)
if len(fp8_params) == 0:
......@@ -569,7 +515,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def prepare_forward(
self,
inp: torch.Tensor,
is_first_microbatch: Union[bool, None],
is_first_microbatch: Union[bool, None], # pylint: disable=unused-argument
num_gemms: int = 1,
allow_non_contiguous: bool = False,
) -> Generator[torch.Tensor, None, None]:
......@@ -591,11 +537,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.set_activation_dtype(inp)
self.init_fp8_metadata(num_gemms=num_gemms)
# Create persistent tensors for fp8 weights and their transposes
# only when fp8 weight caching is used and weights are not in fp8
if is_first_microbatch is not None and not self.primary_weights_in_fp8:
self.set_fp8_weights()
if self.fp8 and self.sequence_parallel:
assert self.fp8_meta["recipe"].reduce_amax, \
"Amax reduction across tensor parallel group is " \
......@@ -754,49 +695,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return grad_output_mat, grad_output_c, grad_output_t, grad_bias
def get_fp8_weights_empty_tensors(
self,
is_first_microbatch: Union[bool, None],
) -> List[Float8Tensor]:
"""
Returns empty tensors to be later used to store fp8 version of weights
and their transposes (for the bwd pass) for this batch (or microbatch).
When `is_first_microbatch` is `None`, this is especially useful since
we then don't need to store the fp8 weights that are needed for one time
only in the forward pass. Note that we still need to store the tensor
for the fp8 weight transpose which is at least needed in the backward
pass but that's taken care of by storing the transpose tensor in
`ctx.save_for_backward`.
"""
assert is_first_microbatch is None, "Should only be here when "\
"`is_first_microbatch` is None!"
fp8_weight_tensors = []
for shape in self.fp8_weight_shapes:
fp8_weight_tensors.append(
Float8Tensor(
data=torch.empty(
shape,
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_scale_inv=1,
)
)
fp8_weight_tensors.append(
Float8Tensor(
data=torch.empty(
shape[1],
shape[0],
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_scale_inv=1,
)
)
return fp8_weight_tensors
def register_parameter(self, name, param, **kwargs):
"""
Thin wrapper around PyTorch parameter registration to stash additional parameter
......@@ -852,12 +750,119 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def forward(self):
"""Needs override."""
@abstractmethod
def get_fp8_weights_scratchpad(
def get_fp8_workspace(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
"""Needs override."""
*,
tensor: Optional[torch.Tensor] = None,
fp8_meta_forward: Optional[bool] = None,
fp8_meta_index: Optional[int] = None,
cache_name: Optional[str] = None,
update_workspace: bool = True,
skip_update_flag: Optional[torch.Tensor] = None,
with_transpose: bool = False,
) -> Float8Tensor:
"""Get FP8 workspace buffer and maybe update its values
The workspace buffer may be cached for future function calls.
Parameters
----------
tensor : torch.Tensor, optional
Values to copy into workspace. Required if the workspace
is being constructed or updated.
fp8_meta_forward: bool, optional
Whether to access FP8 meta tensors for the forward pass or
backward pass. Required if the workspace is being
constructed.
fp8_meta_index: int, optional
Index to access in FP8 meta tensors. Required if the
workspace is being constructed.
cache_name: str, optional
Key for caching.
update_workspace: bool, default = `True`
Update workspace with values from `tensor`.
skip_update_flag: torch.Tensor, optional
GPU flag to skip updating the workspace. Take precedence
over `update_workspace` if provided.
with_transpose: bool, default = `False`
Whether to initialize cached transpose in workspace.
"""
# Construct workspace if needed
out = None
if cache_name is not None:
out = self._fp8_workspaces.get(cache_name, None)
if out is None:
if (
tensor is None
or fp8_meta_forward is None
or fp8_meta_index is None
):
raise ValueError(
"tensor, fp8_meta_forward, and fp8_meta_index kwargs "
"must be provided to construct FP8 workspace"
)
fp8_dtype = get_fp8_te_dtype(
self.fp8_meta["recipe"],
fprop_tensor=fp8_meta_forward,
)
scale_inv = torch.empty(
[1],
dtype=torch.float32,
device=tensor.device
)
out = Float8Tensor(
data=torch.empty_like(tensor, dtype=torch.uint8),
fp8_meta=self.fp8_meta,
fp8_meta_forward=fp8_meta_forward,
fp8_meta_index=fp8_meta_index,
fp8_dtype=fp8_dtype,
fp8_scale_inv=scale_inv,
dtype=tensor.dtype,
)
if cache_name is not None:
self._fp8_workspaces[cache_name] = out
update_workspace = True
skip_update_flag = None
# Update workspace if needed
if skip_update_flag is not None:
update_workspace = True
if update_workspace:
if tensor is None:
raise ValueError(
"tensor kwarg must be provided to update FP8 workspace"
)
if with_transpose:
out.cast_transpose_(
tensor,
noop_flag=skip_update_flag,
)
else:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=out._fp8_meta_forward,
)
fp8_meta = out._fp8_meta[fp8_meta_key]
fp8_meta_index = out._fp8_meta_index
cast_to_fp8(
tensor,
fp8_meta,
fp8_meta_index,
out._fp8_dtype,
out=out._data,
)
if is_in_onnx_export_mode():
# ONNX export expects FP8 scales can be
# represented with constant ops. However, copying
# into a buffer involves an expand op for array
# broadcasting. We work around this by filling the
# buffer instead.
out._scale_inv.fill_(fp8_meta.scale_inv[fp8_meta_index].item())
else:
out._scale_inv.copy_(fp8_meta.scale_inv[fp8_meta_index])
return out
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
......
......@@ -5,7 +5,7 @@
"""LayerNormLinear API"""
import os
import warnings
from typing import Union, Optional, Callable, Tuple, List, Dict, Any
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
from torch.nn import init
......@@ -61,13 +61,11 @@ class _LayerNormLinear(torch.autograd.Function):
ln_weight: torch.Tensor,
ln_bias: Union[torch.Tensor, None],
weight: torch.Tensor,
weight_fp8: Union[torch.Tensor, None],
weight_t_fp8: Union[torch.Tensor, None],
weight_fp8: Optional[torch.Tensor],
bias: torch.Tensor,
use_bias: bool,
eps: float,
is_first_microbatch: Union[bool, None],
skip_fp8_weight_update: Union[torch.Tensor, None],
fp8: bool,
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
......@@ -86,7 +84,6 @@ class _LayerNormLinear(torch.autograd.Function):
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
normalization: str,
primary_weights_in_fp8: bool,
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
ub_overlap_rs_dgrad: bool,
......@@ -101,12 +98,6 @@ class _LayerNormLinear(torch.autograd.Function):
assert_dim_for_fp8_exec(inputmat)
assert_dim_for_fp8_exec(weight)
update_fp8_weights = (
is_first_microbatch is None
or is_first_microbatch
or skip_fp8_weight_update is not None
)
# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
ln_weight = cast_if_needed(ln_weight, activation_dtype)
......@@ -202,38 +193,10 @@ class _LayerNormLinear(torch.autograd.Function):
)
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if primary_weights_in_fp8:
# Weight is already in FP8
weight.reset_fp8_meta_scale_inv()
# Use FP8 weights
if weight_fp8 is None:
weight_fp8 = weight
elif update_fp8_weights:
# Need to cast weights to FP8
weight_fp8 = Float8Tensor(
data=weight_fp8._data,
fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
if (is_grad_enabled
or (is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase())):
tex.fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
cast_out=weight_fp8._data,
transpose_out=weight_t_fp8._data,
noop_flag=skip_fp8_weight_update,
)
else:
tex.cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
out=weight_fp8._data,
)
weight_t_fp8 = None
assert isinstance(weight_fp8, Float8Tensor)
if fp8_meta["recipe"].fp8_mha:
out_index, meta_tensor, output_te_dtype, output_dtype = (
......@@ -246,9 +209,9 @@ class _LayerNormLinear(torch.autograd.Function):
None, None, None, activation_dtype)
out, _ = tex.fp8_gemm(
weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
weight_fp8._scale_inv,
0,
weight_fp8._fp8_dtype,
ln_out_total,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT,
......@@ -307,8 +270,8 @@ class _LayerNormLinear(torch.autograd.Function):
if cpu_offloading:
if fuse_wgrad_accumulation:
weight.main_grad.weight_offloading = True
if fp8 and weight_t_fp8 is not None:
weight_t_fp8.weight_offloading = True
if fp8 and weight_fp8 is not None:
weight_fp8.weight_offloading = True
ln_weight.weight_offloading = True
weight.weight_offloading = True
......@@ -324,11 +287,10 @@ class _LayerNormLinear(torch.autograd.Function):
mu,
rsigma,
weight,
weight_fp8,
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
weight_t_fp8,
ln_out if weight.requires_grad else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None,
)
ctx.activation_dtype = activation_dtype
......@@ -355,7 +317,6 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.ub_name = ub_name
ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization
ctx.primary_weights_in_fp8 = primary_weights_in_fp8
ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias):
ctx.reduce_and_update_bwd_fp8_tensors = (
......@@ -394,26 +355,16 @@ class _LayerNormLinear(torch.autograd.Function):
mu,
rsigma,
weight,
weight_fp8,
main_grad,
weight_t_fp8,
ln_out,
fwd_scale_inverses,
skip_fp8_weight_update,
) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
weight = torch.nn.Parameter(weight, False)
weight.main_grad = main_grad
# Primary weights are in FP8.
if ctx.primary_weights_in_fp8:
weight_t_fp8 = weight.transpose_2d(
cache=ctx.is_first_microbatch is not None,
noop_flag=skip_fp8_weight_update,
)
elif ctx.fp8:
weight_t_fp8 = weight_t_fp8._data
if ctx.ub_overlap_rs_dgrad:
ctx.ub_bulk_dgrad = False
ctx.ub_bulk_wgrad = False
......@@ -520,10 +471,10 @@ class _LayerNormLinear(torch.autograd.Function):
# DGRAD: Evaluated unconditionally to feed into Linear backward
_ = tex.fp8_gemm(
weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
weight_fp8.transpose_2d(),
weight_fp8._scale_inv,
0,
weight_fp8._fp8_dtype,
grad_output_c._data
if isinstance(grad_output_c, Float8Tensor) else grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv,
......@@ -712,37 +663,34 @@ class _LayerNormLinear(torch.autograd.Function):
dgamma,
dbeta,
wgrad,
None,
None,
None, # weight_fp8
grad_bias,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None, # use_bias
None, # eps
None, # is_first_microbatch
None, # fp8
None, # fp8_calibration
None, # fp8_meta
None, # fuse_wgrad_accumulation
None, # cpu_offloading
None, # tp_group
None, # tp_size
None, # sequence_parallel
None, # tensor_parallel
None, # activation_dtype
None, # parallel_mode
None, # return_layernorm_output
None, # return_layernorm_output_gathered
None, # is_grad_enabled
None, # fwd_ln_sm_margin
None, # bwd_ln_sm_margin
None, # zero_centered_gamma
None, # normalization
None, # ub_bulk_wgrad
None, # ub_bulk_dgrad
None, # ub_overlap_rs_dgrad
None, # ub_overlap_ag
None, # ub_name
)
......@@ -873,7 +821,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_layernorm_output = return_layernorm_output
self.return_layernorm_output_gathered = return_layernorm_output_gathered
self.zero_centered_gamma = zero_centered_gamma
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_overlap_ag = ub_overlap_ag
......@@ -926,6 +873,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
else:
self.layer_norm_bias = None
# Initialize params in FP8
with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()
# Contiguous buffers for params
weight_tensor = torch.empty(
self.out_features,
......@@ -998,7 +948,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
# Check if parameters are subviews of buffers
is_subview = (split_start, split_end) != (0, self.out_features)
if is_subview and self.primary_weights_in_fp8:
if is_subview and with_fp8_params:
raise RuntimeError(
"Splitting Float8Tensor into multiple params "
"is not supported"
......@@ -1030,13 +980,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
bias = torch.Tensor().to(dtype=params_dtype, device=device)
setattr(self, name, bias)
if self.primary_weights_in_fp8:
if with_fp8_params:
self.init_fp8_metadata()
self.reset_parameters(defer_init=(device == 'meta'))
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.apply_bias:
......@@ -1093,31 +1041,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
elif self.parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1)
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
"""
Fetch the fp8 weight tensor placeholders if they exist (when
`is_first_microbatch` is not `None`) or return empty fp8 weight
tensors (if `is_first_microbatch is None`)
"""
if not self.fp8 or self.primary_weights_in_fp8:
return [None, None]
if is_first_microbatch is None:
# Return empty weight placeholders for each fwd/bwd pass
fp8_weight_tensors = self.get_fp8_weights_empty_tensors(
is_first_microbatch
)
else:
# These persistent weight placeholders should've been created in
# `set_fp8_weights` method
fp8_weight_tensors = [self.weight1_fp8, self.weight1_t_fp8]
return fp8_weight_tensors
@no_torch_dynamo()
def forward(
self,
......@@ -1151,13 +1074,19 @@ class LayerNormLinear(TransformerEngineBaseModule):
is_first_microbatch = False
with self.prepare_forward(inp, is_first_microbatch) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8."
# Get concatenated weight and bias tensors
weight_tensor = _noop_cat(
[getattr(self, name) for name in self.weight_names],
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, Float8Tensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting Float8Tensor into multiple params "
"is not supported"
)
else:
unfused_weights = [w.from_float8() for w in unfused_weights]
weight_tensor = _noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = _noop_cat(
[getattr(self, name) for name in self.bias_names],
......@@ -1165,9 +1094,44 @@ class LayerNormLinear(TransformerEngineBaseModule):
else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused
# Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(
# Initialize FP8 weights if needed
weight_fp8 = None
if self.fp8:
with_transpose = torch.is_grad_enabled()
if (
not with_transpose
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
with_transpose = True
if isinstance(weight_tensor, Float8Tensor):
# Fill transpose cache in FP8 tensor if needed
update_transpose_cache = with_transpose
if update_transpose_cache:
update_transpose_cache = (
is_first_microbatch
or skip_fp8_weight_update is not None
)
if update_transpose_cache:
weight_tensor.transpose_2d(
fill_cache=True,
noop_flag=skip_fp8_weight_update,
)
else:
# FP8 cast to workspace buffer
update_workspace = (
is_first_microbatch is None
or is_first_microbatch
or skip_fp8_weight_update is not None
)
weight_fp8 = self.get_fp8_workspace(
tensor=weight_tensor,
fp8_meta_forward=True,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
cache_name=(None if is_first_microbatch is None else "weight"),
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
with_transpose=with_transpose,
)
from ..cpu_offload import CPUOffloadEnabled
......@@ -1183,13 +1147,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
weight1_fp8,
weight1_t_fp8,
weight_fp8,
bias_tensor,
self.apply_bias and not self.gemm_bias_unfused_add,
self.eps,
is_first_microbatch,
skip_fp8_weight_update,
self.fp8,
self.fp8_calibration,
self.fp8_meta,
......@@ -1208,7 +1170,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
self.primary_weights_in_fp8,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_overlap_rs_dgrad,
......
......@@ -4,7 +4,7 @@
"""Linear API"""
import os
from typing import Union, Optional, Callable, Tuple, List, Dict, Any
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
......@@ -62,13 +62,11 @@ class _Linear(torch.autograd.Function):
def forward(
ctx,
weight: Union[Float8Tensor, torch.Tensor],
weight_fp8: Union[Float8Tensor, None],
weight_t_fp8: Union[Float8Tensor, None],
weight_fp8: Optional[Float8Tensor],
inp: torch.Tensor,
bias: torch.Tensor,
use_bias: bool,
is_first_microbatch: Union[bool, None],
skip_fp8_weight_update: Union[torch.Tensor, None],
fp8: bool,
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
......@@ -81,7 +79,6 @@ class _Linear(torch.autograd.Function):
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
is_grad_enabled: bool,
primary_weights_in_fp8: bool,
ub_overlap_rs: bool,
ub_overlap_ag: bool,
ub_name: str,
......@@ -99,12 +96,6 @@ class _Linear(torch.autograd.Function):
assert_dim_for_fp8_exec(inputmat)
assert_dim_for_fp8_exec(weight)
update_fp8_weights = (
is_first_microbatch is None
or is_first_microbatch
or skip_fp8_weight_update is not None
)
tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs
......@@ -163,38 +154,10 @@ class _Linear(torch.autograd.Function):
)
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if primary_weights_in_fp8:
# Weight is already in FP8
weight.reset_fp8_meta_scale_inv()
# Use FP8 weights
if weight_fp8 is None:
weight_fp8 = weight
elif update_fp8_weights:
# Need to cast weights to FP8
weight_fp8 = Float8Tensor(
data=weight_fp8._data,
fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
if (is_grad_enabled
or (is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase())):
fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
cast_out=weight_fp8._data,
transpose_out=weight_t_fp8._data,
noop_flag=skip_fp8_weight_update,
)
else:
cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
out=weight_fp8._data,
)
weight_t_fp8 = None
assert isinstance(weight_fp8, Float8Tensor)
if is_first_module_in_mha:
proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
......@@ -211,7 +174,7 @@ class _Linear(torch.autograd.Function):
out = ub_obj_projout.get_ubuf_output(1)
dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = weight.size(0)
dim_size[1] = weight_fp8.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
if ub_obj_projout.is_p2p_overlap():
if ub_obj_projout.is_atomic_gemm():
......@@ -231,14 +194,14 @@ class _Linear(torch.autograd.Function):
ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index])
else:
dim_size = list(inputmat_total.size())
dim_size[1] = weight.size(0)
dim_size[1] = weight_fp8.size(0)
out = torch.empty(dim_size, dtype=proj_out_pttype, device=inputmat_total.device)
_ = fp8_gemm(
weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
weight_fp8._scale_inv,
0,
weight_fp8._fp8_dtype,
inputmat_total._data
if isinstance(inputmat_total, Float8Tensor) else inputmat_total,
fp8_meta["scaling_fwd"].scale_inv,
......@@ -329,8 +292,8 @@ class _Linear(torch.autograd.Function):
if cpu_offloading:
if fuse_wgrad_accumulation:
weight.main_grad.weight_offloading = True
if fp8 and weight_t_fp8 is not None:
weight_t_fp8.weight_offloading = True
if fp8 and weight_fp8 is not None:
weight_fp8.weight_offloading = True
weight.weight_offloading = True
if saved_inputmat is not None:
......@@ -340,10 +303,9 @@ class _Linear(torch.autograd.Function):
saved_inputmat,
saved_inputmat_t,
weight,
weight_fp8,
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
weight_t_fp8 if fp8 else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None,
)
ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8
......@@ -362,7 +324,6 @@ class _Linear(torch.autograd.Function):
ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad
ctx.is_input_fp8 = is_input_fp8
ctx.primary_weights_in_fp8 = primary_weights_in_fp8
ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, weight, bias):
ctx.reduce_and_update_bwd_fp8_tensors = (
......@@ -394,25 +355,15 @@ class _Linear(torch.autograd.Function):
inputmat,
inputmat_t,
weight,
weight_fp8,
main_grad,
weight_t_fp8,
fwd_scale_inverses,
skip_fp8_weight_update,
) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
weight = torch.nn.Parameter(weight, False)
weight.main_grad = main_grad
# Primary weights are in FP8.
if ctx.primary_weights_in_fp8:
weight_t_fp8 = weight.transpose_2d(
cache=ctx.is_first_microbatch is not None,
noop_flag=skip_fp8_weight_update,
)
elif ctx.fp8:
weight_t_fp8 = weight_t_fp8._data
tp_world_size = get_distributed_world_size(ctx.tp_group)
ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag
if ctx.ub_overlap_ag:
......@@ -476,10 +427,10 @@ class _Linear(torch.autograd.Function):
out_index, meta_tensor, output_te_dtype, output_dtype = (
None, None, None, ctx.activation_dtype)
dgrad, _ = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
weight_fp8.transpose_2d(),
weight_fp8._scale_inv,
0,
weight_fp8._fp8_dtype,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
......@@ -620,30 +571,27 @@ class _Linear(torch.autograd.Function):
return (
wgrad,
None,
None,
None, # weight_fp8
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
grad_bias,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None, # use_bias
None, # is_first_microbatch
None, # fp8
None, # fp8_calibration
None, # fp8_meta
None, # fuse_wgrad_accumulation
None, # cpu_offloading
None, # tp_group
None, # tp_size
None, # sequence_parallel
None, # tensor_parallel
None, # activation_dtype
None, # parallel_mode
None, # is_grad_enabled
None, # ub_overlap_rs
None, # ub_overlap_ag
None, # ub_name
None, # is_first_module_in_mha
)
......@@ -747,7 +695,6 @@ class Linear(TransformerEngineBaseModule):
self.use_bias = bias
self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_overlap_rs = ub_overlap_rs
self.ub_overlap_ag = ub_overlap_ag
if ub_overlap_rs or ub_overlap_ag:
......@@ -783,6 +730,9 @@ class Linear(TransformerEngineBaseModule):
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
# Initialize params in FP8
with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()
# Contiguous buffers for params
weight_tensor = torch.empty(
self.out_features,
......@@ -855,7 +805,7 @@ class Linear(TransformerEngineBaseModule):
# Check if parameters are subviews of buffers
is_subview = (split_start, split_end) != (0, self.out_features)
if is_subview and self.primary_weights_in_fp8:
if is_subview and with_fp8_params:
raise RuntimeError(
"Splitting Float8Tensor into multiple params "
"is not supported"
......@@ -887,13 +837,11 @@ class Linear(TransformerEngineBaseModule):
bias = torch.Tensor().to(dtype=params_dtype, device=device)
setattr(self, name, bias)
if self.primary_weights_in_fp8:
if with_fp8_params:
self.init_fp8_metadata()
self.reset_parameters(defer_init=(device == 'meta'))
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.apply_bias:
......@@ -922,30 +870,6 @@ class Linear(TransformerEngineBaseModule):
elif self.parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1)
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[Float8Tensor]:
"""
Fetch the fp8 weight tensor placeholders if they exist (when
`is_first_microbatch` is not `None`) or return empty fp8 weight
tensors (if `is_first_microbatch is None`)
"""
if not self.fp8 or self.primary_weights_in_fp8:
return [None, None]
if is_first_microbatch is None:
# Return empty weight placeholders for each fwd/bwd pass
fp8_weight_tensors = self.get_fp8_weights_empty_tensors(
is_first_microbatch
)
else:
# These persistent weight placeholders should've been created in
# `set_fp8_weights` method
fp8_weight_tensors = [self.weight1_fp8, self.weight1_t_fp8]
return fp8_weight_tensors
@no_torch_dynamo()
def forward(
self,
......@@ -979,18 +903,26 @@ class Linear(TransformerEngineBaseModule):
if skip_fp8_weight_update is not None:
is_first_microbatch = False
with self.prepare_forward(inp,
with self.prepare_forward(
inp,
is_first_microbatch,
allow_non_contiguous=isinstance(inp,Float8Tensor)) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8."
allow_non_contiguous=isinstance(inp,Float8Tensor),
) as inp:
is_first_module_in_mha = is_first_module_in_mha and self.fp8_meta["recipe"].fp8_mha
# Get concatenated weight and bias tensors
weight_tensor = _noop_cat(
[getattr(self, name) for name in self.weight_names],
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, Float8Tensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting Float8Tensor into multiple params "
"is not supported"
)
else:
unfused_weights = [w.from_float8() for w in unfused_weights]
weight_tensor = _noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = _noop_cat(
[getattr(self, name) for name in self.bias_names],
......@@ -998,9 +930,43 @@ class Linear(TransformerEngineBaseModule):
else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused
# Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(
# Initialize FP8 weights if needed
weight_fp8 = None
if self.fp8:
with_transpose = torch.is_grad_enabled()
if (
not with_transpose
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
with_transpose = True
if isinstance(weight_tensor, Float8Tensor):
# Fill transpose cache in FP8 tensor if needed
update_transpose_cache = with_transpose
if update_transpose_cache:
update_transpose_cache = (
is_first_microbatch
or skip_fp8_weight_update is not None
)
if update_transpose_cache:
weight_tensor.transpose_2d(
fill_cache=True,
noop_flag=skip_fp8_weight_update,
)
else:
# FP8 cast to workspace buffer
update_workspace = (
is_first_microbatch is None
or is_first_microbatch
)
weight_fp8 = self.get_fp8_workspace(
tensor=weight_tensor,
fp8_meta_forward=True,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
cache_name=(None if is_first_microbatch is None else "weight"),
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
with_transpose=with_transpose,
)
from ..cpu_offload import CPUOffloadEnabled
......@@ -1013,13 +979,11 @@ class Linear(TransformerEngineBaseModule):
args = [None]
args += (
weight_tensor,
weight1_fp8,
weight1_t_fp8,
weight_fp8,
inp,
bias_tensor,
self.apply_bias and not self.gemm_bias_unfused_add,
is_first_microbatch,
skip_fp8_weight_update,
self.fp8,
self.fp8_calibration,
self.fp8_meta,
......@@ -1032,7 +996,6 @@ class Linear(TransformerEngineBaseModule):
self.activation_dtype,
self.parallel_mode,
torch.is_grad_enabled(),
self.primary_weights_in_fp8,
self.ub_overlap_rs,
self.ub_overlap_ag,
self.ub_name,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment