Unverified Commit bb759adc authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Refactor parameter splitting in Linear and LayerNormLinear (#590)



* Refactor parameter split in Linear module

Remove module state from noop_cat. Support arbitrary names in parameter split. Handle tensor parallelism.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make noop_cat a standalone operation
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update parameter splits in LayerNormLinear
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug case without bias

Fix pylint complaints.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove unused import
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 7ce7dfe5
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Attention.""" """Attention."""
import collections
import os import os
import warnings import warnings
import math import math
...@@ -2705,9 +2706,13 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2705,9 +2706,13 @@ class MultiheadAttention(torch.nn.Module):
qkv_parallel_mode = "column" if set_parallel_mode else None qkv_parallel_mode = "column" if set_parallel_mode else None
if self.attention_type == "self": if self.attention_type == "self":
parameters_split = {"query_": hidden_size, parameters_split = None
"key_": self.hidden_size_kv, if not fuse_qkv_params:
"value_": self.hidden_size_kv} if not fuse_qkv_params else None parameters_split = collections.OrderedDict([
("query", hidden_size),
("key", self.hidden_size_kv),
("value", self.hidden_size_kv),
])
if self.input_layernorm: if self.input_layernorm:
self.layernorm_qkv = LayerNormLinear( self.layernorm_qkv = LayerNormLinear(
hidden_size, hidden_size,
...@@ -2749,7 +2754,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2749,7 +2754,7 @@ class MultiheadAttention(torch.nn.Module):
bias=bias, bias=bias,
return_bias=False, return_bias=False,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
parameters_split=("query_",) if not fuse_qkv_params else None, parameters_split=("query",) if not fuse_qkv_params else None,
return_layernorm_output=return_layernorm_output, return_layernorm_output=return_layernorm_output,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_wgrad=ub_bulk_wgrad,
...@@ -2777,7 +2782,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2777,7 +2782,7 @@ class MultiheadAttention(torch.nn.Module):
bias=bias, bias=bias,
return_bias=False, return_bias=False,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
parameters_split=("key_", "value_") if not fuse_qkv_params else None, parameters_split=("key", "value") if not fuse_qkv_params else None,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""Internal function used by multiple modules.""" """Internal function used by multiple modules."""
from typing import Union, Dict, Any from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
...@@ -93,3 +93,97 @@ def _apply_normalization(inputmat:torch.Tensor, ...@@ -93,3 +93,97 @@ def _apply_normalization(inputmat:torch.Tensor,
elif normalization == "LayerNorm": elif normalization == "LayerNorm":
output = (ln_out, output[1], output[2]) output = (ln_out, output[1], output[2])
return output return output
class _NoopCatFunc(torch.autograd.Function):
"""No-op concatenate tensors along dim 0
`full_tensor` is assumed to already be the concatenation of
`tensors`, i.e. they occupy the same memory with the correct
offsets.
"""
@staticmethod
def forward(
ctx,
split_ranges: List[Tuple[int, int]],
full_tensor: torch.Tensor,
*tensors: Tuple[torch.Tensor, ...],
) -> torch.Tensor:
# pylint: disable=unused-argument
ctx.split_ranges = split_ranges
assert not full_tensor.requires_grad, "Concatenated tensor should not require gradient"
out = full_tensor.new()
out.set_(
full_tensor.untyped_storage(),
full_tensor.storage_offset(),
full_tensor.size(),
full_tensor.stride(),
)
out.requires_grad = True
return out
@staticmethod
def backward(
ctx,
grad_output: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
grads = [
grad_output[split_start:split_end]
for split_start, split_end in ctx.split_ranges
]
return None, None, *grads
def _noop_cat(
tensors: List[torch.Tensor],
full_tensor: torch.Tensor,
) -> torch.Tensor:
"""Concatenate tensors along dim 0, doing a no-op if possible
If `full_tensor` is already the concatenation of `tensors`, i.e.
they occupy the same memory region with the correct offsets, then
no copies are performed. Otherwise the buffers in all the tensors
are reallocated so that another call would result in a no-op.
In the backward pass, gradients to `partial_tensors` will just be
tensor views.
"""
# Determine split points
split_ranges = []
full_tensor_shape = full_tensor.size()
offset = 0
for tensor in tensors:
tensor_shape = tensor.size()
if tensor_shape[1:] != full_tensor_shape[1:]:
raise ValueError(
f"Attempting to concatenate tensor with shape={list(tensor_shape)} "
f"into a tensor with shape={list(full_tensor_shape)}"
)
split_start = offset
offset += tensor_shape[0]
split_end = offset
split_ranges.append((split_start, split_end))
if offset != full_tensor_shape[0]:
raise ValueError(
f"Attempting to concatenate tensors with total shape[0]={offset} "
f"into a tensor with shape[0]={full_tensor_shape[0]}"
)
# Reallocate buffers if no-op concat isn't possible
need_to_reallocate = False
for tensor, (split_start, _) in zip(tensors, split_ranges):
if tensor.data_ptr() != full_tensor[split_start].data_ptr():
need_to_reallocate = True
break
if need_to_reallocate:
with torch.no_grad():
full_tensor.data = torch.cat(tensors)
for tensor, (split_start, split_end) in zip(tensors, split_ranges):
tensor.data = full_tensor[split_start:split_end]
# Perform no-op concat
return _NoopCatFunc.apply(split_ranges, full_tensor, *tensors)
...@@ -14,7 +14,6 @@ from contextlib import contextmanager ...@@ -14,7 +14,6 @@ from contextlib import contextmanager
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from ..export import is_in_onnx_export_mode from ..export import is_in_onnx_export_mode
...@@ -213,44 +212,6 @@ def get_ub(name: str): ...@@ -213,44 +212,6 @@ def get_ub(name: str):
return _ub_communicators[name] return _ub_communicators[name]
class _NoopCat(torch.autograd.Function):
"""This class is a no-op replacement for `torch.cat`."""
@staticmethod
def forward(ctx,
full_param_buffer: torch.Tensor,
*params_split: Tuple[torch.Tensor, ...],
) -> torch.Tensor:
assert not full_param_buffer.requires_grad, "Buffers should not require gradient"
sum_params_shape = sum(p.shape[0] for p in params_split)
assert (
full_param_buffer.shape[0] == sum_params_shape
), "Dimensions not compatible for concatenation"
param_temp = full_param_buffer.new()
param_temp.set_(full_param_buffer.untyped_storage(),
full_param_buffer.storage_offset(),
full_param_buffer.size(),
full_param_buffer.stride())
param_temp.requires_grad = True
ctx.save_for_backward(*params_split)
return param_temp
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
params_split = ctx.saved_tensors
grads = []
slice_begin = 0
for i, _ in enumerate(params_split):
slice_size = params_split[i].shape[0]
slice_end = slice_begin + slice_size
grads.append(grad_output[slice_begin:slice_end])
slice_begin = slice_end
return None, *grads
class TransformerEngineBaseModule(torch.nn.Module, ABC): class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module.""" """Base TE module."""
...@@ -742,40 +703,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -742,40 +703,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return grad_output_mat, grad_output_c, grad_output_t, grad_bias return grad_output_mat, grad_output_c, grad_output_t, grad_bias
def noop_cat(self,
buffer_name: str,
pnames: List[str],
parameters_split: Dict[str, int]
) -> torch.Tensor:
"""No-op replacement of `torch.cat`. The buffer and split parameters must occupy
the same memory region. If this is not the case, then the split parameters
are concatenated and the buffer is overwritten. The parameters' memory is then
re-assigned to point to the buffer to avoid subsequent concatenations.
"""
assert hasattr(self, buffer_name), f"No buffer named {buffer_name}"
full_param_buffer = getattr(self, buffer_name)
params = [getattr(self, name) for name in pnames]
slice_begin = 0
for i, p in enumerate(params):
slice_size = parameters_split[pnames[i].split('_')[0]+'_']
slice_end = slice_begin + slice_size
if p.data.data_ptr() != full_param_buffer[slice_begin:slice_end].data_ptr():
with torch.no_grad():
setattr(self, buffer_name, torch.cat(params))
slice_begin_j = 0
for pname in pnames:
slice_size_j = parameters_split[pname.split('_')[0]+'_']
slice_end_j = slice_begin_j + slice_size_j
full_param_buffer = getattr(self, buffer_name)
setattr(self, pname,
Parameter(full_param_buffer[slice_begin_j:slice_end_j]))
slice_begin_j = slice_end_j
break
slice_begin = slice_end
return _NoopCat.apply(getattr(self, buffer_name), *[getattr(self, name) for name in pnames])
def get_fp8_weights_empty_tensors( def get_fp8_weights_empty_tensors(
self, self,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
......
...@@ -7,9 +7,7 @@ import os ...@@ -7,9 +7,7 @@ import os
import warnings import warnings
from typing import Union, Optional, Callable, Tuple, List, Dict, Any from typing import Union, Optional, Callable, Tuple, List, Dict, Any
import torch import torch
from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
from .. import cpp_extensions as tex from .. import cpp_extensions as tex
...@@ -41,7 +39,7 @@ from ..distributed import ( ...@@ -41,7 +39,7 @@ from ..distributed import (
) )
from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..constants import GemmParallelModes, dist_group_type, TE_DType
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ._common import _apply_normalization from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
...@@ -612,13 +610,13 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -612,13 +610,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
Example use case: residual connection for transformer module is Example use case: residual connection for transformer module is
taken post layernorm. taken post layernorm.
parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
if a tuple of strings or a dict of strings to integers is provided, Configuration for splitting the weight and bias tensors along dim 0 into
the weight and bias parameters of the module are exposed as `N` separate multiple PyTorch parameters. If a list or tuple of strings is provided,
`torch.nn.parameter.Parameter`s each, split along the first dimension, they are used to make the names of equally-sized parameters. If a dict
where `N` is the length of the argument and the strings contained are the (preferably an OrderedDict) is provided, the keys are used as names and
names of the split parameters. In the case of a tuple, each parameter values as split sizes along dim 0. The resulting parameters will have
has the same shape. In the case of a dict, the values give the names that end in `_weight` or `_bias`, so trailing underscores are
`out_features` for each projection. stripped from any provided names.
zero_centered_gamma : bool, default = 'False' zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to the LayerNorm formula changes to
...@@ -705,7 +703,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -705,7 +703,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_bias = return_bias self.return_bias = return_bias
self.apply_bias = self.use_bias and not return_bias self.apply_bias = self.use_bias and not return_bias
self.return_layernorm_output = return_layernorm_output self.return_layernorm_output = return_layernorm_output
self.parameters_split = parameters_split
self.zero_centered_gamma = zero_centered_gamma self.zero_centered_gamma = zero_centered_gamma
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_wgrad = ub_bulk_wgrad
...@@ -752,12 +749,12 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -752,12 +749,12 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
self.eps = eps self.eps = eps
self.layer_norm_weight = Parameter( self.layer_norm_weight = torch.nn.Parameter(
torch.empty(in_features, device=device, dtype=params_dtype) torch.empty(in_features, device=device, dtype=params_dtype)
) )
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
if self.normalization != "RMSNorm": if self.normalization != "RMSNorm":
self.layer_norm_bias = Parameter( self.layer_norm_bias = torch.nn.Parameter(
torch.empty(in_features, device=device, dtype=params_dtype) torch.empty(in_features, device=device, dtype=params_dtype)
) )
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
...@@ -800,68 +797,100 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -800,68 +797,100 @@ class LayerNormLinear(TransformerEngineBaseModule):
with torch.no_grad(): with torch.no_grad():
self.bias_tensor.zero_() self.bias_tensor.zero_()
# Configure parameter splits
self.weight_names = []
self.bias_names = []
self.parameter_split_sizes = []
if parameters_split is None: if parameters_split is None:
parameters_split = {"": self.out_features} # Split into a single parameter by default
elif isinstance(parameters_split, tuple): self.weight_names = ["weight"]
assert ( self.bias_names = ["bias"]
self.out_features % len(parameters_split) == 0 self.parameter_split_sizes = [out_features]
), f"Weight and bias params cannot be split into {len(parameters_split)} parts" elif not parameters_split:
split_size = self.out_features // len(parameters_split) raise ValueError("Cannot split weight buffer into 0 parameters")
parameters_split = {key: split_size for key in parameters_split}
elif isinstance(parameters_split, dict): elif isinstance(parameters_split, dict):
overall_split_size = sum(parameters_split.values()) # Split parameters with provided sizes
assert( for name, split_size in parameters_split.items():
self.out_features == overall_split_size self.weight_names.append(f"{name.rstrip('_')}_weight")
), f"Overall sum of parameters_split (={overall_split_size}) does not match "\ self.bias_names.append(f"{name.rstrip('_')}_bias")
f"to out features (={self.out_features})" self.parameter_split_sizes.append(split_size)
elif all(isinstance(name, str) for name in parameters_split):
# Split parameters evenly
split_size = out_features // len(parameters_split)
for name in parameters_split:
self.weight_names.append(f"{name.rstrip('_')}_weight")
self.bias_names.append(f"{name.rstrip('_')}_bias")
self.parameter_split_sizes.append(split_size)
else: else:
assert False, "Type of 'parameters_split' is not None, tuple or dict" raise TypeError("Invalid configuration for parameters split")
self.updated_parameters_split = parameters_split
self.weight_names = [] # Make sure parameter splits are valid
self.bias_names = [] if sum(self.parameter_split_sizes) != out_features:
raise ValueError(
f"Trying to split weight buffer ({out_features=}) "
f"with split sizes {self.parameter_split_sizes}"
)
slice_begin = 0 # Adjust parameter splits for tensor-parallel distribution
for pname, slice_size in parameters_split.items(): if self.parallel_mode == "column":
wname = pname + "weight" for i, size in enumerate(self.parameter_split_sizes):
bname = pname + "bias" if size % self.tp_size != 0:
raise RuntimeError(
slice_end = slice_begin + slice_size f"Attempting to distribute a parameter with out_features={size} "
# NOTE(future): Figure out a way to support slicing when weights f"between {self.tp_size} tensor-parallel processes"
# are of `Float8Tensor` class )
if self.primary_weights_in_fp8: self.parameter_split_sizes[i] = size // self.tp_size
assert len(parameters_split) == 1, ("Slicing operation is not "
"supported in Float8Tensor " # Construct parameters from weight and bias buffers
"class!") offset = 0
self.register_parameter(wname, Parameter(self.weight_tensor)) for i, split_size in enumerate(self.parameter_split_sizes):
else: split_start = offset
self.register_parameter( offset += split_size
wname, Parameter(self.weight_tensor[slice_begin:slice_end]) split_end = offset
# Check if parameters are subviews of buffers
is_subview = (split_start, split_end) != (0, self.out_features)
if is_subview and self.primary_weights_in_fp8:
raise RuntimeError(
"Splitting Float8Tensor into multiple params "
"is not supported"
) )
set_tensor_model_parallel_attributes( # Construct weight parameter
tensor=getattr(self, wname), weight = self.weight_tensor
is_parallel=True, if is_subview:
dim=1 if parallel_mode == "row" else 0, weight = weight[split_start:split_end]
stride=1, weight = torch.nn.Parameter(weight)
) self.register_parameter(self.weight_names[i], weight)
# Construct bias parameter if needed
if self.use_bias: if self.use_bias:
self.register_parameter( bias = self.bias_tensor
bname, Parameter(self.bias_tensor[slice_begin:slice_end]) if is_subview:
) bias = bias[split_start:split_end]
bias = torch.nn.Parameter(bias)
self.register_parameter(self.bias_names[i], bias)
if parallel_mode == "row": if parallel_mode == "row":
setattr(getattr(self, bname), "sequence_parallel", sequence_parallel) bias.sequence_parallel = sequence_parallel
else: else:
setattr(self, bname, torch.Tensor().to(dtype=params_dtype, device=device)) bias = torch.Tensor().to(dtype=params_dtype, device=device)
setattr(self, self.bias_names[i], bias)
# Configure tensor parallelism
set_tensor_model_parallel_attributes(
tensor=weight,
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)
if parallel_mode == "column": if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1) set_tensor_model_parallel_attributes(bias, True, 0, 1)
self.weight_names.append(wname) # Concatenated tensors are not needed if not splitting
self.bias_names.append(bname) # into multiple parameters
if not is_subview:
slice_begin = slice_end del self.weight_tensor
del self.bias_tensor
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
...@@ -880,12 +909,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -880,12 +909,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
# Clean up weight and bias buffers
if self.parameters_split is None:
del self.weight_tensor
if self.use_bias:
del self.bias_tensor
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
if not self.zero_centered_gamma: if not self.zero_centered_gamma:
...@@ -950,18 +973,26 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -950,18 +973,26 @@ class LayerNormLinear(TransformerEngineBaseModule):
with self.prepare_forward(inp, is_first_microbatch) as inp: with self.prepare_forward(inp, is_first_microbatch) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \ assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8." "Need to run inside fp8_autocast region when weights are stored in FP8."
bias_tensor = (
self.bias if self.parameters_split is None # Get concatenated weight and bias tensors
else self.bias_tensor if not torch.is_grad_enabled() if len(self.parameter_split_sizes) == 1:
else self.noop_cat("bias_tensor", self.bias_names, weight_tensor = getattr(self, self.weight_names[0])
self.updated_parameters_split) bias_tensor = getattr(self, self.bias_names[0])
) elif torch.is_grad_enabled():
weight_tensor = ( weight_tensor = _noop_cat(
self.weight if self.parameters_split is None [getattr(self, name) for name in self.weight_names],
else self.weight_tensor if not torch.is_grad_enabled() self.weight_tensor,
else self.noop_cat("weight_tensor", self.weight_names, )
self.updated_parameters_split) if self.use_bias:
) bias_tensor = _noop_cat(
[getattr(self, name) for name in self.bias_names],
self.bias_tensor,
)
else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused
else:
weight_tensor = self.weight_tensor
bias_tensor = self.bias_tensor
# Fetch the fp8 weights placeholders (for linear/gemm) # Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad( weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(
......
...@@ -7,7 +7,6 @@ import warnings ...@@ -7,7 +7,6 @@ import warnings
from typing import Union, Optional, Callable, Tuple, List, Dict, Any from typing import Union, Optional, Callable, Tuple, List, Dict, Any
import torch import torch
from torch.nn.parameter import Parameter
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
...@@ -20,6 +19,7 @@ from .base import ( ...@@ -20,6 +19,7 @@ from .base import (
_2X_ACC_DGRAD, _2X_ACC_DGRAD,
_2X_ACC_WGRAD, _2X_ACC_WGRAD,
) )
from ._common import _noop_cat
from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
from ..utils import ( from ..utils import (
divide, divide,
...@@ -521,8 +521,7 @@ class _Linear(torch.autograd.Function): ...@@ -521,8 +521,7 @@ class _Linear(torch.autograd.Function):
class Linear(TransformerEngineBaseModule): class Linear(TransformerEngineBaseModule):
""" """Applies a linear transformation to the incoming data :math:`y = xA^T + b`
Applies a linear transformation to the incoming data :math:`y = xA^T + b`
On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`. On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.
...@@ -538,13 +537,13 @@ class Linear(TransformerEngineBaseModule): ...@@ -538,13 +537,13 @@ class Linear(TransformerEngineBaseModule):
used for initializing weights in the following way: `init_method(weight)`. used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
if a tuple of strings or a dict of strings to integers is provided, Configuration for splitting the weight and bias tensors along dim 0 into
the weight and bias parameters of the module are exposed as `N` separate multiple PyTorch parameters. If a list or tuple of strings is provided,
`torch.nn.parameter.Parameter`s each, split along the first dimension, they are used to make the names of equally-sized parameters. If a dict
where `N` is the length of the argument and the strings contained are the (preferably an OrderedDict) is provided, the keys are used as names and
names of the split parameters. In the case of a tuple, each parameter values as split sizes along dim 0. The resulting parameters will have
has the same shape. In the case of a dict, the values give the names that end in `_weight` or `_bias`, so trailing underscores are
`out_features` for each projection. stripped from any provided names.
device : Union[torch.device, str], default = "cuda" device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
...@@ -584,6 +583,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -584,6 +583,7 @@ class Linear(TransformerEngineBaseModule):
it controls the type used to allocate the initial parameters. Useful when it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
""" """
def __init__( def __init__(
...@@ -617,7 +617,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -617,7 +617,6 @@ class Linear(TransformerEngineBaseModule):
self.use_bias = bias self.use_bias = bias
self.return_bias = return_bias self.return_bias = return_bias
self.apply_bias = bias and not return_bias self.apply_bias = bias and not return_bias
self.parameters_split = parameters_split
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_split_rs = ub_split_rs self.ub_split_rs = ub_split_rs
self.ub_split_ag = ub_split_ag self.ub_split_ag = ub_split_ag
...@@ -694,69 +693,100 @@ class Linear(TransformerEngineBaseModule): ...@@ -694,69 +693,100 @@ class Linear(TransformerEngineBaseModule):
with torch.no_grad(): with torch.no_grad():
self.bias_tensor.zero_() self.bias_tensor.zero_()
# Configure parameter splits
self.weight_names = []
self.bias_names = []
self.parameter_split_sizes = []
if parameters_split is None: if parameters_split is None:
parameters_split = {"": self.out_features} # Split into a single parameter by default
elif isinstance(parameters_split, tuple): self.weight_names = ["weight"]
assert ( self.bias_names = ["bias"]
self.out_features % len(parameters_split) == 0 self.parameter_split_sizes = [out_features]
), f"Weight and bias params cannot be split into {len(parameters_split)} parts" elif not parameters_split:
split_size = self.out_features // len(parameters_split) raise ValueError("Cannot split weight buffer into 0 parameters")
parameters_split = {key: split_size for key in parameters_split}
elif isinstance(parameters_split, dict): elif isinstance(parameters_split, dict):
overall_split_size = sum(parameters_split.values()) # Split parameters with provided sizes
assert( for name, split_size in parameters_split.items():
self.out_features == overall_split_size self.weight_names.append(f"{name.rstrip('_')}_weight")
), f"Overall sum of parameters_split (={overall_split_size}) does not match "\ self.bias_names.append(f"{name.rstrip('_')}_bias")
f"to out features (={self.out_features})" self.parameter_split_sizes.append(split_size)
elif all(isinstance(name, str) for name in parameters_split):
# Split parameters evenly
split_size = out_features // len(parameters_split)
for name in parameters_split:
self.weight_names.append(f"{name.rstrip('_')}_weight")
self.bias_names.append(f"{name.rstrip('_')}_bias")
self.parameter_split_sizes.append(split_size)
else: else:
assert False, "Type of 'parameters_split' is not None, tuple or dict" raise TypeError("Invalid configuration for parameters split")
self.updated_parameters_split = parameters_split
self.weight_names = [] # Make sure parameter splits are valid
self.bias_names = [] if sum(self.parameter_split_sizes) != out_features:
raise ValueError(
f"Trying to split weight buffer ({out_features=}) "
f"with split sizes {self.parameter_split_sizes}"
)
slice_begin = 0 # Adjust parameter splits for tensor-parallel distribution
for pname, slice_size in parameters_split.items(): if self.parallel_mode == "column":
wname = pname + "weight" for i, size in enumerate(self.parameter_split_sizes):
bname = pname + "bias" if size % self.tp_size != 0:
raise RuntimeError(
f"Attempting to distribute a parameter with out_features={size} "
f"between {self.tp_size} tensor-parallel processes"
)
self.parameter_split_sizes[i] = size // self.tp_size
# Construct parameters from weight and bias buffers
offset = 0
for i, split_size in enumerate(self.parameter_split_sizes):
split_start = offset
offset += split_size
split_end = offset
# Check if parameters are subviews of buffers
is_subview = (split_start, split_end) != (0, self.out_features)
if is_subview and self.primary_weights_in_fp8:
raise RuntimeError(
"Splitting Float8Tensor into multiple params "
"is not supported"
)
slice_end = slice_begin + slice_size # Construct weight parameter
weight = self.weight_tensor
if is_subview:
weight = weight[split_start:split_end]
weight = torch.nn.Parameter(weight)
self.register_parameter(self.weight_names[i], weight)
# TODO(ksivaman): Add indexing op to torch dispatcher for float8 # Construct bias parameter if needed
if self.primary_weights_in_fp8: if self.use_bias:
assert len(parameters_split) == 1, ("Slicing operation is not " bias = self.bias_tensor
"supported in Float8Tensor " if is_subview:
"class!") bias = bias[split_start:split_end]
self.register_parameter(wname, Parameter(self.weight_tensor)) bias = torch.nn.Parameter(bias)
self.register_parameter(self.bias_names[i], bias)
if parallel_mode == "row":
bias.sequence_parallel = sequence_parallel
else: else:
bias = torch.Tensor().to(dtype=params_dtype, device=device)
setattr(self, self.bias_names[i], bias)
self.register_parameter( # Configure tensor parallelism
wname, Parameter(self.weight_tensor[slice_begin:slice_end])
)
set_tensor_model_parallel_attributes( set_tensor_model_parallel_attributes(
tensor=getattr(self, wname), tensor=weight,
is_parallel=True, is_parallel=True,
dim=1 if parallel_mode == "row" else 0, dim=1 if parallel_mode == "row" else 0,
stride=1, stride=1,
) )
if self.use_bias:
self.register_parameter(
bname, Parameter(self.bias_tensor[slice_begin:slice_end])
)
if parallel_mode == "row":
setattr(getattr(self, bname), "sequence_parallel", sequence_parallel)
else:
setattr(self, bname, torch.Tensor().to(dtype=params_dtype, device=device))
if parallel_mode == "column": if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1) set_tensor_model_parallel_attributes(bias, True, 0, 1)
self.weight_names.append(wname) # Concatenated tensors are not needed if not splitting
self.bias_names.append(bname) # into multiple parameters
if not is_subview:
slice_begin = slice_end del self.weight_tensor
del self.bias_tensor
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
...@@ -767,12 +797,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -767,12 +797,6 @@ class Linear(TransformerEngineBaseModule):
else: else:
self.gemm_bias_unfused_add = False self.gemm_bias_unfused_add = False
# Clean up weight and bias buffers
if self.parameters_split is None:
del self.weight_tensor
if self.use_bias:
del self.bias_tensor
def get_fp8_weights_scratchpad( def get_fp8_weights_scratchpad(
self, self,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
...@@ -828,18 +852,26 @@ class Linear(TransformerEngineBaseModule): ...@@ -828,18 +852,26 @@ class Linear(TransformerEngineBaseModule):
with self.prepare_forward(inp, is_first_microbatch) as inp: with self.prepare_forward(inp, is_first_microbatch) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \ assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8." "Need to run inside fp8_autocast region when weights are stored in FP8."
bias_tensor = (
self.bias if self.parameters_split is None # Get concatenated weight and bias tensors
else self.bias_tensor if not torch.is_grad_enabled() if len(self.parameter_split_sizes) == 1:
else self.noop_cat("bias_tensor", self.bias_names, weight_tensor = getattr(self, self.weight_names[0])
self.updated_parameters_split) bias_tensor = getattr(self, self.bias_names[0])
) elif torch.is_grad_enabled():
weight_tensor = ( weight_tensor = _noop_cat(
self.weight if self.parameters_split is None [getattr(self, name) for name in self.weight_names],
else self.weight_tensor if not torch.is_grad_enabled() self.weight_tensor,
else self.noop_cat("weight_tensor", self.weight_names, )
self.updated_parameters_split) if self.use_bias:
) bias_tensor = _noop_cat(
[getattr(self, name) for name in self.bias_names],
self.bias_tensor,
)
else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused
else:
weight_tensor = self.weight_tensor
bias_tensor = self.bias_tensor
# Fetch the fp8 weights placeholders (for linear/gemm) # Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad( weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment