Unverified Commit 0edf30b8 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] Distributed intermediate/activation tensors for FSDP (#687)



* New TE wrapper for PyTorch FullyShardedDataParallel to make TE modules distribute their activations after the forward pass and gather them before the backward pass
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* simplified TE module setup for FSDP comms
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* FSDP scatter/gather for tensors saved into autograd ctx now working for base TE modules
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* make sure activation recompute disables FSDP scatter/gather
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* make sure Fp8 weight buffers are sharded at the end of the backward pass and gathered before forward
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* Fixed typo in attribute name
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed bug in finding FSDP-wrapped TE modules
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed typo in fp8 weight tensor name
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed incorrect # of gradients
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* Added fp8 amax gradient hook tensor to the parameter reset
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* get rid of erroneous dummy tensor leftover from incorrect rebase
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* Linting fixes
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixing git snafu and removing debug statements
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent a9b4af19
......@@ -4,6 +4,7 @@
import os
import argparse
from functools import partial
import torch
......@@ -11,9 +12,37 @@ import torch.distributed as dist
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper
)
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
# RNG state tracker for checkpointing
rng_seed = 1234
torch.manual_seed(rng_seed)
torch.cuda.manual_seed(rng_seed)
CUDA_RNG_STATES_TRACKER = te.distributed.CudaRNGStatesTracker()
CUDA_RNG_STATES_TRACKER.add('model-parallel-rng', rng_seed)
def get_cuda_rng_tracker():
return CUDA_RNG_STATES_TRACKER
def apply_fsdp_checkpointing(model, blocks):
"""apply activation checkpointing to model
returns None as model is updated directly
"""
wrapper = lambda m: checkpoint_wrapper(m,
checkpoint_fn=te.distributed.checkpoint,
use_reentrant=False,
get_rng_state_tracker=get_cuda_rng_tracker)
check_fn = lambda submodule: isinstance(submodule, blocks)
apply_activation_checkpointing(model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn)
def lowercase(s):
return str(s).lower()
......@@ -41,42 +70,41 @@ te_layer_map = {
'transformerlayer': te.TransformerLayer
}
def te_layer(l):
if l is not None:
if lowercase(l) not in te_layer_map.keys():
raise TypeError
return te_layer_map[lowercase(l)]
return None
def get_layer_args(args):
hidden_size = args.num_heads * args.head_dim
def get_layer_args(opts):
hidden_size = opts.num_heads * opts.head_dim
layer_args = (hidden_size, )
layer_kwargs = {
'params_dtype': args.dtype,
'device': 'meta' if args.defer_init else 'cuda'
'params_dtype': opts.dtype,
'device': 'cuda' if opts.no_defer_init else 'meta',
'get_rng_state_tracker': get_cuda_rng_tracker,
}
if args.layer_type in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]:
ffn_hidden_size = 3 * hidden_size if args.num_layers == 1 else hidden_size
if opts.layer_type in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]:
ffn_hidden_size = 3 * hidden_size if opts.num_layers == 1 else hidden_size
layer_args += (ffn_hidden_size, )
layer_kwargs['bias'] = True
if args.layer_type == te.LayerNormMLP:
layer_kwargs['seq_length'] = args.seq_length
elif args.layer_type == te.MultiheadAttention:
layer_args += (args.num_heads, )
if opts.layer_type == te.LayerNormMLP:
layer_kwargs['seq_length'] = opts.seq_length
elif opts.layer_type == te.MultiheadAttention:
layer_args += (opts.num_heads, )
layer_kwargs['fuse_qkv_params'] = True
elif args.layer_type == te.TransformerLayer:
layer_args += (3 * hidden_size, args.num_heads)
layer_kwargs['input_layernorm'] = True
elif opts.layer_type == te.TransformerLayer:
layer_args += (3 * hidden_size, opts.num_heads)
layer_kwargs['fuse_qkv_params'] = True
layer_kwargs['seq_length'] = args.seq_length
layer_kwargs['seq_length'] = opts.seq_length
return layer_args, layer_kwargs
def parse_fsdp_args():
parser = argparse.ArgumentParser(description="Run Transformer Engine modules with the " +
"torch.distributed.fsdp.FullyShardedDataParallel strategy.")
parser.add_argument("-t", "--layer-type", type=te_layer, default=te.TransformerLayer,
choices=list(te_layer_map.values()),
help="TE module type used to construct the test model.")
parser.add_argument("--no-fp8", action="store_true", default=False,
help="Disables the te.fp8_autocast() context.")
parser.add_argument('-i', "--num-iters", type=int, default=3,
help="Number of dummy 'training' iterations.")
parser.add_argument('-v', "--verbose", action="store_true", default=False,
help="Print out information from all GPUs instead of only the root GPU-0.")
parser.add_argument('-b', "--batch-size", type=int, default=32,
help="Input batch size.")
parser.add_argument('-s', "--seq-length", type=int, default=1048,
......@@ -85,55 +113,67 @@ def parse_fsdp_args():
help="Number of attention heads.")
parser.add_argument('-d', "--head-dim", type=int, default=128,
help="Dimension of each attention head (number of KV channels).")
parser.add_argument('-l', "--num-layers", type=int, default=1,
parser.add_argument('-i', "--num-iters", type=int, default=5,
help="Number of dummy 'training' iterations.")
parser.add_argument('-k', "--num-layers", type=int, default=3,
help="Number of modules chained together with nn.Sequential.")
parser.add_argument("--layer-type", type=te_layer, default=te.TransformerLayer,
choices=list(te_layer_map.values()),
help="TE module type used to construct the test model.")
parser.add_argument("--seed", type=int, default=1234,
help="PyTorch RNG seed.")
parser.add_argument("--defer-init", action="store_true",
parser.add_argument("--profile-memory", action="store_true",
help="Enable memory profiling via torch.profiler.profile().")
parser.add_argument("--profile-name", type=str, default=None,
help="File path for memory profiling.")
parser.add_argument("--checkpoint-layer", type=te_layer, default=None,
help="Recompute activations of the selected layer during the backward " + \
"pass instead of saving.")
parser.add_argument("--no-fp8", action="store_true", default=False,
help="Disables the te.fp8_autocast() context.")
parser.add_argument("--no-defer-init", action="store_true",
help="Defer module parameter initialization until after FSDP sharding.")
parser.add_argument('-v', "--verbose", action="store_true", default=False,
help="Print out information from all GPUs instead of only the root GPU-0.")
parser.add_argument("--no-te-fsdp", action="store_true",
help="Disable sharding of intermediate/activation tensors in TE modules.")
parser.add_argument("--dtype", type=torch_dtype, default=torch.bfloat16,
help="Data type for input tensor and Transformer Engine module parameters.")
return parser.parse_args()
def train(args):
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
def dist_print(text, all_ranks=False, no_new_line=False):
if LOCAL_RANK == 0 or all_ranks:
end = '' if no_new_line else '\n'
print(f"[GPU-{LOCAL_RANK}] " + text, end=end)
def train(opts):
# Initialize torch.distributed global process group
dist.init_process_group(backend="nccl")
torch.cuda.set_device(local_rank)
if local_rank == 0:
print(f"[GPU-0] WORLD_SIZE = {world_size}\n\n", end='')
torch.manual_seed(args.seed)
torch.cuda.set_device(LOCAL_RANK)
dist_print(f"WORLD_SIZE = {WORLD_SIZE}")
torch.manual_seed(opts.seed)
# Construct a simple homogeneous model (only one layer type) with NO PARALLELISM
layer_args, layer_kwargs = get_layer_args(args)
if args.num_layers > 1:
layer_args, layer_kwargs = get_layer_args(opts)
if opts.num_layers > 1:
te_layer_list = []
for i in range(args.num_layers):
if args.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
for i in range(opts.num_layers):
if opts.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
layer_kwargs['layer_number'] = i+1
te_layer_list.append(args.layer_type(*layer_args, **layer_kwargs))
te_layer_list.append(opts.layer_type(*layer_args, **layer_kwargs))
te_model = nn.Sequential(*te_layer_list)
else:
# Single layer model
te_model = args.layer_type(*layer_args, **layer_kwargs)
if local_rank == 0:
print(f"[GPU-0] TransformerEngine Model:\n{te_model}\n", end='')
te_model = opts.layer_type(*layer_args, **layer_kwargs)
# Print out allocated device memory before the model parameters are sharded by FSDP
pre_mem_use = torch.cuda.memory_allocated(device=f"cuda:{local_rank}") * 1e-6
if local_rank == 0 or args.verbose:
print(f"[GPU-{local_rank}] Pre-FSDP memory use = {pre_mem_use}MiB\n", end='')
pre_mem_use = torch.cuda.memory_allocated(device=f"cuda:{LOCAL_RANK}") * 1e-6
dist_print(f"Pre-FSDP memory use = {pre_mem_use}MiB")
# Wrap the model with FSDP
# NOTE: The TE model itself has no inherent parallelism. FSDP shards model parameters and
# controls all communication.
all_gpus = dist.new_group(backend='nccl')
fsdp_wrap_policy = always_wrap_policy
if args.layer_type == te.TransformerLayer:
if opts.layer_type == te.TransformerLayer:
# NOTE: FSDP causes illegal memory access without this special policy for Transformers
fsdp_wrap_policy = partial(transformer_auto_wrap_policy,
transformer_layer_cls={te.TransformerLayer})
......@@ -141,16 +181,23 @@ def train(args):
process_group=all_gpus,
use_orig_params=True,
mixed_precision=MixedPrecision(
param_dtype=args.dtype,
param_dtype=opts.dtype,
reduce_dtype=torch.float32,
),
sync_module_states=True,
auto_wrap_policy=fsdp_wrap_policy)
if opts.checkpoint_layer is not None:
# Recompute the activations of the selected layer during the backward pass instead of
# saving them during the forward pass
apply_fsdp_checkpointing(te_model, blocks=opts.checkpoint_layer)
elif not opts.no_te_fsdp:
# Prepare TE modules to shard internal buffers that FSDP cannot shard on its own
prepare_te_modules_for_fsdp(te_model)
# Print out allocated device memory after the model parameters are sharded
post_mem_use = torch.cuda.memory_allocated(device=f"cuda:{local_rank}") * 1e-6
if local_rank == 0 or args.verbose:
print(f"[GPU-{local_rank}] Post-FSDP memory use = {post_mem_use}MiB\n", end='')
post_mem_use = torch.cuda.memory_allocated(device=f"cuda:{LOCAL_RANK}") * 1e-6
dist_print(f"Post-FSDP memory use = {post_mem_use}MiB")
dist_print(f"FSDP-Wrapped + Checkpointed TE Model:\n{te_model}")
# Fp8 setup for TE
fp8_format = Format.HYBRID
......@@ -159,37 +206,46 @@ def train(args):
# Optimizer must be created after the model is wrapped in FSDP and the parameters are sharded
optim = torch.optim.Adam(te_model.parameters(), lr=0.0001)
# Start and time dummy "training" iterations
# Profile memory use
if opts.profile_memory:
torch.cuda.memory._record_memory_history(max_entries=100000)
else:
torch.cuda.reset_peak_memory_stats()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start.record()
for i in range(args.num_iters):
for i in range(opts.num_iters):
# Generate a random input batch
x = torch.rand(args.seq_length, args.batch_size,
args.num_heads*args.head_dim).to(dtype=args.dtype).cuda()
x = torch.rand(opts.seq_length, opts.batch_size, opts.num_heads*opts.head_dim,
dtype=opts.dtype, device='cuda')
# fp8_autocast needs to be given the FSDP process group for amax reductions
with te.fp8_autocast(enabled=not args.no_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus):
with te.fp8_autocast(enabled=not opts.no_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus):
y = te_model(x)
loss = y.sum()
# calculate gradient and take training step outside the fp8_autocast context
loss.backward()
optim.step()
optim.zero_grad(set_to_none=True)
del x
if local_rank == 0:
print(f"[GPU-0] Iter. {i+1}\n", end='')
if opts.profile_memory:
torch.cuda.memory._dump_snapshot(f"gpu{LOCAL_RANK}_{opts.profile_name}.pickle")
torch.cuda.memory._record_memory_history(enabled=None)
else:
end.record()
torch.cuda.synchronize()
# Print out "training" time and peak memory use stats
peak_mem = torch.cuda.max_memory_allocated()
train_time = start.elapsed_time(end)/1000.
max_memory_alloc = int(torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") * 1e-6)
if local_rank == 0 or args.verbose:
print(f"[GPU-{local_rank}] Training Time: {train_time}s\n" +
f"[GPU-{local_rank}] Avg. Iter. Time: {train_time /args.num_iters}s\n" +
f"[GPU-{local_rank}] Peak memory use = {max_memory_alloc}MiB\n\n", end='')
dist_print(f"Training Time: {train_time}s")
dist_print(f"Avg. Iter. Time: {train_time / opts.num_iters}s")
dist_print(f"Peak Memory Use: {peak_mem * 1e-6}MBs")
# Run with:
# torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) test_fsdp.py --defer-init
if __name__ == "__main__":
args = parse_fsdp_args()
train(args)
......@@ -5,15 +5,19 @@
"""Methods needed for distributed training (DP/TP)."""
import warnings
from contextlib import contextmanager, AbstractContextManager, ContextDecorator
from typing import Any, Dict, List, Union, Optional, Callable, Tuple
from typing import Any, Dict, Union, Optional, Callable, Tuple, List
import torch
from torch.cuda import _lazy_call, _lazy_init
from torch.utils.checkpoint import detach_variable, noop_context_fn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules
from .utils import safely_set_viewless_tensor_data
from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager
from .float8_tensor import Float8Tensor
__all__ = ["checkpoint", "CudaRNGStatesTracker"]
......@@ -630,6 +634,11 @@ def checkpoint(
**kwargs
)
# If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need
# to scatter/gather activations that we will recompute anyway.
setattr(function, "fsdp_wrapped", False)
setattr(function, "fsdp_group", None)
# Otherwise discard unused te.utils.checkpoint.checkpoint() arguments
# and execute TE's own checkpointing
# NOTE: This logic uses the TE checkpoint on all custom callable `function` handles because we
......@@ -856,3 +865,110 @@ def allreduce(
handle = torch.distributed.all_reduce(input_, group=tp_group, async_op=async_op)
return input_, handle
def _fsdp_scatter_tensors(
fsdp_group: dist_group_type,
*tensors: torch.Tensor,
):
shapes = []
if fsdp_group is not None:
for t in tensors:
if isinstance(t, torch.Tensor):
target = t._data if isinstance(t, Float8Tensor) else t
shapes.append(target.data.shape)
safely_set_viewless_tensor_data(
target, split_tensor_into_1d_equal_chunks(
target.data, fsdp_group, new_buffer=True)
)
else:
shapes.append(None)
return shapes
def _fsdp_gather_tensors(
fsdp_group: dist_group_type,
shapes: List[Tuple[int,...]],
*tensors: torch.Tensor,
):
if fsdp_group is not None:
assert len(shapes) == len(tensors), "Number of tensors and tensor shapes must be equal."
for s, t in zip(shapes, tensors):
if isinstance(t, torch.Tensor):
assert s is not None, "Internal TE error."
target = t._data if isinstance(t, Float8Tensor) else t
safely_set_viewless_tensor_data(
target, gather_split_1d_tensor(target.data, fsdp_group).view(s)
)
def _is_te_module(module):
"""
Check if given module is a Transformer Engine module that requires the TE checkpoint
implementation for activation recompute.
"""
from .module import LayerNorm, RMSNorm
from .module.base import TransformerEngineBaseModule
from .attention import UnfusedDotProductAttention, DotProductAttention, MultiheadAttention
from .transformer import TransformerLayer
te_classes_list = [
LayerNorm,
RMSNorm,
TransformerEngineBaseModule,
UnfusedDotProductAttention,
DotProductAttention,
MultiheadAttention,
TransformerLayer,
]
is_te_module = False
for te_class in te_classes_list:
if isinstance(module, te_class):
is_te_module = True
break
return is_te_module
def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
"""
Inject FSDP process gorup references into FSDP-wrapped TE modules in an FSDP-wrapped root
module in order to scatter/gather the Fp8 weight copies at the same time FSDP scatters/gathers
its `FlatParameters`.
Parameters
----------
fsdp_root: torch.nn.Module
FSDP-wrapped root module that may contain FSDP-wrapped TE modules.
"""
assert isinstance(fsdp_root, FSDP), "Root module must be FSDP-wrapped."
assert not fsdp_root.primary_weights_in_fp8, (
"TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
"Please initialize your model without the te.fp8_model_init(...) context."
)
# If the root module is a TE module, inject FSDP information into it
if _is_te_module(fsdp_root.module):
root_state = _get_module_fsdp_state(fsdp_root)
assert root_state is not None, "Root module does not have a valid _FSDPState."
setattr(fsdp_root.module, "fsdp_group", root_state.process_group)
# Iterate through all FSDP-wrapped submodules and inject FSDP information into TE modules
fsdp_states, fsdp_modules = _get_fsdp_states_with_modules(fsdp_root)
for state, fsdp_module in zip(fsdp_states, fsdp_modules):
assert not fsdp_module.module.primary_weights_in_fp8, (
"TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
"Please initialize your model without the te.fp8_model_init(...) context."
)
if _is_te_module(fsdp_module.module):
setattr(fsdp_module.module, "fsdp_group", state.process_group)
class FullyShardedDataParallel(FSDP):
"""
Transformer Engine wrapper around `torch.distributed.fsdp.FullyShardedDataParallel` that
extracts necessary information out of the FSDP wrap for TE modules to scatter their
activation tensors after each forward pass and gather them before the backward pass.
"""
def __init__(self, module, *args, **kwargs):
super().__init__(module, *args, **kwargs)
prepare_te_modules_for_fsdp(self)
......@@ -26,6 +26,7 @@ from ..distributed import (
gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
_fsdp_gather_tensors,
)
from ..cpp_extensions import (
fp8_cast_transpose_fused,
......@@ -254,6 +255,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.sequence_parallel = False
self.param_init_meta = {}
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.fsdp_wrapped = False
self.fsdp_group = None
self._fp8_workspaces: Dict[str, Float8Tensor] = {}
def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
......@@ -760,6 +763,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
update_workspace: bool = True,
skip_update_flag: Optional[torch.Tensor] = None,
with_transpose: bool = False,
fsdp_group: dist_group_type = None,
) -> Float8Tensor:
"""Get FP8 workspace buffer and maybe update its values
......@@ -786,13 +790,22 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
over `update_workspace` if provided.
with_transpose: bool, default = `False`
Whether to initialize cached transpose in workspace.
fsdp_group: bool, default = None
FSDP process group that the weights are distributed over.
"""
# Construct workspace if needed
out = None
if cache_name is not None:
out = self._fp8_workspaces.get(cache_name, None)
# Gather cached Fp8 workspace if it's distributed
# NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
# for models initialized with Fp8 primary weights.
if (not isinstance(out, Float8Tensor) and
fsdp_group is not None and
out._data.shape != tensor.data.shape):
_fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out)
if out is None:
if (
tensor is None
......
......@@ -38,6 +38,8 @@ from ..distributed import (
gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
_fsdp_scatter_tensors,
_fsdp_gather_tensors,
)
from ..constants import GemmParallelModes, dist_group_type, TE_DType
from ..jit import no_torch_dynamo
......@@ -89,6 +91,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_overlap_rs_dgrad: bool,
ub_overlap_ag: bool,
ub_name: str,
fsdp_group: Union[dist_group_type, None],
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -196,6 +199,7 @@ class _LayerNormLinear(torch.autograd.Function):
# Use FP8 weights
if weight_fp8 is None:
weight_fp8 = weight
assert isinstance(weight_fp8, Float8Tensor)
if fp8_meta["recipe"].fp8_mha:
......@@ -281,6 +285,18 @@ class _LayerNormLinear(torch.autograd.Function):
rsigma.activation_offloading = True
ln_out.activation_offloading = True
# Scatter intermediate/activation tensors saved for the backward pass
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
# shards/unshards the base weights so we don't do it ourselves
ctx.fsdp_group = fsdp_group
ctx.fsdp_shapes = _fsdp_scatter_tensors(
fsdp_group,
mu,
rsigma,
weight_fp8 if fp8 and not isinstance(weight, Float8Tensor) else None,
ln_out if weight.requires_grad else None,
)
ctx.save_for_backward(
inputmat,
ln_weight,
......@@ -331,6 +347,7 @@ class _LayerNormLinear(torch.autograd.Function):
# [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.view(-1, *inp.shape[1:-1], out.shape[-1])
if return_layernorm_output:
if return_layernorm_output_gathered:
shape = list(inp.shape)
......@@ -361,6 +378,18 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_scale_inverses,
) = ctx.saved_tensors
# Gather intermediate/activation tensors if needed
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
# shards/unshards the base weights so we don't do it ourselves
_fsdp_gather_tensors(
ctx.fsdp_group,
ctx.fsdp_shapes,
mu,
rsigma,
weight_fp8 if ctx.fp8 and not isinstance(weight, Float8Tensor) else None,
ln_out,
)
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
weight = torch.nn.Parameter(weight, False)
weight.main_grad = main_grad
......@@ -630,6 +659,8 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
dbeta = None
clear_tensor_data(mu)
clear_tensor_data(rsigma)
if not ctx.use_bias:
grad_bias = None
......@@ -658,6 +689,10 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
# Scatter fp8 weight buffers
if ctx.fp8 and not isinstance(weight, Float8Tensor):
_fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
return (
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma,
......@@ -691,6 +726,7 @@ class _LayerNormLinear(torch.autograd.Function):
None, # ub_overlap_rs_dgrad
None, # ub_overlap_ag
None, # ub_name
None, # fsdp_group
)
......@@ -1175,6 +1211,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.ub_overlap_rs_dgrad,
self.ub_overlap_ag,
self.ub_name,
self.fsdp_group,
)
out = fwd_fn(*args)
......
......@@ -44,6 +44,8 @@ from ..distributed import (
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
use_reentrant_activation_recompute,
_fsdp_scatter_tensors,
_fsdp_gather_tensors,
)
from .. import cpp_extensions as tex
......@@ -119,6 +121,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_overlap_rs: bool,
ub_overlap_ag: bool,
gemm_gelu_fusion: bool,
fsdp_group: Union[dist_group_type, None],
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -220,6 +223,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_weight_fp8 = fc1_weight
if fc2_weight_fp8 is None:
fc2_weight_fp8 = fc2_weight
assert isinstance(fc1_weight_fp8, Float8Tensor)
assert isinstance(fc2_weight_fp8, Float8Tensor)
......@@ -440,6 +444,21 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_out.activation_offloading = True
gelu_out.activation_offloading = True
# Scatter intermediate/activation tensors saved for the backward pass
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
# shards/unshards the base weights so we don't do it ourselves
ctx.fsdp_group = fsdp_group
ctx.fsdp_shapes = _fsdp_scatter_tensors(
fsdp_group,
mu,
rsigma,
ln_out,
fc1_out,
gelu_out,
fc1_weight_fp8 if fp8 and not isinstance(fc1_weight, Float8Tensor) else None,
fc2_weight_fp8 if fp8 and not isinstance(fc2_weight, Float8Tensor) else None,
)
ctx.save_for_backward(
inputmat,
ln_weight,
......@@ -457,6 +476,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_bias,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
)
ctx.activation_dtype = activation_dtype
ctx.activation = activation
ctx.fp8 = fp8
......@@ -532,6 +552,21 @@ class _LayerNormMLP(torch.autograd.Function):
fwd_scale_inverses,
) = ctx.saved_tensors
# Gather saved autograd context tensors when running with FSDP
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
# shards/unshards the base weights so we don't do it ourselves
_fsdp_gather_tensors(
ctx.fsdp_group,
ctx.fsdp_shapes,
mu,
rsigma,
ln_out,
fc1_out,
gelu_out,
fc1_weight_fp8 if ctx.fp8 and not isinstance(fc1_weight, Float8Tensor) else None,
fc2_weight_fp8 if ctx.fp8 and not isinstance(fc2_weight, Float8Tensor) else None,
)
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
fc1_weight = Parameter(fc1_weight, False)
fc2_weight = Parameter(fc2_weight, False)
......@@ -1006,6 +1041,8 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
dbeta = None
clear_tensor_data(mu)
clear_tensor_data(rsigma)
if fc1_weight.requires_grad:
# Handle custom DDP from mcore.
......@@ -1052,6 +1089,14 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
# Scatter Fp8 tranposed-weight buffers
if ctx.fp8:
_fsdp_scatter_tensors(
ctx.fsdp_group,
fc1_weight_fp8 if not isinstance(fc1_weight, Float8Tensor) else None,
fc2_weight_fp8 if not isinstance(fc2_weight, Float8Tensor) else None
)
return (
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma,
......@@ -1092,6 +1137,7 @@ class _LayerNormMLP(torch.autograd.Function):
None, # ub_overlap_rs
None, # ub_overlap_ag
None, # gemm_gelu_fusion
None, # fsdp_group
)
......@@ -1542,6 +1588,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.ub_overlap_rs,
self.ub_overlap_ag,
self.gemm_gelu_fusion,
self.fsdp_group,
)
out = fwd_fn(*args)
......
......@@ -36,6 +36,8 @@ from ..distributed import (
gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
_fsdp_scatter_tensors,
_fsdp_gather_tensors,
)
from ..cpp_extensions import (
fp8_gemm,
......@@ -83,6 +85,7 @@ class _Linear(torch.autograd.Function):
ub_overlap_ag: bool,
ub_name: str,
is_first_module_in_mha: bool,
fsdp_group: Union[dist_group_type, None],
) -> torch.Tensor:
is_input_fp8 = isinstance(inp, Float8Tensor)
if is_input_fp8:
......@@ -157,6 +160,7 @@ class _Linear(torch.autograd.Function):
# Use FP8 weights
if weight_fp8 is None:
weight_fp8 = weight
assert isinstance(weight_fp8, Float8Tensor)
if is_first_module_in_mha:
......@@ -299,6 +303,16 @@ class _Linear(torch.autograd.Function):
if saved_inputmat is not None:
saved_inputmat.activation_offloading = True
# Scatter intermediate/activation tensors saved for the backward pass
# NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights
ctx.fsdp_group = fsdp_group
ctx.fsdp_shapes = _fsdp_scatter_tensors(
fsdp_group,
saved_inputmat, # None if fp8 == False
saved_inputmat_t, # None if fp8 == False AND not is_grad_enabled
weight_fp8 if fp8 and not isinstance(weight, Float8Tensor) else None,
)
ctx.save_for_backward(
saved_inputmat,
saved_inputmat_t,
......@@ -307,6 +321,7 @@ class _Linear(torch.autograd.Function):
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
)
ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8
ctx.fp8_meta = fp8_meta
......@@ -360,6 +375,16 @@ class _Linear(torch.autograd.Function):
fwd_scale_inverses,
) = ctx.saved_tensors
# Gather intermediate/activation tensors if needed
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
# shards/unshards the base weights so we don't do it ourselves
_fsdp_gather_tensors(
ctx.fsdp_group,
ctx.fsdp_shapes,
inputmat,
inputmat_t,
weight_fp8 if ctx.fp8 and not isinstance(weight, Float8Tensor) else None)
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
weight = torch.nn.Parameter(weight, False)
weight.main_grad = main_grad
......@@ -569,6 +594,10 @@ class _Linear(torch.autograd.Function):
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
# Scatter fp8 weight buffers
if ctx.fp8 and not isinstance(weight, Float8Tensor):
_fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
return (
wgrad,
None, # weight_fp8
......@@ -592,6 +621,7 @@ class _Linear(torch.autograd.Function):
None, # ub_overlap_ag
None, # ub_name
None, # is_first_module_in_mha
None, # fsdp_group
)
......@@ -967,6 +997,7 @@ class Linear(TransformerEngineBaseModule):
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
with_transpose=with_transpose,
fsdp_group=self.fsdp_group,
)
from ..cpu_offload import CPUOffloadEnabled
......@@ -1000,6 +1031,7 @@ class Linear(TransformerEngineBaseModule):
self.ub_overlap_ag,
self.ub_name,
is_first_module_in_mha,
self.fsdp_group,
)
out = linear_fn(*args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment