Unverified Commit 434d58fa authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] Deferred Initialization via `device='meta'` option (#596)



* Implemented deferred initialization via `device='meta'` option for te.Linear and added new PyTorch example to demonstrate its use with FullyShardedDataParallel execution.
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* correcting Float8Tensor initialization and fixing linting errors
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* removed duplicate code from upstream rebase, local tests passing
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* improved comments/documentation for FSDP example
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* converted reset_parameters() into a base module function
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed Float8Tensor creation with deferred init, all tests passing locally
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* extended deferred initialization to all TE modules
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed linting errors
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* removed unnecessary reference to the parent module of parameter, added clarifying comments in parameter reset
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent c4d5f365
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Basic Example for Using PyTorch Fully Sharded Data Parallel mode with Transformer Engine
```bash
# FSDP without deferred initialization:
# Duplicate modules initialized on each device. Load on device memory reduced only after
# torch.distributed.fsdp.FullyShardedDataParallel mode shards model parameters.
$ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsdp.py
# Sample output on 8xL40S:
# [GPU-0] WORLD_SIZE = 8
# [GPU-0] TransformerEngine Model:
# TransformerLayer(
# (self_attention): MultiheadAttention(
# (layernorm_qkv): LayerNormLinear()
# (core_attention): DotProductAttention(
# (flash_attention): FlashAttention()
# (fused_attention): FusedAttention()
# (unfused_attention): UnfusedDotProductAttention(
# (scale_mask_softmax): FusedScaleMaskSoftmax()
# (attention_dropout): Dropout(p=0.1, inplace=False)
# )
# )
# (proj): Linear()
# )
# (layernorm_mlp): LayerNormMLP()
# )
# [GPU-0] Pre-FSDP memory use = 83.935232MiB
# [GPU-0] Post-FSDP memory use = 10.491904MiB
# [GPU-0] Iter. 1
# [GPU-0] Iter. 2
# [GPU-0] Iter. 3
# [GPU-0] Training Time: 6.647654296875s
# [GPU-0] Avg. Iter. Time: 2.2158847656250003s
# [GPU-0] Peak memory use = 3000MiB
# FSDP with deferred initialization:
# Modules initialized with empty paramaters via `device='meta'` option. Zero load on device
# memory until torch.distributed.fsdp.FullyShardedDataParallel mode triggers a reset on
# on already sharded model parameters.
$ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsdp.py --defer-init
# Sample output on 8xL40S:
# [GPU-0] WORLD_SIZE = 8
# ...
# [GPU-0] Pre-FSDP memory use = 0.0MiB
# [GPU-0] Post-FSDP memory use = 10.491904MiB
# ...
```
**NOTE:** This example has `fp8_autocast()` enabled by default. To run on GPUs without Fp8 support
(e.g.: A100), add the `--no-fp8` option to the commands shown above.
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import argparse
from functools import partial
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
def lowercase(s):
return str(s).lower()
def torch_dtype(d):
typemap = {
'fp32' : torch.float32,
'float32' : torch.float32,
'fp16' : torch.float16,
'float16' : torch.float16,
'bf16' : torch.bfloat16,
'bfloat16' : torch.bfloat16
}
if lowercase(d) not in typemap.keys():
raise TypeError
return typemap[lowercase(d)]
te_layer_map = {
'linear': te.Linear,
'layernorm': te.LayerNorm,
'rmsnorm': te.RMSNorm,
'layernormlinear': te.LayerNormLinear,
'layernormmlp': te.LayerNormMLP,
'multiheadattention': te.MultiheadAttention,
'transformerlayer': te.TransformerLayer
}
def te_layer(l):
if lowercase(l) not in te_layer_map.keys():
raise TypeError
return te_layer_map[lowercase(l)]
def get_layer_args(args):
hidden_size = args.num_heads * args.head_dim
layer_args = (hidden_size, )
layer_kwargs = {
'params_dtype': args.dtype,
'device': 'meta' if args.defer_init else 'cuda'
}
if args.layer_type in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]:
ffn_hidden_size = 3 * hidden_size if args.num_layers == 1 else hidden_size
layer_args += (ffn_hidden_size, )
layer_kwargs['bias'] = True
if args.layer_type == te.LayerNormMLP:
layer_kwargs['seq_length'] = args.seq_length
elif args.layer_type == te.MultiheadAttention:
layer_args += (args.num_heads, )
layer_kwargs['fuse_qkv_params'] = True
elif args.layer_type == te.TransformerLayer:
layer_args += (3 * hidden_size, args.num_heads)
layer_kwargs['fuse_qkv_params'] = True
layer_kwargs['seq_length'] = args.seq_length
return layer_args, layer_kwargs
def parse_fsdp_args():
parser = argparse.ArgumentParser(description="Run Transformer Engine modules with the " +
"torch.distributed.fsdp.FullyShardedDataParallel strategy.")
parser.add_argument("-t", "--layer-type", type=te_layer, default=te.TransformerLayer,
choices=list(te_layer_map.values()),
help="TE module type used to construct the test model.")
parser.add_argument("--no-fp8", action="store_true", default=False,
help="Disables the te.fp8_autocast() context.")
parser.add_argument('-i', "--num-iters", type=int, default=3,
help="Number of dummy 'training' iterations.")
parser.add_argument('-b', "--batch-size", type=int, default=32,
help="Input batch size.")
parser.add_argument('-s', "--seq-length", type=int, default=1048,
help="Input sequence length.")
parser.add_argument('-n', "--num-heads", type=int, default=16,
help="Number of attention heads.")
parser.add_argument('-d', "--head-dim", type=int, default=128,
help="Dimension of each attention head (number of KV channels).")
parser.add_argument('-l', "--num-layers", type=int, default=1,
help="Number of modules chained together with nn.Sequential.")
parser.add_argument("--seed", type=int, default=1234,
help="PyTorch RNG seed.")
parser.add_argument("--defer-init", action="store_true",
help="Defer module parameter initialization until after FSDP sharding.")
parser.add_argument('-v', "--verbose", action="store_true", default=False,
help="Print out information from all GPUs instead of only the root GPU-0.")
parser.add_argument("--dtype", type=torch_dtype, default=torch.bfloat16,
help="Data type for input tensor and Transformer Engine module parameters.")
return parser.parse_args()
def train(args):
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
# Initialize torch.distributed global process group
dist.init_process_group(backend="nccl")
torch.cuda.set_device(local_rank)
if local_rank == 0:
print(f"[GPU-0] WORLD_SIZE = {world_size}\n\n", end='')
torch.manual_seed(args.seed)
# Construct a simple homogeneous model (only one layer type) with NO PARALLELISM
layer_args, layer_kwargs = get_layer_args(args)
if args.num_layers > 1:
te_layer_list = []
for i in range(args.num_layers):
if args.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
layer_kwargs['layer_number'] = i+1
te_layer_list.append(args.layer_type(*layer_args, **layer_kwargs))
te_model = nn.Sequential(*te_layer_list)
else:
# Single layer model
te_model = args.layer_type(*layer_args, **layer_kwargs)
if local_rank == 0:
print(f"[GPU-0] TransformerEngine Model:\n{te_model}\n", end='')
# Print out allocated device memory before the model parameters are sharded by FSDP
pre_mem_use = torch.cuda.memory_allocated(device=f"cuda:{local_rank}") * 1e-6
if local_rank == 0 or args.verbose:
print(f"[GPU-{local_rank}] Pre-FSDP memory use = {pre_mem_use}MiB\n", end='')
# Wrap the model with FSDP
# NOTE: The TE model itself has no inherent parallelism. FSDP shards model parameters and
# controls all communication.
all_gpus = dist.new_group(backend='nccl')
fsdp_wrap_policy = always_wrap_policy
if args.layer_type == te.TransformerLayer:
# NOTE: FSDP causes illegal memory access without this special policy for Transformers
fsdp_wrap_policy = partial(transformer_auto_wrap_policy,
transformer_layer_cls={te.TransformerLayer})
te_model = FullyShardedDataParallel(te_model,
process_group=all_gpus,
use_orig_params=True,
mixed_precision=MixedPrecision(
param_dtype=args.dtype,
reduce_dtype=torch.float32,
),
sync_module_states=True,
auto_wrap_policy=fsdp_wrap_policy)
# Print out allocated device memory after the model parameters are sharded
post_mem_use = torch.cuda.memory_allocated(device=f"cuda:{local_rank}") * 1e-6
if local_rank == 0 or args.verbose:
print(f"[GPU-{local_rank}] Post-FSDP memory use = {post_mem_use}MiB\n", end='')
# Fp8 setup for TE
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
# Optimizer must be created after the model is wrapped in FSDP and the parameters are sharded
optim = torch.optim.Adam(te_model.parameters(), lr=0.0001)
# Start and time dummy "training" iterations
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start.record()
for i in range(args.num_iters):
# Generate a random input batch
x = torch.rand(args.seq_length, args.batch_size,
args.num_heads*args.head_dim).to(dtype=args.dtype).cuda()
# fp8_autocast needs to be given the FSDP process group for amax reductions
with te.fp8_autocast(enabled=not args.no_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus):
y = te_model(x)
loss = y.sum()
# calculate gradient and take training step outside the fp8_autocast context
loss.backward()
optim.step()
del x
if local_rank == 0:
print(f"[GPU-0] Iter. {i+1}\n", end='')
end.record()
torch.cuda.synchronize()
# Print out "training" time and peak memory use stats
train_time = start.elapsed_time(end)/1000.
max_memory_alloc = int(torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") * 1e-6)
if local_rank == 0 or args.verbose:
print(f"[GPU-{local_rank}] Training Time: {train_time}s\n" +
f"[GPU-{local_rank}] Avg. Iter. Time: {train_time /args.num_iters}s\n" +
f"[GPU-{local_rank}] Peak memory use = {max_memory_alloc}MiB\n\n", end='')
if __name__ == "__main__":
args = parse_fsdp_args()
train(args)
...@@ -4,12 +4,14 @@ ...@@ -4,12 +4,14 @@
"""Internal function used by multiple modules.""" """Internal function used by multiple modules."""
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union, Callable
from dataclasses import dataclass
import torch import torch
from .. import cpp_extensions as tex from .. import cpp_extensions as tex
from ..fp8 import get_fp8_te_dtype from ..fp8 import get_fp8_te_dtype
from ..utils import get_default_init_method
def _get_normalization_func(normalization: str, def _get_normalization_func(normalization: str,
fp8_output: bool, fp8_output: bool,
...@@ -187,3 +189,18 @@ def _noop_cat( ...@@ -187,3 +189,18 @@ def _noop_cat(
# Perform no-op concat # Perform no-op concat
return _NoopCatFunc.apply(split_ranges, full_tensor, *tensors) return _NoopCatFunc.apply(split_ranges, full_tensor, *tensors)
@dataclass
class _ParameterInitMeta:
"""
Stores essential metadata needed to support deferred parameter initialization.
"""
init_fn: Optional[Callable] = get_default_init_method()
get_rng_state_tracker: Optional[Callable] = None
fp8_meta_index: Optional[int] = None
def __post_init__(self):
"""Safeguard reference to the parameter's parent module and initialization function."""
if self.init_fn is None:
self.init_fn = get_default_init_method()
...@@ -16,6 +16,7 @@ import torch ...@@ -16,6 +16,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from ._common import _ParameterInitMeta
from ..export import is_in_onnx_export_mode from ..export import is_in_onnx_export_mode
from ..fp8 import ( from ..fp8 import (
get_default_fp8_recipe, get_default_fp8_recipe,
...@@ -234,6 +235,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -234,6 +235,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["async_amax_reduction"] = bool( self.fp8_meta["async_amax_reduction"] = bool(
int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0")) int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0"))
) )
self.param_init_meta = {}
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
def set_meta_tensor(self, fwd: bool) -> None: def set_meta_tensor(self, fwd: bool) -> None:
"""Init scales and amaxes for fwd | bwd.""" """Init scales and amaxes for fwd | bwd."""
...@@ -746,6 +749,52 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -746,6 +749,52 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
) )
return fp8_weight_tensors return fp8_weight_tensors
def register_parameter(self, name, param, **kwargs):
"""
Thin wrapper around PyTorch parameter registration to stash additional parameter
metedata used in deferred initialization.
"""
super().register_parameter(name, param)
self.param_init_meta[name] = _ParameterInitMeta(**kwargs)
def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
"""
Reset all module parameters to initial values. Unless deferred initialization
is specified, all parameters on a 'meta' device are also materialized on a real cuda
device before the values are reset to initial.
"""
if defer_init:
return
for name, param in self.named_parameters(recurse=False):
# Ensure parameter is on a real device
if param.device == torch.device('meta'):
param = param.to(device='cuda')
# Initialize the parameter values on device
init_fn = self.param_init_meta[name].init_fn
get_rng_state_tracker = self.param_init_meta[name].get_rng_state_tracker
if get_rng_state_tracker is None:
init_fn(param)
else:
with get_rng_state_tracker().fork():
init_fn(param)
# If primary weights are in fp8, wrap the parameter as Float8Tensor
fp8_meta_index = self.param_init_meta[name].fp8_meta_index
if self.primary_weights_in_fp8 and fp8_meta_index is not None:
param = Float8Tensor.to_float8(
param,
fp8_meta=self.fp8_meta,
fp8_meta_index=fp8_meta_index
)
# Redo parameter wrap in case we broke it above
# NOTE: Currently this can only be broken when primary weights are in Fp8 but
# re-applying the nn.Parameter() wrap is a no-op when the input is already
# a parameter so we always re-apply it just for extra safety.
setattr(self, name, torch.nn.Parameter(param))
@abstractmethod @abstractmethod
def forward(self): def forward(self):
"""Needs override.""" """Needs override."""
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""LayerNorm API""" """LayerNorm API"""
import os import os
import warnings
from typing import Union, Tuple, Optional from typing import Union, Tuple, Optional
import torch import torch
...@@ -139,7 +140,8 @@ class LayerNorm(torch.nn.Module): ...@@ -139,7 +140,8 @@ class LayerNorm(torch.nn.Module):
) )
setattr(self.weight, "sequence_parallel", sequence_parallel) setattr(self.weight, "sequence_parallel", sequence_parallel)
setattr(self.bias, "sequence_parallel", sequence_parallel) setattr(self.bias, "sequence_parallel", sequence_parallel)
self.reset_layer_norm_parameters()
self.reset_parameters(defer_init=(device == 'meta'))
# These many SMs are subtracted from the total SM count when calling forward # These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN # and backward LayerNorm C APIs. These envvars can be used to prevent the LN
...@@ -150,12 +152,25 @@ class LayerNorm(torch.nn.Module): ...@@ -150,12 +152,25 @@ class LayerNorm(torch.nn.Module):
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
warnings.warn(
("This method will be deprecated in an upcoming release. "
"Update your code to use LayerNorm.reset_parameters() instead."),
DeprecationWarning,
stacklevel=2
)
if not self.zero_centered_gamma: if not self.zero_centered_gamma:
init.ones_(self.weight) init.ones_(self.weight)
else: else:
init.zeros_(self.weight) init.zeros_(self.weight)
init.zeros_(self.bias) init.zeros_(self.bias)
def reset_parameters(self, defer_init=False) -> None:
"""Init LayerNorm parameters"""
if defer_init:
return
init.constant_(self.weight, float(not self.zero_centered_gamma))
init.zeros_(self.bias)
@no_torch_dynamo() @no_torch_dynamo()
def forward(self, inp: torch.Tensor) -> torch.Tensor: def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""LayerNorm FWD""" """LayerNorm FWD"""
......
...@@ -25,6 +25,7 @@ from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager ...@@ -25,6 +25,7 @@ from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
from ..utils import ( from ..utils import (
divide, divide,
get_default_init_method, get_default_init_method,
init_method_constant,
cast_if_needed, cast_if_needed,
assert_dim_for_fp8_exec, assert_dim_for_fp8_exec,
clear_tensor_data, clear_tensor_data,
...@@ -33,7 +34,6 @@ from ..distributed import ( ...@@ -33,7 +34,6 @@ from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
get_distributed_world_size, get_distributed_world_size,
allreduce, allreduce,
initialize_affine_weight_gpu,
reduce_scatter_along_first_dim, reduce_scatter_along_first_dim,
gather_along_first_dim, gather_along_first_dim,
) )
...@@ -749,43 +749,25 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -749,43 +749,25 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
self.eps = eps self.eps = eps
self.layer_norm_weight = torch.nn.Parameter( layer_norm_weight = torch.nn.Parameter(
torch.empty(in_features, device=device, dtype=params_dtype) torch.empty(in_features, device=device, dtype=params_dtype)
) )
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) self.register_parameter('layer_norm_weight', layer_norm_weight,
init_fn=init_method_constant(float(not self.zero_centered_gamma)))
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) # pylint: disable=access-member-before-definition
if self.normalization != "RMSNorm": if self.normalization != "RMSNorm":
self.layer_norm_bias = torch.nn.Parameter( layer_norm_bias = torch.nn.Parameter(
torch.empty(in_features, device=device, dtype=params_dtype) torch.empty(in_features, device=device, dtype=params_dtype)
) )
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) self.register_parameter('layer_norm_bias', layer_norm_bias)
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) # pylint: disable=access-member-before-definition
else: else:
self.layer_norm_bias = None self.layer_norm_bias = None
self.reset_layer_norm_parameters()
temp_weight = torch.empty( self.weight_tensor = torch.empty(
self.out_features, self.in_features, self.out_features, self.in_features,
device=device, dtype=params_dtype) device=device, dtype=params_dtype)
initialize_affine_weight_gpu(
temp_weight,
init_method,
get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
if self.primary_weights_in_fp8:
self.init_fp8_metadata()
self.fp8_meta["update_amax_and_scale_fwd"] = True
self.weight_tensor = Float8Tensor.to_float8(
temp_weight,
fp8_meta=self.fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
else:
self.weight_tensor = temp_weight
if self.use_bias: if self.use_bias:
self.bias_tensor = torch.empty( self.bias_tensor = torch.empty(
self.out_features, self.out_features,
...@@ -794,9 +776,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -794,9 +776,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
else: else:
self.bias_tensor = torch.Tensor().to(dtype=params_dtype, device=device) self.bias_tensor = torch.Tensor().to(dtype=params_dtype, device=device)
with torch.no_grad():
self.bias_tensor.zero_()
# Configure parameter splits # Configure parameter splits
self.weight_names = [] self.weight_names = []
self.bias_names = [] self.bias_names = []
...@@ -861,7 +840,10 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -861,7 +840,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
if is_subview: if is_subview:
weight = weight[split_start:split_end] weight = weight[split_start:split_end]
weight = torch.nn.Parameter(weight) weight = torch.nn.Parameter(weight)
self.register_parameter(self.weight_names[i], weight) self.register_parameter(self.weight_names[i], weight,
init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT)
# Construct bias parameter if needed # Construct bias parameter if needed
if self.use_bias: if self.use_bias:
...@@ -892,8 +874,13 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -892,8 +874,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
del self.weight_tensor del self.weight_tensor
del self.bias_tensor del self.bias_tensor
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) if self.primary_weights_in_fp8:
self.init_fp8_metadata()
self.fp8_meta["update_amax_and_scale_fwd"] = True
self.reset_parameters(defer_init=(device == 'meta'))
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
# For RPL, bias has to be added after TP collectives # For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM # So it cannot be fused with the GEMM
...@@ -911,6 +898,12 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -911,6 +898,12 @@ class LayerNormLinear(TransformerEngineBaseModule):
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
warnings.warn(
("This method will be deprecated in an upcoming release. "
"Update your code to use LayerNormLinear.reset_parameters() instead."),
DeprecationWarning,
stacklevel=2
)
if not self.zero_centered_gamma: if not self.zero_centered_gamma:
init.ones_(self.layer_norm_weight) init.ones_(self.layer_norm_weight)
else: else:
......
...@@ -30,6 +30,7 @@ from ..jit import ( ...@@ -30,6 +30,7 @@ from ..jit import (
from ..utils import ( from ..utils import (
divide, divide,
get_default_init_method, get_default_init_method,
init_method_constant,
cast_if_needed, cast_if_needed,
assert_dim_for_fp8_exec, assert_dim_for_fp8_exec,
clear_tensor_data, clear_tensor_data,
...@@ -38,7 +39,6 @@ from ..distributed import ( ...@@ -38,7 +39,6 @@ from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
get_distributed_world_size, get_distributed_world_size,
allreduce, allreduce,
initialize_affine_weight_gpu,
reduce_scatter_along_first_dim, reduce_scatter_along_first_dim,
gather_along_first_dim, gather_along_first_dim,
) )
...@@ -1170,91 +1170,76 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1170,91 +1170,76 @@ class LayerNormMLP(TransformerEngineBaseModule):
# LN init # LN init
self.eps = eps self.eps = eps
self.layer_norm_weight = Parameter( layer_norm_weight = Parameter(
torch.empty(hidden_size, device=device, dtype=params_dtype) torch.empty(hidden_size, device=device, dtype=params_dtype)
) )
self.register_parameter('layer_norm_weight', layer_norm_weight,
init_fn=init_method_constant(float(not self.zero_centered_gamma)))
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
if self.normalization != "RMSNorm": if self.normalization != "RMSNorm":
self.layer_norm_bias = Parameter( layer_norm_bias = Parameter(
torch.empty(hidden_size, device=device, dtype=params_dtype) torch.empty(hidden_size, device=device, dtype=params_dtype)
) )
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) self.register_parameter('layer_norm_bias', layer_norm_bias)
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) # pylint: disable=access-member-before-definition
else: else:
self.layer_norm_bias = None self.layer_norm_bias = None
self.reset_layer_norm_parameters()
# FC1 init
if self.activation in ['reglu', 'geglu', 'swiglu']: if self.activation in ['reglu', 'geglu', 'swiglu']:
fc1_output_features = 2 * self.size_per_partition fc1_output_features = 2 * self.size_per_partition
else: else:
fc1_output_features = self.size_per_partition fc1_output_features = self.size_per_partition
# FC1 init
fc1_temp_weight = torch.empty(
fc1_output_features, hidden_size, device=device, dtype=params_dtype)
initialize_affine_weight_gpu(
fc1_temp_weight,
init_method,
get_rng_state_tracker,
set_tp_attributes=False,
)
if self.primary_weights_in_fp8: fc1_weight = Parameter(
self.init_fp8_metadata(num_gemms=2) torch.empty(
self.fp8_meta["update_amax_and_scale_fwd"] = True fc1_output_features, hidden_size, device=device, dtype=params_dtype
fc1_temp_weight = Float8Tensor.to_float8(
fc1_temp_weight,
fp8_meta=self.fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
) )
)
self.fc1_weight = Parameter(fc1_temp_weight) self.register_parameter('fc1_weight', fc1_weight,
init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT)
set_tensor_model_parallel_attributes(self.fc1_weight, True, 0, 1) set_tensor_model_parallel_attributes(self.fc1_weight, True, 0, 1)
self.fp8_weight_shapes.append(self.fc1_weight.shape) self.fp8_weight_shapes.append(self.fc1_weight.shape)
if self.use_bias: if self.use_bias:
self.fc1_bias = Parameter( fc1_bias = Parameter(
torch.empty(fc1_output_features, device=device, dtype=params_dtype) torch.empty(fc1_output_features, device=device, dtype=params_dtype)
) )
set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1) self.register_parameter('fc1_bias', fc1_bias)
set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1) # pylint: disable=access-member-before-definition
else: else:
self.fc1_bias = torch.Tensor().to(dtype=params_dtype, device=device) self.fc1_bias = torch.Tensor().to(dtype=params_dtype, device=device)
with torch.no_grad():
self.fc1_bias.zero_()
# FC2 init # FC2 init
fc2_temp_weight = torch.empty( fc2_weight = Parameter(
hidden_size, self.size_per_partition, device=device, dtype=params_dtype) torch.empty(hidden_size, self.size_per_partition, device=device, dtype=params_dtype)
initialize_affine_weight_gpu(
fc2_temp_weight,
output_layer_init_method,
get_rng_state_tracker,
set_tp_attributes=False,
) )
self.register_parameter('fc2_weight', fc2_weight,
if self.primary_weights_in_fp8: init_fn=output_layer_init_method,
fc2_temp_weight = Float8Tensor.to_float8( get_rng_state_tracker=get_rng_state_tracker,
fc2_temp_weight, fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT)
fp8_meta=self.fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT,
)
self.fc2_weight = Parameter(fc2_temp_weight)
set_tensor_model_parallel_attributes(self.fc2_weight, True, 1, 1) set_tensor_model_parallel_attributes(self.fc2_weight, True, 1, 1)
self.fp8_weight_shapes.append(self.fc2_weight.shape) self.fp8_weight_shapes.append(self.fc2_weight.shape)
if self.use_bias: if self.use_bias:
self.fc2_bias = Parameter( fc2_bias = Parameter(
torch.empty(hidden_size, device=device, dtype=params_dtype) torch.empty(hidden_size, device=device, dtype=params_dtype)
) )
self.register_parameter('fc2_bias', fc2_bias)
# RPL # RPL
if self.set_parallel_mode: if self.set_parallel_mode:
setattr(self.fc2_bias, "sequence_parallel", sequence_parallel) setattr(self.fc2_bias, "sequence_parallel", sequence_parallel) # pylint: disable=access-member-before-definition
else: else:
self.fc2_bias = torch.Tensor().to(dtype=params_dtype, device=device) self.fc2_bias = torch.Tensor().to(dtype=params_dtype, device=device)
if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=2)
self.fp8_meta["update_amax_and_scale_fwd"] = True
self.reset_parameters(defer_init=(device == 'meta'))
# For RPL, bias has to be added after TP collectives # For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM # So it cannot be fused with the GEMM
if self.set_parallel_mode and self.apply_bias: if self.set_parallel_mode and self.apply_bias:
...@@ -1262,9 +1247,6 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1262,9 +1247,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
else: else:
self.gemm_bias_unfused_add = False self.gemm_bias_unfused_add = False
with torch.no_grad():
self.fc2_bias.zero_()
if self.bias_gelu_nvfusion: if self.bias_gelu_nvfusion:
set_jit_fusion_options() set_jit_fusion_options()
if seq_length and micro_batch_size: if seq_length and micro_batch_size:
...@@ -1281,6 +1263,12 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1281,6 +1263,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
warnings.warn(
("This method will be deprecated in an upcoming release. "
"Update your code to use LayerNormMLP.reset_parameters() instead."),
DeprecationWarning,
stacklevel=2
)
if not self.zero_centered_gamma: if not self.zero_centered_gamma:
init.ones_(self.layer_norm_weight) init.ones_(self.layer_norm_weight)
else: else:
......
...@@ -23,7 +23,6 @@ from ._common import _noop_cat ...@@ -23,7 +23,6 @@ from ._common import _noop_cat
from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
from ..utils import ( from ..utils import (
divide, divide,
get_default_init_method,
cast_if_needed, cast_if_needed,
assert_dim_for_fp8_exec, assert_dim_for_fp8_exec,
clear_tensor_data, clear_tensor_data,
...@@ -32,7 +31,6 @@ from ..distributed import ( ...@@ -32,7 +31,6 @@ from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
get_distributed_world_size, get_distributed_world_size,
allreduce, allreduce,
initialize_affine_weight_gpu,
reduce_scatter_along_first_dim, reduce_scatter_along_first_dim,
gather_along_first_dim, gather_along_first_dim,
) )
...@@ -82,7 +80,7 @@ class _Linear(torch.autograd.Function): ...@@ -82,7 +80,7 @@ class _Linear(torch.autograd.Function):
ub_split_ag: bool, ub_split_ag: bool,
ub_atomic_gemm_rs: bool, ub_atomic_gemm_rs: bool,
ub_atomic_gemm_ag: bool, ub_atomic_gemm_ag: bool,
ub_name: str, ub_name: str
) -> torch.Tensor: ) -> torch.Tensor:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = weight.shape[-1] in_features = weight.shape[-1]
...@@ -625,6 +623,10 @@ class Linear(TransformerEngineBaseModule): ...@@ -625,6 +623,10 @@ class Linear(TransformerEngineBaseModule):
if any([ub_atomic_gemm_rs, ub_atomic_gemm_ag]): if any([ub_atomic_gemm_rs, ub_atomic_gemm_ag]):
assert ub_name is not None, "Userbuffer name [string] is not set." assert ub_name is not None, "Userbuffer name [string] is not set."
self.ub_name = ub_name self.ub_name = ub_name
self.get_rng_state_tracker = get_rng_state_tracker
if device == 'meta':
assert parameters_split is None, ("Cannot split module parameters "
"on 'meta' device.")
if ub_split_rs or ub_split_ag or ub_atomic_gemm_rs: if ub_split_rs or ub_split_ag or ub_atomic_gemm_rs:
assert ( assert (
...@@ -655,44 +657,17 @@ class Linear(TransformerEngineBaseModule): ...@@ -655,44 +657,17 @@ class Linear(TransformerEngineBaseModule):
elif self.parallel_mode == "row": elif self.parallel_mode == "row":
self.in_features = divide(self.in_features, self.tp_size) self.in_features = divide(self.in_features, self.tp_size)
if init_method is None:
init_method = get_default_init_method()
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
temp_weight = torch.empty( self.weight_tensor = torch.empty(
self.out_features, self.in_features, self.out_features, self.in_features,
device=device, dtype=params_dtype) device=device, dtype=params_dtype)
# TODO(ksivaman): This functionality works with FP8 outside TE.
initialize_affine_weight_gpu(
temp_weight,
init_method,
get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
if self.primary_weights_in_fp8:
self.init_fp8_metadata()
self.fp8_meta["update_amax_and_scale_fwd"] = True
self.weight_tensor = Float8Tensor.to_float8(
temp_weight,
fp8_meta=self.fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
else:
self.weight_tensor = temp_weight
if self.use_bias: if self.use_bias:
self.bias_tensor = torch.empty(self.out_features, device=device, dtype=params_dtype) self.bias_tensor = torch.empty(self.out_features, device=device, dtype=params_dtype)
else: else:
self.bias_tensor = torch.Tensor().to(dtype=params_dtype, device=device) self.bias_tensor = torch.Tensor().to(dtype=params_dtype, device=device)
with torch.no_grad():
self.bias_tensor.zero_()
# Configure parameter splits # Configure parameter splits
self.weight_names = [] self.weight_names = []
self.bias_names = [] self.bias_names = []
...@@ -757,7 +732,10 @@ class Linear(TransformerEngineBaseModule): ...@@ -757,7 +732,10 @@ class Linear(TransformerEngineBaseModule):
if is_subview: if is_subview:
weight = weight[split_start:split_end] weight = weight[split_start:split_end]
weight = torch.nn.Parameter(weight) weight = torch.nn.Parameter(weight)
self.register_parameter(self.weight_names[i], weight) self.register_parameter(self.weight_names[i], weight,
init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT)
# Construct bias parameter if needed # Construct bias parameter if needed
if self.use_bias: if self.use_bias:
...@@ -788,6 +766,12 @@ class Linear(TransformerEngineBaseModule): ...@@ -788,6 +766,12 @@ class Linear(TransformerEngineBaseModule):
del self.weight_tensor del self.weight_tensor
del self.bias_tensor del self.bias_tensor
if self.primary_weights_in_fp8:
self.init_fp8_metadata()
self.fp8_meta["update_amax_and_scale_fwd"] = True
self.reset_parameters(defer_init=(device == 'meta'))
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
# For RPL, bias has to be added after TP collectives # For RPL, bias has to be added after TP collectives
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""RMSNorm API""" """RMSNorm API"""
import os import os
import warnings
from typing import Union, Tuple, Optional from typing import Union, Tuple, Optional
import torch import torch
...@@ -141,7 +142,8 @@ class RMSNorm(torch.nn.Module): ...@@ -141,7 +142,8 @@ class RMSNorm(torch.nn.Module):
) )
) )
setattr(self.weight, "sequence_parallel", sequence_parallel) setattr(self.weight, "sequence_parallel", sequence_parallel)
self.reset_rms_norm_parameters()
self.reset_parameters(defer_init=(device == 'meta'))
# These many SMs are subtracted from the total SM count when calling forward # These many SMs are subtracted from the total SM count when calling forward
# and backward RMSNorm C APIs. These envvars can be used to prevent the LN # and backward RMSNorm C APIs. These envvars can be used to prevent the LN
...@@ -152,11 +154,22 @@ class RMSNorm(torch.nn.Module): ...@@ -152,11 +154,22 @@ class RMSNorm(torch.nn.Module):
def reset_rms_norm_parameters(self) -> None: def reset_rms_norm_parameters(self) -> None:
"""Init RMSNorm params""" """Init RMSNorm params"""
warnings.warn(
("This method will be deprecated in an upcoming release. "
"Update your code to use RMSNorm.reset_parameters() instead."),
DeprecationWarning,
stacklevel=2
)
if not self.zero_centered_gamma: if not self.zero_centered_gamma:
init.ones_(self.weight) init.ones_(self.weight)
else: else:
init.zeros_(self.weight) init.zeros_(self.weight)
def reset_parameters(self, defer_init=False) -> None:
"""Reset RMSNorm parameters"""
if defer_init:
return
init.constant_(self.weight, float(not self.zero_centered_gamma))
@no_torch_dynamo() @no_torch_dynamo()
def forward(self, inp: torch.Tensor) -> torch.Tensor: def forward(self, inp: torch.Tensor) -> torch.Tensor:
......
...@@ -40,6 +40,21 @@ def get_default_init_method() -> Callable: ...@@ -40,6 +40,21 @@ def get_default_init_method() -> Callable:
return init_method_normal(0.023) return init_method_normal(0.023)
def init_method_constant(val: float) -> Callable:
"""Init method to set all tensor elements to a constant value."""
if val == 1.0:
def init_(tensor: torch.Tensor) -> Callable:
return torch.nn.init.ones_(tensor)
elif val == 0.0:
def init_(tensor: torch.Tensor) -> Callable:
return torch.nn.init.zeros_(tensor)
else:
def init_(tensor: torch.Tensor) -> Callable:
return torch.nn.init.constant_(tensor, val)
return init_
def init_method_normal(sigma: float) -> Callable: def init_method_normal(sigma: float) -> Callable:
"""Init method based on N(0, sigma).""" """Init method based on N(0, sigma)."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment