Unverified Commit b1a0e0a7 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Refactor FP8 workspaces in linear modules (#820)



* Initial refactor of FP8 workspaces in Linear module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove extra kernel launch
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Minor perf optimizations

Tensor base class functions in Float8Tensor have significant overhead.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug FP8 recipe test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactor FP8 workspaces in LayerNormLinear and LayerNormMLP
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Document FP8 workspace function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Revert changes to FP8 recipe tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add support for lazy FP8 transpose caching

Previous caching behavior (always fill cache) incorrectly filled cache during CUDA graph warmup steps.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix Pylint warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug ONNX export

ONNX FP8 cast ops assumed that FP8 scales were created during model export (i.e. not initialized during training).
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug fused attention tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure Float8Tensor.transpose_2d is backward compatible
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Revert changes to ONNX export operations

Work around ONNX test failures by filling FP8 scale tensors instead of copying into them.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug scale factor update in Float8Tensor transpose_2d
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 4e30bc4b
...@@ -19,3 +19,4 @@ NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py ...@@ -19,3 +19,4 @@ NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py
pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py
pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py
...@@ -1752,7 +1752,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule): ...@@ -1752,7 +1752,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
dtype=params_dtype, dtype=params_dtype,
) )
) )
self.fp8_weight_shapes.append(self.qkv_weight.shape)
self.qkv_bias = torch.nn.Parameter( self.qkv_bias = torch.nn.Parameter(
torch.empty( torch.empty(
self.hidden_size * 3, self.hidden_size * 3,
...@@ -1786,9 +1785,3 @@ class Custom_MHA_FP8(TransformerEngineBaseModule): ...@@ -1786,9 +1785,3 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
self.training, self.training,
self.mask_type) self.mask_type)
return out return out
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
"""Needs override."""
...@@ -294,7 +294,7 @@ class TestFloat8Tensor: ...@@ -294,7 +294,7 @@ class TestFloat8Tensor:
assert x_fp8._transpose_invalid, "Transpose cache must be invalid when not caching." assert x_fp8._transpose_invalid, "Transpose cache must be invalid when not caching."
x_fp8 += 0.5 x_fp8 += 0.5
x = x_fp8.from_float8() x = x_fp8.from_float8()
x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(cache=True)) x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(fill_cache=True))
x_t = x.transpose(0, 1) x_t = x.transpose(0, 1)
torch.testing.assert_close(x_fp8_t, x_t, **tols) torch.testing.assert_close(x_fp8_t, x_t, **tols)
assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly." assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."
...@@ -303,7 +303,7 @@ class TestFloat8Tensor: ...@@ -303,7 +303,7 @@ class TestFloat8Tensor:
x_fp8 += 0.5 x_fp8 += 0.5
assert x_fp8._transpose_invalid, "Transpose cache not invalidated properly." assert x_fp8._transpose_invalid, "Transpose cache not invalidated properly."
x = x_fp8.from_float8() x = x_fp8.from_float8()
x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(cache=True)) x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(fill_cache=True))
x_t = x.transpose(0, 1) x_t = x.transpose(0, 1)
torch.testing.assert_close(x_fp8_t, x_t, **tols) torch.testing.assert_close(x_fp8_t, x_t, **tols)
assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly." assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."
......
...@@ -5,15 +5,16 @@ ...@@ -5,15 +5,16 @@
"""Tensor class with FP8 data""" """Tensor class with FP8 data"""
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import warnings
import torch import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from .constants import TE_DType from .constants import TE_DType
from .cpp_extensions import fp8_cast_transpose_fused
from .fp8 import FP8GlobalStateManager from .fp8 import FP8GlobalStateManager
aten = torch.ops.aten aten = torch.ops.aten
c10d = torch.ops.c10d c10d = torch.ops.c10d
updated_fp8_params = {} updated_fp8_params = {}
...@@ -381,6 +382,7 @@ class Float8Tensor(torch.Tensor): ...@@ -381,6 +382,7 @@ class Float8Tensor(torch.Tensor):
raise ValueError( raise ValueError(
"Float8Tensor requires non-differentiable data buffer" "Float8Tensor requires non-differentiable data buffer"
) )
if not data.is_cuda:
data = data.cuda() data = data.cuda()
# Initialize tensor object # Initialize tensor object
...@@ -426,32 +428,38 @@ class Float8Tensor(torch.Tensor): ...@@ -426,32 +428,38 @@ class Float8Tensor(torch.Tensor):
self._transpose_invalid: bool = True self._transpose_invalid: bool = True
# FP8 scale-inverse # FP8 scale-inverse
self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv if fp8_scale_inv is None and self._fp8_meta is not None:
if self._scale_inv is None and self._fp8_meta is not None:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=self._fp8_meta_forward, forward=self._fp8_meta_forward,
) )
scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index] fp8_scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index]
self._scale_inv = scale_inv.detach().view(1).clone() fp8_scale_inv = fp8_scale_inv.detach().view(1).clone()
if self._scale_inv is None: if fp8_scale_inv is None:
raise ValueError( raise ValueError(
"Attempted to initialize Float8Tensor without specifying scale-inverse" "Attempted to initialize Float8Tensor without specifying scale-inverse"
) )
if not isinstance(self._scale_inv, torch.Tensor): if not isinstance(fp8_scale_inv, torch.Tensor):
self._scale_inv = torch.full( fp8_scale_inv = torch.full(
[1], [1],
self._scale_inv, fp8_scale_inv,
dtype=torch.float32, dtype=torch.float32,
device=self._data.device, device=self._data.device,
) )
if self._scale_inv.numel() != 1: if fp8_scale_inv.numel() != 1:
raise ValueError( raise ValueError(
"Attempted to initialize Float8Tensor with invalid scale-inverse tensor" "Attempted to initialize Float8Tensor with invalid scale-inverse tensor"
) )
self._scale_inv = self._scale_inv.to( if fp8_scale_inv.dim() != 1:
fp8_scale_inv = fp8_scale_inv.reshape(1)
if (
fp8_scale_inv.device != self._data.device
or fp8_scale_inv.dtype != torch.float32
):
fp8_scale_inv = fp8_scale_inv.to(
device=self._data.device, device=self._data.device,
dtype=torch.float32, dtype=torch.float32,
) )
self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv
return self return self
...@@ -559,48 +567,177 @@ class Float8Tensor(torch.Tensor): ...@@ -559,48 +567,177 @@ class Float8Tensor(torch.Tensor):
def transpose_2d( def transpose_2d(
self, self,
*, *,
cache: bool = False, force_compute: bool = False,
fill_cache: bool = False,
noop_flag: Optional[torch.Tensor] = None, noop_flag: Optional[torch.Tensor] = None,
cache: Optional[bool] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
2D transpose with caching support. 2D transpose with caching support.
Parameters Parameters
---------- ----------
cache: bool, default = `False` force_compute: bool, default = `False`
Whether or not to cache the transpose. Force computation of transpose. Otherwise use
noop_flag: Optional[torch.Tensor], default = `None` cached values, if possible.
Only used if argument `cache` is `True`, ignored otherwise. fill_cache: bool, default = `False`
A single element fp32 tensor with a value of 1.0 or 0.0 Cache output tensor for future function calls.
which is treated as a boolean. `1.0` forces recompute noop_flag: torch.Tensor, optional
and `0.0` executes a noop using the same kernel. float32 flag indicating whether to avoid updating
cached values, if possible.
cache: bool, deprecated
""" """
assert self.dim() == 2, f"{self.dim()}-D transpose not supported." assert self.dim() == 2, f"{self.dim()}-D transpose not supported."
# Case: no caching. # Handle deprecated cache kwarg
if not cache: if cache is not None:
return tex.fp8_transpose(self._data, self._fp8_dtype) msg = (
"cache kwarg for Float8Tensor.transpose_2d is deprecated, "
"please use force_compute and fill_cache instead"
)
warnings.warn(msg, DeprecationWarning)
if cache:
force_compute = False
fill_cache = True
else:
force_compute = True
fill_cache = False
# Case: reuse cache without calling a kernel. # Need to compute transpose if cache is invalid
if not self._transpose_invalid and noop_flag is None: need_compute = force_compute
assert self._transpose is not None, "Tranpose cache is empty." if self._transpose is None:
need_compute = True
elif self._transpose_invalid:
need_compute = True
# Need to apply transpose kernel if noop flag is applied
if noop_flag is not None:
need_compute = True
# Return cached transpose if possible
if not need_compute:
return self._transpose return self._transpose
# Allocate transpose if needed. # Allocate output if needed
data_2d = self._data.reshape(-1, self._data.shape[-1]) data = self._data.contiguous().reshape(-1, self.size(-1))
if self._transpose is None: out = self._transpose
shape = (data_2d.shape[1], data_2d.shape[0]) if out is None:
self._transpose = torch.empty(shape, dtype=torch.uint8, device=self._data.device) out = torch.empty(
(data.size(1), data.size(0)),
dtype=torch.uint8,
device=data.device,
)
noop_flag = None
else:
self._transpose_invalid = False
# Case: recompute transpose and store cache. # Apply transpose kernel
fp8_dtype = self._fp8_dtype
if noop_flag is None: if noop_flag is None:
tex.fp8_transpose_noalloc(data_2d, self._transpose, self._fp8_dtype) tex.fp8_transpose_noalloc(data, out, fp8_dtype)
else: else:
# Case: cuda graph capture. noop_flag = noop_flag.to(dtype=torch.float32, device=data.device)
tex.fp8_transpose_noalloc_noop(data_2d, self._transpose, noop_flag, self._fp8_dtype) tex.fp8_transpose_noalloc_noop(data, out, noop_flag, fp8_dtype)
# Fill cache if needed
if fill_cache:
self._transpose = out
self._transpose_invalid = False
return out
@torch.no_grad()
def cast_transpose_(
self,
tensor: torch.Tensor,
noop_flag: Optional[torch.Tensor] = None,
) -> None:
"""Cast from tensor and populate transpose cache
Only supported for 2D tensors.
Parameters
----------
tensor: torch.Tensor
Tensor to copy from. Must have same dimensions as
destination tensor.
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid updating
destination tensor.
"""
# Make sure tensor is in expected format
data = self._data
if (
tensor.device != data.device
or tensor.dtype not in (torch.float32, torch.float16, torch.bfloat16)
or not tensor.is_contiguous()
):
dtype = tensor.dtype
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
dtype = torch.float32
tensor = tensor.to(
device=self.device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
if tensor.size() != data.size() or data.dim() != 2:
raise ValueError(
"Invalid tensor dimensions for FP8 cast-transpose "
f"(src={tuple(tensor.size())}, dst={tuple(data.size())})"
)
if not data.is_contiguous():
raise ValueError(
"FP8 cast-transpose is only supported for "
"`Float8Tensor`s with contiguous data"
)
if self._fp8_meta is None:
raise ValueError(
"FP8 cast-transpose is only supported for "
"`Float8Tensor`s with FP8 metadata "
)
# Construct transpose cache if needed
transpose = self._transpose
if transpose is None or not transpose.is_contiguous():
transpose = torch.empty(
(data.size(1), data.size(0)),
dtype=torch.uint8,
device=data.device,
)
self._transpose = transpose
noop_flag = None
# Launch cast-transpose kernel
fp8_meta_index = int(self._fp8_meta_index)
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=self._fp8_meta_forward,
)
fp8_meta = self._fp8_meta[fp8_meta_key]
fp8_cast_transpose_fused(
tensor,
fp8_meta,
fp8_meta_index,
self._fp8_dtype,
cast_out=data,
transpose_out=transpose,
noop_flag=noop_flag,
)
scale = fp8_meta.scale[fp8_meta_index:fp8_meta_index+1]
scale_inv = self._scale_inv
if noop_flag is None:
torch.reciprocal(scale, out=scale_inv)
else:
torch.where(
noop_flag.bool(),
scale_inv,
scale.reciprocal(),
out=scale_inv,
)
self._transpose_invalid = False self._transpose_invalid = False
return self._transpose
@torch.no_grad() @torch.no_grad()
def reset_fp8_meta_scale_inv(self) -> None: def reset_fp8_meta_scale_inv(self) -> None:
......
...@@ -8,7 +8,7 @@ import os ...@@ -8,7 +8,7 @@ import os
import pickle import pickle
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Generator, Union, Optional, Tuple, List from typing import Dict, Generator, List, Optional, Tuple, Union
from contextlib import contextmanager from contextlib import contextmanager
import torch import torch
...@@ -252,9 +252,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -252,9 +252,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.tp_group = None self.tp_group = None
self.tp_size = 1 self.tp_size = 1
self.sequence_parallel = False self.sequence_parallel = False
self.fp8_weight_shapes = []
self.param_init_meta = {} self.param_init_meta = {}
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self._fp8_workspaces: Dict[str, Float8Tensor] = {}
def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
"""Increase or decrease size of amax history based on given `length`. """Increase or decrease size of amax history based on given `length`.
...@@ -452,60 +452,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -452,60 +452,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
) )
self.activation_dtype = dtype self.activation_dtype = dtype
def set_fp8_weights(self) -> None:
"""Construct workspace buffers for FP8 weights, if needed
These workspace buffers are used for FP8 training when the
module parameters are not natively in FP8 and there are
multiple microbatches per training step. The buffers, with
names like `weight1_fp8` and `weight1_t_fp8`, cache the FP8
values and transposed FP8 values in between microbatches. They
are not registered as module parameters or buffers since we
don't want them to be affected by `.to` and since they aren't
needed for checkpointing.
"""
if not self.fp8 or self.primary_weights_in_fp8:
return
for i, shape in enumerate(self.fp8_weight_shapes, start=1):
weight_cast_attr = f"weight{i}_fp8"
weight_transpose_attr = f"weight{i}_t_fp8"
if (
hasattr(self, weight_cast_attr)
and getattr(self, weight_cast_attr).shape == shape
):
return
setattr(
self,
weight_cast_attr,
Float8Tensor(
data=torch.empty(
shape,
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_scale_inv=1,
)
)
setattr(
self,
weight_transpose_attr,
Float8Tensor(
data=torch.empty(
shape[1],
shape[0],
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_scale_inv=1,
)
)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
""" """
Set the tensor parallel group for the given Set the tensor parallel group for the given
...@@ -522,7 +468,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -522,7 +468,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: def _get_fp8_params(self) -> Union[List[torch.Tensor], None]:
"""returns the FP8 weights.""" """returns the FP8 weights."""
fp8_params = [] fp8_params = []
for param in self.parameters(): for param in self.parameters(recurse=False):
if isinstance(param, Float8Tensor) and param.requires_grad: if isinstance(param, Float8Tensor) and param.requires_grad:
fp8_params.append(param) fp8_params.append(param)
if len(fp8_params) == 0: if len(fp8_params) == 0:
...@@ -569,7 +515,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -569,7 +515,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def prepare_forward( def prepare_forward(
self, self,
inp: torch.Tensor, inp: torch.Tensor,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None], # pylint: disable=unused-argument
num_gemms: int = 1, num_gemms: int = 1,
allow_non_contiguous: bool = False, allow_non_contiguous: bool = False,
) -> Generator[torch.Tensor, None, None]: ) -> Generator[torch.Tensor, None, None]:
...@@ -591,11 +537,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -591,11 +537,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.set_activation_dtype(inp) self.set_activation_dtype(inp)
self.init_fp8_metadata(num_gemms=num_gemms) self.init_fp8_metadata(num_gemms=num_gemms)
# Create persistent tensors for fp8 weights and their transposes
# only when fp8 weight caching is used and weights are not in fp8
if is_first_microbatch is not None and not self.primary_weights_in_fp8:
self.set_fp8_weights()
if self.fp8 and self.sequence_parallel: if self.fp8 and self.sequence_parallel:
assert self.fp8_meta["recipe"].reduce_amax, \ assert self.fp8_meta["recipe"].reduce_amax, \
"Amax reduction across tensor parallel group is " \ "Amax reduction across tensor parallel group is " \
...@@ -754,49 +695,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -754,49 +695,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return grad_output_mat, grad_output_c, grad_output_t, grad_bias return grad_output_mat, grad_output_c, grad_output_t, grad_bias
def get_fp8_weights_empty_tensors(
self,
is_first_microbatch: Union[bool, None],
) -> List[Float8Tensor]:
"""
Returns empty tensors to be later used to store fp8 version of weights
and their transposes (for the bwd pass) for this batch (or microbatch).
When `is_first_microbatch` is `None`, this is especially useful since
we then don't need to store the fp8 weights that are needed for one time
only in the forward pass. Note that we still need to store the tensor
for the fp8 weight transpose which is at least needed in the backward
pass but that's taken care of by storing the transpose tensor in
`ctx.save_for_backward`.
"""
assert is_first_microbatch is None, "Should only be here when "\
"`is_first_microbatch` is None!"
fp8_weight_tensors = []
for shape in self.fp8_weight_shapes:
fp8_weight_tensors.append(
Float8Tensor(
data=torch.empty(
shape,
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_scale_inv=1,
)
)
fp8_weight_tensors.append(
Float8Tensor(
data=torch.empty(
shape[1],
shape[0],
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_scale_inv=1,
)
)
return fp8_weight_tensors
def register_parameter(self, name, param, **kwargs): def register_parameter(self, name, param, **kwargs):
""" """
Thin wrapper around PyTorch parameter registration to stash additional parameter Thin wrapper around PyTorch parameter registration to stash additional parameter
...@@ -852,12 +750,119 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -852,12 +750,119 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def forward(self): def forward(self):
"""Needs override.""" """Needs override."""
@abstractmethod def get_fp8_workspace(
def get_fp8_weights_scratchpad(
self, self,
is_first_microbatch: Union[bool, None], *,
) -> List[torch.Tensor]: tensor: Optional[torch.Tensor] = None,
"""Needs override.""" fp8_meta_forward: Optional[bool] = None,
fp8_meta_index: Optional[int] = None,
cache_name: Optional[str] = None,
update_workspace: bool = True,
skip_update_flag: Optional[torch.Tensor] = None,
with_transpose: bool = False,
) -> Float8Tensor:
"""Get FP8 workspace buffer and maybe update its values
The workspace buffer may be cached for future function calls.
Parameters
----------
tensor : torch.Tensor, optional
Values to copy into workspace. Required if the workspace
is being constructed or updated.
fp8_meta_forward: bool, optional
Whether to access FP8 meta tensors for the forward pass or
backward pass. Required if the workspace is being
constructed.
fp8_meta_index: int, optional
Index to access in FP8 meta tensors. Required if the
workspace is being constructed.
cache_name: str, optional
Key for caching.
update_workspace: bool, default = `True`
Update workspace with values from `tensor`.
skip_update_flag: torch.Tensor, optional
GPU flag to skip updating the workspace. Take precedence
over `update_workspace` if provided.
with_transpose: bool, default = `False`
Whether to initialize cached transpose in workspace.
"""
# Construct workspace if needed
out = None
if cache_name is not None:
out = self._fp8_workspaces.get(cache_name, None)
if out is None:
if (
tensor is None
or fp8_meta_forward is None
or fp8_meta_index is None
):
raise ValueError(
"tensor, fp8_meta_forward, and fp8_meta_index kwargs "
"must be provided to construct FP8 workspace"
)
fp8_dtype = get_fp8_te_dtype(
self.fp8_meta["recipe"],
fprop_tensor=fp8_meta_forward,
)
scale_inv = torch.empty(
[1],
dtype=torch.float32,
device=tensor.device
)
out = Float8Tensor(
data=torch.empty_like(tensor, dtype=torch.uint8),
fp8_meta=self.fp8_meta,
fp8_meta_forward=fp8_meta_forward,
fp8_meta_index=fp8_meta_index,
fp8_dtype=fp8_dtype,
fp8_scale_inv=scale_inv,
dtype=tensor.dtype,
)
if cache_name is not None:
self._fp8_workspaces[cache_name] = out
update_workspace = True
skip_update_flag = None
# Update workspace if needed
if skip_update_flag is not None:
update_workspace = True
if update_workspace:
if tensor is None:
raise ValueError(
"tensor kwarg must be provided to update FP8 workspace"
)
if with_transpose:
out.cast_transpose_(
tensor,
noop_flag=skip_update_flag,
)
else:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=out._fp8_meta_forward,
)
fp8_meta = out._fp8_meta[fp8_meta_key]
fp8_meta_index = out._fp8_meta_index
cast_to_fp8(
tensor,
fp8_meta,
fp8_meta_index,
out._fp8_dtype,
out=out._data,
)
if is_in_onnx_export_mode():
# ONNX export expects FP8 scales can be
# represented with constant ops. However, copying
# into a buffer involves an expand op for array
# broadcasting. We work around this by filling the
# buffer instead.
out._scale_inv.fill_(fp8_meta.scale_inv[fp8_meta_index].item())
else:
out._scale_inv.copy_(fp8_meta.scale_inv[fp8_meta_index])
return out
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs): missing_keys, unexpected_keys, error_msgs):
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
"""LayerNormLinear API""" """LayerNormLinear API"""
import os import os
import warnings import warnings
from typing import Union, Optional, Callable, Tuple, List, Dict, Any from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch import torch
from torch.nn import init from torch.nn import init
...@@ -61,13 +61,11 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -61,13 +61,11 @@ class _LayerNormLinear(torch.autograd.Function):
ln_weight: torch.Tensor, ln_weight: torch.Tensor,
ln_bias: Union[torch.Tensor, None], ln_bias: Union[torch.Tensor, None],
weight: torch.Tensor, weight: torch.Tensor,
weight_fp8: Union[torch.Tensor, None], weight_fp8: Optional[torch.Tensor],
weight_t_fp8: Union[torch.Tensor, None],
bias: torch.Tensor, bias: torch.Tensor,
use_bias: bool, use_bias: bool,
eps: float, eps: float,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
skip_fp8_weight_update: Union[torch.Tensor, None],
fp8: bool, fp8: bool,
fp8_calibration: bool, fp8_calibration: bool,
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
...@@ -86,7 +84,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -86,7 +84,6 @@ class _LayerNormLinear(torch.autograd.Function):
bwd_ln_sm_margin: int, bwd_ln_sm_margin: int,
zero_centered_gamma: bool, zero_centered_gamma: bool,
normalization: str, normalization: str,
primary_weights_in_fp8: bool,
ub_bulk_wgrad: bool, ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool, ub_bulk_dgrad: bool,
ub_overlap_rs_dgrad: bool, ub_overlap_rs_dgrad: bool,
...@@ -101,12 +98,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -101,12 +98,6 @@ class _LayerNormLinear(torch.autograd.Function):
assert_dim_for_fp8_exec(inputmat) assert_dim_for_fp8_exec(inputmat)
assert_dim_for_fp8_exec(weight) assert_dim_for_fp8_exec(weight)
update_fp8_weights = (
is_first_microbatch is None
or is_first_microbatch
or skip_fp8_weight_update is not None
)
# Cast for native AMP # Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype) inputmat = cast_if_needed(inputmat, activation_dtype)
ln_weight = cast_if_needed(ln_weight, activation_dtype) ln_weight = cast_if_needed(ln_weight, activation_dtype)
...@@ -202,38 +193,10 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -202,38 +193,10 @@ class _LayerNormLinear(torch.autograd.Function):
) )
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if primary_weights_in_fp8: # Use FP8 weights
# Weight is already in FP8 if weight_fp8 is None:
weight.reset_fp8_meta_scale_inv()
weight_fp8 = weight weight_fp8 = weight
elif update_fp8_weights: assert isinstance(weight_fp8, Float8Tensor)
# Need to cast weights to FP8
weight_fp8 = Float8Tensor(
data=weight_fp8._data,
fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
if (is_grad_enabled
or (is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase())):
tex.fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
cast_out=weight_fp8._data,
transpose_out=weight_t_fp8._data,
noop_flag=skip_fp8_weight_update,
)
else:
tex.cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
out=weight_fp8._data,
)
weight_t_fp8 = None
if fp8_meta["recipe"].fp8_mha: if fp8_meta["recipe"].fp8_mha:
out_index, meta_tensor, output_te_dtype, output_dtype = ( out_index, meta_tensor, output_te_dtype, output_dtype = (
...@@ -246,9 +209,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -246,9 +209,9 @@ class _LayerNormLinear(torch.autograd.Function):
None, None, None, activation_dtype) None, None, None, activation_dtype)
out, _ = tex.fp8_gemm( out, _ = tex.fp8_gemm(
weight_fp8._data, weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv, weight_fp8._scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT, 0,
fp8_dtype_forward, weight_fp8._fp8_dtype,
ln_out_total, ln_out_total,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
...@@ -307,8 +270,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -307,8 +270,8 @@ class _LayerNormLinear(torch.autograd.Function):
if cpu_offloading: if cpu_offloading:
if fuse_wgrad_accumulation: if fuse_wgrad_accumulation:
weight.main_grad.weight_offloading = True weight.main_grad.weight_offloading = True
if fp8 and weight_t_fp8 is not None: if fp8 and weight_fp8 is not None:
weight_t_fp8.weight_offloading = True weight_fp8.weight_offloading = True
ln_weight.weight_offloading = True ln_weight.weight_offloading = True
weight.weight_offloading = True weight.weight_offloading = True
...@@ -324,11 +287,10 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -324,11 +287,10 @@ class _LayerNormLinear(torch.autograd.Function):
mu, mu,
rsigma, rsigma,
weight, weight,
weight_fp8,
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
weight_t_fp8,
ln_out if weight.requires_grad else None, ln_out if weight.requires_grad else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None,
) )
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
...@@ -355,7 +317,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -355,7 +317,6 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.ub_name = ub_name ctx.ub_name = ub_name
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization ctx.normalization = normalization
ctx.primary_weights_in_fp8 = primary_weights_in_fp8
ctx.reduce_and_update_bwd_fp8_tensors = False ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias):
ctx.reduce_and_update_bwd_fp8_tensors = ( ctx.reduce_and_update_bwd_fp8_tensors = (
...@@ -394,26 +355,16 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -394,26 +355,16 @@ class _LayerNormLinear(torch.autograd.Function):
mu, mu,
rsigma, rsigma,
weight, weight,
weight_fp8,
main_grad, main_grad,
weight_t_fp8,
ln_out, ln_out,
fwd_scale_inverses, fwd_scale_inverses,
skip_fp8_weight_update,
) = ctx.saved_tensors ) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
weight = torch.nn.Parameter(weight, False) weight = torch.nn.Parameter(weight, False)
weight.main_grad = main_grad weight.main_grad = main_grad
# Primary weights are in FP8.
if ctx.primary_weights_in_fp8:
weight_t_fp8 = weight.transpose_2d(
cache=ctx.is_first_microbatch is not None,
noop_flag=skip_fp8_weight_update,
)
elif ctx.fp8:
weight_t_fp8 = weight_t_fp8._data
if ctx.ub_overlap_rs_dgrad: if ctx.ub_overlap_rs_dgrad:
ctx.ub_bulk_dgrad = False ctx.ub_bulk_dgrad = False
ctx.ub_bulk_wgrad = False ctx.ub_bulk_wgrad = False
...@@ -520,10 +471,10 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -520,10 +471,10 @@ class _LayerNormLinear(torch.autograd.Function):
# DGRAD: Evaluated unconditionally to feed into Linear backward # DGRAD: Evaluated unconditionally to feed into Linear backward
_ = tex.fp8_gemm( _ = tex.fp8_gemm(
weight_t_fp8, weight_fp8.transpose_2d(),
fwd_scale_inverses, weight_fp8._scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT, 0,
fp8_dtype_forward, weight_fp8._fp8_dtype,
grad_output_c._data grad_output_c._data
if isinstance(grad_output_c, Float8Tensor) else grad_output_c, if isinstance(grad_output_c, Float8Tensor) else grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv, ctx.fp8_meta["scaling_bwd"].scale_inv,
...@@ -712,37 +663,34 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -712,37 +663,34 @@ class _LayerNormLinear(torch.autograd.Function):
dgamma, dgamma,
dbeta, dbeta,
wgrad, wgrad,
None, None, # weight_fp8
None,
grad_bias, grad_bias,
None, None, # use_bias
None, None, # eps
None, None, # is_first_microbatch
None, None, # fp8
None, None, # fp8_calibration
None, None, # fp8_meta
None, None, # fuse_wgrad_accumulation
None, None, # cpu_offloading
None, None, # tp_group
None, None, # tp_size
None, None, # sequence_parallel
None, None, # tensor_parallel
None, None, # activation_dtype
None, None, # parallel_mode
None, None, # return_layernorm_output
None, None, # return_layernorm_output_gathered
None, None, # is_grad_enabled
None, None, # fwd_ln_sm_margin
None, None, # bwd_ln_sm_margin
None, None, # zero_centered_gamma
None, None, # normalization
None, None, # ub_bulk_wgrad
None, None, # ub_bulk_dgrad
None, None, # ub_overlap_rs_dgrad
None, None, # ub_overlap_ag
None, None, # ub_name
None,
None,
) )
...@@ -873,7 +821,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -873,7 +821,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_layernorm_output = return_layernorm_output self.return_layernorm_output = return_layernorm_output
self.return_layernorm_output_gathered = return_layernorm_output_gathered self.return_layernorm_output_gathered = return_layernorm_output_gathered
self.zero_centered_gamma = zero_centered_gamma self.zero_centered_gamma = zero_centered_gamma
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_overlap_ag = ub_overlap_ag self.ub_overlap_ag = ub_overlap_ag
...@@ -926,6 +873,9 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -926,6 +873,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
else: else:
self.layer_norm_bias = None self.layer_norm_bias = None
# Initialize params in FP8
with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()
# Contiguous buffers for params # Contiguous buffers for params
weight_tensor = torch.empty( weight_tensor = torch.empty(
self.out_features, self.out_features,
...@@ -998,7 +948,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -998,7 +948,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
# Check if parameters are subviews of buffers # Check if parameters are subviews of buffers
is_subview = (split_start, split_end) != (0, self.out_features) is_subview = (split_start, split_end) != (0, self.out_features)
if is_subview and self.primary_weights_in_fp8: if is_subview and with_fp8_params:
raise RuntimeError( raise RuntimeError(
"Splitting Float8Tensor into multiple params " "Splitting Float8Tensor into multiple params "
"is not supported" "is not supported"
...@@ -1030,13 +980,11 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1030,13 +980,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
bias = torch.Tensor().to(dtype=params_dtype, device=device) bias = torch.Tensor().to(dtype=params_dtype, device=device)
setattr(self, name, bias) setattr(self, name, bias)
if self.primary_weights_in_fp8: if with_fp8_params:
self.init_fp8_metadata() self.init_fp8_metadata()
self.reset_parameters(defer_init=(device == 'meta')) self.reset_parameters(defer_init=(device == 'meta'))
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
# For RPL, bias has to be added after TP collectives # For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM # So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.apply_bias: if self.parallel_mode == "row" and self.apply_bias:
...@@ -1093,31 +1041,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1093,31 +1041,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
elif self.parallel_mode == "column": elif self.parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1)
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
"""
Fetch the fp8 weight tensor placeholders if they exist (when
`is_first_microbatch` is not `None`) or return empty fp8 weight
tensors (if `is_first_microbatch is None`)
"""
if not self.fp8 or self.primary_weights_in_fp8:
return [None, None]
if is_first_microbatch is None:
# Return empty weight placeholders for each fwd/bwd pass
fp8_weight_tensors = self.get_fp8_weights_empty_tensors(
is_first_microbatch
)
else:
# These persistent weight placeholders should've been created in
# `set_fp8_weights` method
fp8_weight_tensors = [self.weight1_fp8, self.weight1_t_fp8]
return fp8_weight_tensors
@no_torch_dynamo() @no_torch_dynamo()
def forward( def forward(
self, self,
...@@ -1151,13 +1074,19 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1151,13 +1074,19 @@ class LayerNormLinear(TransformerEngineBaseModule):
is_first_microbatch = False is_first_microbatch = False
with self.prepare_forward(inp, is_first_microbatch) as inp: with self.prepare_forward(inp, is_first_microbatch) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8."
# Get concatenated weight and bias tensors # Get concatenated weight and bias tensors
weight_tensor = _noop_cat( unfused_weights = [getattr(self, name) for name in self.weight_names]
[getattr(self, name) for name in self.weight_names], if any(isinstance(w, Float8Tensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting Float8Tensor into multiple params "
"is not supported"
) )
else:
unfused_weights = [w.from_float8() for w in unfused_weights]
weight_tensor = _noop_cat(unfused_weights)
if self.use_bias: if self.use_bias:
bias_tensor = _noop_cat( bias_tensor = _noop_cat(
[getattr(self, name) for name in self.bias_names], [getattr(self, name) for name in self.bias_names],
...@@ -1165,9 +1094,44 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1165,9 +1094,44 @@ class LayerNormLinear(TransformerEngineBaseModule):
else: else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused bias_tensor = getattr(self, self.bias_names[0]) # Unused
# Fetch the fp8 weights placeholders (for linear/gemm) # Initialize FP8 weights if needed
weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad( weight_fp8 = None
if self.fp8:
with_transpose = torch.is_grad_enabled()
if (
not with_transpose
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
with_transpose = True
if isinstance(weight_tensor, Float8Tensor):
# Fill transpose cache in FP8 tensor if needed
update_transpose_cache = with_transpose
if update_transpose_cache:
update_transpose_cache = (
is_first_microbatch is_first_microbatch
or skip_fp8_weight_update is not None
)
if update_transpose_cache:
weight_tensor.transpose_2d(
fill_cache=True,
noop_flag=skip_fp8_weight_update,
)
else:
# FP8 cast to workspace buffer
update_workspace = (
is_first_microbatch is None
or is_first_microbatch
or skip_fp8_weight_update is not None
)
weight_fp8 = self.get_fp8_workspace(
tensor=weight_tensor,
fp8_meta_forward=True,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
cache_name=(None if is_first_microbatch is None else "weight"),
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
with_transpose=with_transpose,
) )
from ..cpu_offload import CPUOffloadEnabled from ..cpu_offload import CPUOffloadEnabled
...@@ -1183,13 +1147,11 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1183,13 +1147,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.layer_norm_weight, self.layer_norm_weight,
self.layer_norm_bias, self.layer_norm_bias,
weight_tensor, weight_tensor,
weight1_fp8, weight_fp8,
weight1_t_fp8,
bias_tensor, bias_tensor,
self.apply_bias and not self.gemm_bias_unfused_add, self.apply_bias and not self.gemm_bias_unfused_add,
self.eps, self.eps,
is_first_microbatch, is_first_microbatch,
skip_fp8_weight_update,
self.fp8, self.fp8,
self.fp8_calibration, self.fp8_calibration,
self.fp8_meta, self.fp8_meta,
...@@ -1208,7 +1170,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1208,7 +1170,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
self.normalization, self.normalization,
self.primary_weights_in_fp8,
self.ub_bulk_wgrad, self.ub_bulk_wgrad,
self.ub_bulk_dgrad, self.ub_bulk_dgrad,
self.ub_overlap_rs_dgrad, self.ub_overlap_rs_dgrad,
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""Linear API""" """Linear API"""
import os import os
from typing import Union, Optional, Callable, Tuple, List, Dict, Any from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch import torch
...@@ -62,13 +62,11 @@ class _Linear(torch.autograd.Function): ...@@ -62,13 +62,11 @@ class _Linear(torch.autograd.Function):
def forward( def forward(
ctx, ctx,
weight: Union[Float8Tensor, torch.Tensor], weight: Union[Float8Tensor, torch.Tensor],
weight_fp8: Union[Float8Tensor, None], weight_fp8: Optional[Float8Tensor],
weight_t_fp8: Union[Float8Tensor, None],
inp: torch.Tensor, inp: torch.Tensor,
bias: torch.Tensor, bias: torch.Tensor,
use_bias: bool, use_bias: bool,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
skip_fp8_weight_update: Union[torch.Tensor, None],
fp8: bool, fp8: bool,
fp8_calibration: bool, fp8_calibration: bool,
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
...@@ -81,7 +79,6 @@ class _Linear(torch.autograd.Function): ...@@ -81,7 +79,6 @@ class _Linear(torch.autograd.Function):
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
is_grad_enabled: bool, is_grad_enabled: bool,
primary_weights_in_fp8: bool,
ub_overlap_rs: bool, ub_overlap_rs: bool,
ub_overlap_ag: bool, ub_overlap_ag: bool,
ub_name: str, ub_name: str,
...@@ -99,12 +96,6 @@ class _Linear(torch.autograd.Function): ...@@ -99,12 +96,6 @@ class _Linear(torch.autograd.Function):
assert_dim_for_fp8_exec(inputmat) assert_dim_for_fp8_exec(inputmat)
assert_dim_for_fp8_exec(weight) assert_dim_for_fp8_exec(weight)
update_fp8_weights = (
is_first_microbatch is None
or is_first_microbatch
or skip_fp8_weight_update is not None
)
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs
...@@ -163,38 +154,10 @@ class _Linear(torch.autograd.Function): ...@@ -163,38 +154,10 @@ class _Linear(torch.autograd.Function):
) )
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if primary_weights_in_fp8: # Use FP8 weights
# Weight is already in FP8 if weight_fp8 is None:
weight.reset_fp8_meta_scale_inv()
weight_fp8 = weight weight_fp8 = weight
elif update_fp8_weights: assert isinstance(weight_fp8, Float8Tensor)
# Need to cast weights to FP8
weight_fp8 = Float8Tensor(
data=weight_fp8._data,
fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
if (is_grad_enabled
or (is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase())):
fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
cast_out=weight_fp8._data,
transpose_out=weight_t_fp8._data,
noop_flag=skip_fp8_weight_update,
)
else:
cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
out=weight_fp8._data,
)
weight_t_fp8 = None
if is_first_module_in_mha: if is_first_module_in_mha:
proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
...@@ -211,7 +174,7 @@ class _Linear(torch.autograd.Function): ...@@ -211,7 +174,7 @@ class _Linear(torch.autograd.Function):
out = ub_obj_projout.get_ubuf_output(1) out = ub_obj_projout.get_ubuf_output(1)
dim_size = list(inputmat_total.size()) dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // tp_world_size dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = weight.size(0) dim_size[1] = weight_fp8.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
if ub_obj_projout.is_p2p_overlap(): if ub_obj_projout.is_p2p_overlap():
if ub_obj_projout.is_atomic_gemm(): if ub_obj_projout.is_atomic_gemm():
...@@ -231,14 +194,14 @@ class _Linear(torch.autograd.Function): ...@@ -231,14 +194,14 @@ class _Linear(torch.autograd.Function):
ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index]) ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index])
else: else:
dim_size = list(inputmat_total.size()) dim_size = list(inputmat_total.size())
dim_size[1] = weight.size(0) dim_size[1] = weight_fp8.size(0)
out = torch.empty(dim_size, dtype=proj_out_pttype, device=inputmat_total.device) out = torch.empty(dim_size, dtype=proj_out_pttype, device=inputmat_total.device)
_ = fp8_gemm( _ = fp8_gemm(
weight_fp8._data, weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv, weight_fp8._scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT, 0,
fp8_dtype_forward, weight_fp8._fp8_dtype,
inputmat_total._data inputmat_total._data
if isinstance(inputmat_total, Float8Tensor) else inputmat_total, if isinstance(inputmat_total, Float8Tensor) else inputmat_total,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
...@@ -329,8 +292,8 @@ class _Linear(torch.autograd.Function): ...@@ -329,8 +292,8 @@ class _Linear(torch.autograd.Function):
if cpu_offloading: if cpu_offloading:
if fuse_wgrad_accumulation: if fuse_wgrad_accumulation:
weight.main_grad.weight_offloading = True weight.main_grad.weight_offloading = True
if fp8 and weight_t_fp8 is not None: if fp8 and weight_fp8 is not None:
weight_t_fp8.weight_offloading = True weight_fp8.weight_offloading = True
weight.weight_offloading = True weight.weight_offloading = True
if saved_inputmat is not None: if saved_inputmat is not None:
...@@ -340,10 +303,9 @@ class _Linear(torch.autograd.Function): ...@@ -340,10 +303,9 @@ class _Linear(torch.autograd.Function):
saved_inputmat, saved_inputmat,
saved_inputmat_t, saved_inputmat_t,
weight, weight,
weight_fp8,
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
weight_t_fp8 if fp8 else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None,
) )
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8 ctx.fp8 = fp8
...@@ -362,7 +324,6 @@ class _Linear(torch.autograd.Function): ...@@ -362,7 +324,6 @@ class _Linear(torch.autograd.Function):
ctx.tp_size = tp_size ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.is_input_fp8 = is_input_fp8 ctx.is_input_fp8 = is_input_fp8
ctx.primary_weights_in_fp8 = primary_weights_in_fp8
ctx.reduce_and_update_bwd_fp8_tensors = False ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, weight, bias): if ctx.fp8 and requires_grad(inp, weight, bias):
ctx.reduce_and_update_bwd_fp8_tensors = ( ctx.reduce_and_update_bwd_fp8_tensors = (
...@@ -394,25 +355,15 @@ class _Linear(torch.autograd.Function): ...@@ -394,25 +355,15 @@ class _Linear(torch.autograd.Function):
inputmat, inputmat,
inputmat_t, inputmat_t,
weight, weight,
weight_fp8,
main_grad, main_grad,
weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
skip_fp8_weight_update,
) = ctx.saved_tensors ) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
weight = torch.nn.Parameter(weight, False) weight = torch.nn.Parameter(weight, False)
weight.main_grad = main_grad weight.main_grad = main_grad
# Primary weights are in FP8.
if ctx.primary_weights_in_fp8:
weight_t_fp8 = weight.transpose_2d(
cache=ctx.is_first_microbatch is not None,
noop_flag=skip_fp8_weight_update,
)
elif ctx.fp8:
weight_t_fp8 = weight_t_fp8._data
tp_world_size = get_distributed_world_size(ctx.tp_group) tp_world_size = get_distributed_world_size(ctx.tp_group)
ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag
if ctx.ub_overlap_ag: if ctx.ub_overlap_ag:
...@@ -476,10 +427,10 @@ class _Linear(torch.autograd.Function): ...@@ -476,10 +427,10 @@ class _Linear(torch.autograd.Function):
out_index, meta_tensor, output_te_dtype, output_dtype = ( out_index, meta_tensor, output_te_dtype, output_dtype = (
None, None, None, ctx.activation_dtype) None, None, None, ctx.activation_dtype)
dgrad, _ = fp8_gemm( dgrad, _ = fp8_gemm(
weight_t_fp8, weight_fp8.transpose_2d(),
fwd_scale_inverses, weight_fp8._scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT, 0,
fp8_dtype_forward, weight_fp8._fp8_dtype,
grad_output_c, grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv, ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1, tex.FP8BwdTensors.GRAD_OUTPUT1,
...@@ -620,30 +571,27 @@ class _Linear(torch.autograd.Function): ...@@ -620,30 +571,27 @@ class _Linear(torch.autograd.Function):
return ( return (
wgrad, wgrad,
None, None, # weight_fp8
None,
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
grad_bias, grad_bias,
None, None, # use_bias
None, None, # is_first_microbatch
None, None, # fp8
None, None, # fp8_calibration
None, None, # fp8_meta
None, None, # fuse_wgrad_accumulation
None, None, # cpu_offloading
None, None, # tp_group
None, None, # tp_size
None, None, # sequence_parallel
None, None, # tensor_parallel
None, None, # activation_dtype
None, None, # parallel_mode
None, None, # is_grad_enabled
None, None, # ub_overlap_rs
None, None, # ub_overlap_ag
None, None, # ub_name
None, None, # is_first_module_in_mha
None,
None,
) )
...@@ -747,7 +695,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -747,7 +695,6 @@ class Linear(TransformerEngineBaseModule):
self.use_bias = bias self.use_bias = bias
self.return_bias = return_bias self.return_bias = return_bias
self.apply_bias = bias and not return_bias self.apply_bias = bias and not return_bias
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_overlap_rs = ub_overlap_rs self.ub_overlap_rs = ub_overlap_rs
self.ub_overlap_ag = ub_overlap_ag self.ub_overlap_ag = ub_overlap_ag
if ub_overlap_rs or ub_overlap_ag: if ub_overlap_rs or ub_overlap_ag:
...@@ -783,6 +730,9 @@ class Linear(TransformerEngineBaseModule): ...@@ -783,6 +730,9 @@ class Linear(TransformerEngineBaseModule):
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
# Initialize params in FP8
with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()
# Contiguous buffers for params # Contiguous buffers for params
weight_tensor = torch.empty( weight_tensor = torch.empty(
self.out_features, self.out_features,
...@@ -855,7 +805,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -855,7 +805,7 @@ class Linear(TransformerEngineBaseModule):
# Check if parameters are subviews of buffers # Check if parameters are subviews of buffers
is_subview = (split_start, split_end) != (0, self.out_features) is_subview = (split_start, split_end) != (0, self.out_features)
if is_subview and self.primary_weights_in_fp8: if is_subview and with_fp8_params:
raise RuntimeError( raise RuntimeError(
"Splitting Float8Tensor into multiple params " "Splitting Float8Tensor into multiple params "
"is not supported" "is not supported"
...@@ -887,13 +837,11 @@ class Linear(TransformerEngineBaseModule): ...@@ -887,13 +837,11 @@ class Linear(TransformerEngineBaseModule):
bias = torch.Tensor().to(dtype=params_dtype, device=device) bias = torch.Tensor().to(dtype=params_dtype, device=device)
setattr(self, name, bias) setattr(self, name, bias)
if self.primary_weights_in_fp8: if with_fp8_params:
self.init_fp8_metadata() self.init_fp8_metadata()
self.reset_parameters(defer_init=(device == 'meta')) self.reset_parameters(defer_init=(device == 'meta'))
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
# For RPL, bias has to be added after TP collectives # For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM # So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.apply_bias: if self.parallel_mode == "row" and self.apply_bias:
...@@ -922,30 +870,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -922,30 +870,6 @@ class Linear(TransformerEngineBaseModule):
elif self.parallel_mode == "column": elif self.parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1)
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[Float8Tensor]:
"""
Fetch the fp8 weight tensor placeholders if they exist (when
`is_first_microbatch` is not `None`) or return empty fp8 weight
tensors (if `is_first_microbatch is None`)
"""
if not self.fp8 or self.primary_weights_in_fp8:
return [None, None]
if is_first_microbatch is None:
# Return empty weight placeholders for each fwd/bwd pass
fp8_weight_tensors = self.get_fp8_weights_empty_tensors(
is_first_microbatch
)
else:
# These persistent weight placeholders should've been created in
# `set_fp8_weights` method
fp8_weight_tensors = [self.weight1_fp8, self.weight1_t_fp8]
return fp8_weight_tensors
@no_torch_dynamo() @no_torch_dynamo()
def forward( def forward(
self, self,
...@@ -979,18 +903,26 @@ class Linear(TransformerEngineBaseModule): ...@@ -979,18 +903,26 @@ class Linear(TransformerEngineBaseModule):
if skip_fp8_weight_update is not None: if skip_fp8_weight_update is not None:
is_first_microbatch = False is_first_microbatch = False
with self.prepare_forward(inp, with self.prepare_forward(
inp,
is_first_microbatch, is_first_microbatch,
allow_non_contiguous=isinstance(inp,Float8Tensor)) as inp: allow_non_contiguous=isinstance(inp,Float8Tensor),
assert self.fp8 or not self.primary_weights_in_fp8, \ ) as inp:
"Need to run inside fp8_autocast region when weights are stored in FP8."
is_first_module_in_mha = is_first_module_in_mha and self.fp8_meta["recipe"].fp8_mha is_first_module_in_mha = is_first_module_in_mha and self.fp8_meta["recipe"].fp8_mha
# Get concatenated weight and bias tensors # Get concatenated weight and bias tensors
weight_tensor = _noop_cat( unfused_weights = [getattr(self, name) for name in self.weight_names]
[getattr(self, name) for name in self.weight_names], if any(isinstance(w, Float8Tensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting Float8Tensor into multiple params "
"is not supported"
) )
else:
unfused_weights = [w.from_float8() for w in unfused_weights]
weight_tensor = _noop_cat(unfused_weights)
if self.use_bias: if self.use_bias:
bias_tensor = _noop_cat( bias_tensor = _noop_cat(
[getattr(self, name) for name in self.bias_names], [getattr(self, name) for name in self.bias_names],
...@@ -998,9 +930,43 @@ class Linear(TransformerEngineBaseModule): ...@@ -998,9 +930,43 @@ class Linear(TransformerEngineBaseModule):
else: else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused bias_tensor = getattr(self, self.bias_names[0]) # Unused
# Fetch the fp8 weights placeholders (for linear/gemm) # Initialize FP8 weights if needed
weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad( weight_fp8 = None
if self.fp8:
with_transpose = torch.is_grad_enabled()
if (
not with_transpose
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
with_transpose = True
if isinstance(weight_tensor, Float8Tensor):
# Fill transpose cache in FP8 tensor if needed
update_transpose_cache = with_transpose
if update_transpose_cache:
update_transpose_cache = (
is_first_microbatch is_first_microbatch
or skip_fp8_weight_update is not None
)
if update_transpose_cache:
weight_tensor.transpose_2d(
fill_cache=True,
noop_flag=skip_fp8_weight_update,
)
else:
# FP8 cast to workspace buffer
update_workspace = (
is_first_microbatch is None
or is_first_microbatch
)
weight_fp8 = self.get_fp8_workspace(
tensor=weight_tensor,
fp8_meta_forward=True,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
cache_name=(None if is_first_microbatch is None else "weight"),
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
with_transpose=with_transpose,
) )
from ..cpu_offload import CPUOffloadEnabled from ..cpu_offload import CPUOffloadEnabled
...@@ -1013,13 +979,11 @@ class Linear(TransformerEngineBaseModule): ...@@ -1013,13 +979,11 @@ class Linear(TransformerEngineBaseModule):
args = [None] args = [None]
args += ( args += (
weight_tensor, weight_tensor,
weight1_fp8, weight_fp8,
weight1_t_fp8,
inp, inp,
bias_tensor, bias_tensor,
self.apply_bias and not self.gemm_bias_unfused_add, self.apply_bias and not self.gemm_bias_unfused_add,
is_first_microbatch, is_first_microbatch,
skip_fp8_weight_update,
self.fp8, self.fp8,
self.fp8_calibration, self.fp8_calibration,
self.fp8_meta, self.fp8_meta,
...@@ -1032,7 +996,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -1032,7 +996,6 @@ class Linear(TransformerEngineBaseModule):
self.activation_dtype, self.activation_dtype,
self.parallel_mode, self.parallel_mode,
torch.is_grad_enabled(), torch.is_grad_enabled(),
self.primary_weights_in_fp8,
self.ub_overlap_rs, self.ub_overlap_rs,
self.ub_overlap_ag, self.ub_overlap_ag,
self.ub_name, self.ub_name,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment