Unverified Commit bb759adc authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Refactor parameter splitting in Linear and LayerNormLinear (#590)



* Refactor parameter split in Linear module

Remove module state from noop_cat. Support arbitrary names in parameter split. Handle tensor parallelism.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make noop_cat a standalone operation
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update parameter splits in LayerNormLinear
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug case without bias

Fix pylint complaints.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove unused import
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 7ce7dfe5
......@@ -3,6 +3,7 @@
# See LICENSE for license information.
"""Attention."""
import collections
import os
import warnings
import math
......@@ -2705,9 +2706,13 @@ class MultiheadAttention(torch.nn.Module):
qkv_parallel_mode = "column" if set_parallel_mode else None
if self.attention_type == "self":
parameters_split = {"query_": hidden_size,
"key_": self.hidden_size_kv,
"value_": self.hidden_size_kv} if not fuse_qkv_params else None
parameters_split = None
if not fuse_qkv_params:
parameters_split = collections.OrderedDict([
("query", hidden_size),
("key", self.hidden_size_kv),
("value", self.hidden_size_kv),
])
if self.input_layernorm:
self.layernorm_qkv = LayerNormLinear(
hidden_size,
......@@ -2749,7 +2754,7 @@ class MultiheadAttention(torch.nn.Module):
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
parameters_split=("query_",) if not fuse_qkv_params else None,
parameters_split=("query",) if not fuse_qkv_params else None,
return_layernorm_output=return_layernorm_output,
zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
......@@ -2777,7 +2782,7 @@ class MultiheadAttention(torch.nn.Module):
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
parameters_split=("key_", "value_") if not fuse_qkv_params else None,
parameters_split=("key", "value") if not fuse_qkv_params else None,
**common_gemm_kwargs,
)
......
......@@ -4,7 +4,7 @@
"""Internal function used by multiple modules."""
from typing import Union, Dict, Any
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
......@@ -93,3 +93,97 @@ def _apply_normalization(inputmat:torch.Tensor,
elif normalization == "LayerNorm":
output = (ln_out, output[1], output[2])
return output
class _NoopCatFunc(torch.autograd.Function):
"""No-op concatenate tensors along dim 0
`full_tensor` is assumed to already be the concatenation of
`tensors`, i.e. they occupy the same memory with the correct
offsets.
"""
@staticmethod
def forward(
ctx,
split_ranges: List[Tuple[int, int]],
full_tensor: torch.Tensor,
*tensors: Tuple[torch.Tensor, ...],
) -> torch.Tensor:
# pylint: disable=unused-argument
ctx.split_ranges = split_ranges
assert not full_tensor.requires_grad, "Concatenated tensor should not require gradient"
out = full_tensor.new()
out.set_(
full_tensor.untyped_storage(),
full_tensor.storage_offset(),
full_tensor.size(),
full_tensor.stride(),
)
out.requires_grad = True
return out
@staticmethod
def backward(
ctx,
grad_output: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
grads = [
grad_output[split_start:split_end]
for split_start, split_end in ctx.split_ranges
]
return None, None, *grads
def _noop_cat(
tensors: List[torch.Tensor],
full_tensor: torch.Tensor,
) -> torch.Tensor:
"""Concatenate tensors along dim 0, doing a no-op if possible
If `full_tensor` is already the concatenation of `tensors`, i.e.
they occupy the same memory region with the correct offsets, then
no copies are performed. Otherwise the buffers in all the tensors
are reallocated so that another call would result in a no-op.
In the backward pass, gradients to `partial_tensors` will just be
tensor views.
"""
# Determine split points
split_ranges = []
full_tensor_shape = full_tensor.size()
offset = 0
for tensor in tensors:
tensor_shape = tensor.size()
if tensor_shape[1:] != full_tensor_shape[1:]:
raise ValueError(
f"Attempting to concatenate tensor with shape={list(tensor_shape)} "
f"into a tensor with shape={list(full_tensor_shape)}"
)
split_start = offset
offset += tensor_shape[0]
split_end = offset
split_ranges.append((split_start, split_end))
if offset != full_tensor_shape[0]:
raise ValueError(
f"Attempting to concatenate tensors with total shape[0]={offset} "
f"into a tensor with shape[0]={full_tensor_shape[0]}"
)
# Reallocate buffers if no-op concat isn't possible
need_to_reallocate = False
for tensor, (split_start, _) in zip(tensors, split_ranges):
if tensor.data_ptr() != full_tensor[split_start].data_ptr():
need_to_reallocate = True
break
if need_to_reallocate:
with torch.no_grad():
full_tensor.data = torch.cat(tensors)
for tensor, (split_start, split_end) in zip(tensors, split_ranges):
tensor.data = full_tensor[split_start:split_end]
# Perform no-op concat
return _NoopCatFunc.apply(split_ranges, full_tensor, *tensors)
......@@ -14,7 +14,6 @@ from contextlib import contextmanager
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import transformer_engine_extensions as tex
from ..export import is_in_onnx_export_mode
......@@ -213,44 +212,6 @@ def get_ub(name: str):
return _ub_communicators[name]
class _NoopCat(torch.autograd.Function):
"""This class is a no-op replacement for `torch.cat`."""
@staticmethod
def forward(ctx,
full_param_buffer: torch.Tensor,
*params_split: Tuple[torch.Tensor, ...],
) -> torch.Tensor:
assert not full_param_buffer.requires_grad, "Buffers should not require gradient"
sum_params_shape = sum(p.shape[0] for p in params_split)
assert (
full_param_buffer.shape[0] == sum_params_shape
), "Dimensions not compatible for concatenation"
param_temp = full_param_buffer.new()
param_temp.set_(full_param_buffer.untyped_storage(),
full_param_buffer.storage_offset(),
full_param_buffer.size(),
full_param_buffer.stride())
param_temp.requires_grad = True
ctx.save_for_backward(*params_split)
return param_temp
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
params_split = ctx.saved_tensors
grads = []
slice_begin = 0
for i, _ in enumerate(params_split):
slice_size = params_split[i].shape[0]
slice_end = slice_begin + slice_size
grads.append(grad_output[slice_begin:slice_end])
slice_begin = slice_end
return None, *grads
class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module."""
......@@ -742,40 +703,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return grad_output_mat, grad_output_c, grad_output_t, grad_bias
def noop_cat(self,
buffer_name: str,
pnames: List[str],
parameters_split: Dict[str, int]
) -> torch.Tensor:
"""No-op replacement of `torch.cat`. The buffer and split parameters must occupy
the same memory region. If this is not the case, then the split parameters
are concatenated and the buffer is overwritten. The parameters' memory is then
re-assigned to point to the buffer to avoid subsequent concatenations.
"""
assert hasattr(self, buffer_name), f"No buffer named {buffer_name}"
full_param_buffer = getattr(self, buffer_name)
params = [getattr(self, name) for name in pnames]
slice_begin = 0
for i, p in enumerate(params):
slice_size = parameters_split[pnames[i].split('_')[0]+'_']
slice_end = slice_begin + slice_size
if p.data.data_ptr() != full_param_buffer[slice_begin:slice_end].data_ptr():
with torch.no_grad():
setattr(self, buffer_name, torch.cat(params))
slice_begin_j = 0
for pname in pnames:
slice_size_j = parameters_split[pname.split('_')[0]+'_']
slice_end_j = slice_begin_j + slice_size_j
full_param_buffer = getattr(self, buffer_name)
setattr(self, pname,
Parameter(full_param_buffer[slice_begin_j:slice_end_j]))
slice_begin_j = slice_end_j
break
slice_begin = slice_end
return _NoopCat.apply(getattr(self, buffer_name), *[getattr(self, name) for name in pnames])
def get_fp8_weights_empty_tensors(
self,
is_first_microbatch: Union[bool, None],
......
......@@ -7,9 +7,7 @@ import os
import warnings
from typing import Union, Optional, Callable, Tuple, List, Dict, Any
import torch
from torch.nn.parameter import Parameter
from torch.nn import init
from .. import cpp_extensions as tex
......@@ -41,7 +39,7 @@ from ..distributed import (
)
from ..constants import GemmParallelModes, dist_group_type, TE_DType
from ..jit import no_torch_dynamo
from ._common import _apply_normalization
from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor
......@@ -612,13 +610,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
Example use case: residual connection for transformer module is
taken post layernorm.
parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
if a tuple of strings or a dict of strings to integers is provided,
the weight and bias parameters of the module are exposed as `N` separate
`torch.nn.parameter.Parameter`s each, split along the first dimension,
where `N` is the length of the argument and the strings contained are the
names of the split parameters. In the case of a tuple, each parameter
has the same shape. In the case of a dict, the values give the
`out_features` for each projection.
Configuration for splitting the weight and bias tensors along dim 0 into
multiple PyTorch parameters. If a list or tuple of strings is provided,
they are used to make the names of equally-sized parameters. If a dict
(preferably an OrderedDict) is provided, the keys are used as names and
values as split sizes along dim 0. The resulting parameters will have
names that end in `_weight` or `_bias`, so trailing underscores are
stripped from any provided names.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
......@@ -705,7 +703,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_bias = return_bias
self.apply_bias = self.use_bias and not return_bias
self.return_layernorm_output = return_layernorm_output
self.parameters_split = parameters_split
self.zero_centered_gamma = zero_centered_gamma
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_bulk_wgrad = ub_bulk_wgrad
......@@ -752,12 +749,12 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
self.eps = eps
self.layer_norm_weight = Parameter(
self.layer_norm_weight = torch.nn.Parameter(
torch.empty(in_features, device=device, dtype=params_dtype)
)
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
if self.normalization != "RMSNorm":
self.layer_norm_bias = Parameter(
self.layer_norm_bias = torch.nn.Parameter(
torch.empty(in_features, device=device, dtype=params_dtype)
)
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
......@@ -800,68 +797,100 @@ class LayerNormLinear(TransformerEngineBaseModule):
with torch.no_grad():
self.bias_tensor.zero_()
# Configure parameter splits
self.weight_names = []
self.bias_names = []
self.parameter_split_sizes = []
if parameters_split is None:
parameters_split = {"": self.out_features}
elif isinstance(parameters_split, tuple):
assert (
self.out_features % len(parameters_split) == 0
), f"Weight and bias params cannot be split into {len(parameters_split)} parts"
split_size = self.out_features // len(parameters_split)
parameters_split = {key: split_size for key in parameters_split}
# Split into a single parameter by default
self.weight_names = ["weight"]
self.bias_names = ["bias"]
self.parameter_split_sizes = [out_features]
elif not parameters_split:
raise ValueError("Cannot split weight buffer into 0 parameters")
elif isinstance(parameters_split, dict):
overall_split_size = sum(parameters_split.values())
assert(
self.out_features == overall_split_size
), f"Overall sum of parameters_split (={overall_split_size}) does not match "\
f"to out features (={self.out_features})"
# Split parameters with provided sizes
for name, split_size in parameters_split.items():
self.weight_names.append(f"{name.rstrip('_')}_weight")
self.bias_names.append(f"{name.rstrip('_')}_bias")
self.parameter_split_sizes.append(split_size)
elif all(isinstance(name, str) for name in parameters_split):
# Split parameters evenly
split_size = out_features // len(parameters_split)
for name in parameters_split:
self.weight_names.append(f"{name.rstrip('_')}_weight")
self.bias_names.append(f"{name.rstrip('_')}_bias")
self.parameter_split_sizes.append(split_size)
else:
assert False, "Type of 'parameters_split' is not None, tuple or dict"
self.updated_parameters_split = parameters_split
raise TypeError("Invalid configuration for parameters split")
self.weight_names = []
self.bias_names = []
# Make sure parameter splits are valid
if sum(self.parameter_split_sizes) != out_features:
raise ValueError(
f"Trying to split weight buffer ({out_features=}) "
f"with split sizes {self.parameter_split_sizes}"
)
slice_begin = 0
for pname, slice_size in parameters_split.items():
wname = pname + "weight"
bname = pname + "bias"
slice_end = slice_begin + slice_size
# NOTE(future): Figure out a way to support slicing when weights
# are of `Float8Tensor` class
if self.primary_weights_in_fp8:
assert len(parameters_split) == 1, ("Slicing operation is not "
"supported in Float8Tensor "
"class!")
self.register_parameter(wname, Parameter(self.weight_tensor))
else:
self.register_parameter(
wname, Parameter(self.weight_tensor[slice_begin:slice_end])
# Adjust parameter splits for tensor-parallel distribution
if self.parallel_mode == "column":
for i, size in enumerate(self.parameter_split_sizes):
if size % self.tp_size != 0:
raise RuntimeError(
f"Attempting to distribute a parameter with out_features={size} "
f"between {self.tp_size} tensor-parallel processes"
)
self.parameter_split_sizes[i] = size // self.tp_size
# Construct parameters from weight and bias buffers
offset = 0
for i, split_size in enumerate(self.parameter_split_sizes):
split_start = offset
offset += split_size
split_end = offset
# Check if parameters are subviews of buffers
is_subview = (split_start, split_end) != (0, self.out_features)
if is_subview and self.primary_weights_in_fp8:
raise RuntimeError(
"Splitting Float8Tensor into multiple params "
"is not supported"
)
set_tensor_model_parallel_attributes(
tensor=getattr(self, wname),
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)
# Construct weight parameter
weight = self.weight_tensor
if is_subview:
weight = weight[split_start:split_end]
weight = torch.nn.Parameter(weight)
self.register_parameter(self.weight_names[i], weight)
# Construct bias parameter if needed
if self.use_bias:
self.register_parameter(
bname, Parameter(self.bias_tensor[slice_begin:slice_end])
)
bias = self.bias_tensor
if is_subview:
bias = bias[split_start:split_end]
bias = torch.nn.Parameter(bias)
self.register_parameter(self.bias_names[i], bias)
if parallel_mode == "row":
setattr(getattr(self, bname), "sequence_parallel", sequence_parallel)
bias.sequence_parallel = sequence_parallel
else:
setattr(self, bname, torch.Tensor().to(dtype=params_dtype, device=device))
bias = torch.Tensor().to(dtype=params_dtype, device=device)
setattr(self, self.bias_names[i], bias)
# Configure tensor parallelism
set_tensor_model_parallel_attributes(
tensor=weight,
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)
if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
set_tensor_model_parallel_attributes(bias, True, 0, 1)
self.weight_names.append(wname)
self.bias_names.append(bname)
slice_begin = slice_end
# Concatenated tensors are not needed if not splitting
# into multiple parameters
if not is_subview:
del self.weight_tensor
del self.bias_tensor
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
......@@ -880,12 +909,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
# Clean up weight and bias buffers
if self.parameters_split is None:
del self.weight_tensor
if self.use_bias:
del self.bias_tensor
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
if not self.zero_centered_gamma:
......@@ -950,18 +973,26 @@ class LayerNormLinear(TransformerEngineBaseModule):
with self.prepare_forward(inp, is_first_microbatch) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8."
bias_tensor = (
self.bias if self.parameters_split is None
else self.bias_tensor if not torch.is_grad_enabled()
else self.noop_cat("bias_tensor", self.bias_names,
self.updated_parameters_split)
)
weight_tensor = (
self.weight if self.parameters_split is None
else self.weight_tensor if not torch.is_grad_enabled()
else self.noop_cat("weight_tensor", self.weight_names,
self.updated_parameters_split)
)
# Get concatenated weight and bias tensors
if len(self.parameter_split_sizes) == 1:
weight_tensor = getattr(self, self.weight_names[0])
bias_tensor = getattr(self, self.bias_names[0])
elif torch.is_grad_enabled():
weight_tensor = _noop_cat(
[getattr(self, name) for name in self.weight_names],
self.weight_tensor,
)
if self.use_bias:
bias_tensor = _noop_cat(
[getattr(self, name) for name in self.bias_names],
self.bias_tensor,
)
else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused
else:
weight_tensor = self.weight_tensor
bias_tensor = self.bias_tensor
# Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(
......
......@@ -7,7 +7,6 @@ import warnings
from typing import Union, Optional, Callable, Tuple, List, Dict, Any
import torch
from torch.nn.parameter import Parameter
import transformer_engine_extensions as tex
......@@ -20,6 +19,7 @@ from .base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ._common import _noop_cat
from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
from ..utils import (
divide,
......@@ -521,8 +521,7 @@ class _Linear(torch.autograd.Function):
class Linear(TransformerEngineBaseModule):
"""
Applies a linear transformation to the incoming data :math:`y = xA^T + b`
"""Applies a linear transformation to the incoming data :math:`y = xA^T + b`
On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.
......@@ -538,13 +537,13 @@ class Linear(TransformerEngineBaseModule):
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
if a tuple of strings or a dict of strings to integers is provided,
the weight and bias parameters of the module are exposed as `N` separate
`torch.nn.parameter.Parameter`s each, split along the first dimension,
where `N` is the length of the argument and the strings contained are the
names of the split parameters. In the case of a tuple, each parameter
has the same shape. In the case of a dict, the values give the
`out_features` for each projection.
Configuration for splitting the weight and bias tensors along dim 0 into
multiple PyTorch parameters. If a list or tuple of strings is provided,
they are used to make the names of equally-sized parameters. If a dict
(preferably an OrderedDict) is provided, the keys are used as names and
values as split sizes along dim 0. The resulting parameters will have
names that end in `_weight` or `_bias`, so trailing underscores are
stripped from any provided names.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
......@@ -584,6 +583,7 @@ class Linear(TransformerEngineBaseModule):
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
"""
def __init__(
......@@ -617,7 +617,6 @@ class Linear(TransformerEngineBaseModule):
self.use_bias = bias
self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.parameters_split = parameters_split
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_split_rs = ub_split_rs
self.ub_split_ag = ub_split_ag
......@@ -694,69 +693,100 @@ class Linear(TransformerEngineBaseModule):
with torch.no_grad():
self.bias_tensor.zero_()
# Configure parameter splits
self.weight_names = []
self.bias_names = []
self.parameter_split_sizes = []
if parameters_split is None:
parameters_split = {"": self.out_features}
elif isinstance(parameters_split, tuple):
assert (
self.out_features % len(parameters_split) == 0
), f"Weight and bias params cannot be split into {len(parameters_split)} parts"
split_size = self.out_features // len(parameters_split)
parameters_split = {key: split_size for key in parameters_split}
# Split into a single parameter by default
self.weight_names = ["weight"]
self.bias_names = ["bias"]
self.parameter_split_sizes = [out_features]
elif not parameters_split:
raise ValueError("Cannot split weight buffer into 0 parameters")
elif isinstance(parameters_split, dict):
overall_split_size = sum(parameters_split.values())
assert(
self.out_features == overall_split_size
), f"Overall sum of parameters_split (={overall_split_size}) does not match "\
f"to out features (={self.out_features})"
# Split parameters with provided sizes
for name, split_size in parameters_split.items():
self.weight_names.append(f"{name.rstrip('_')}_weight")
self.bias_names.append(f"{name.rstrip('_')}_bias")
self.parameter_split_sizes.append(split_size)
elif all(isinstance(name, str) for name in parameters_split):
# Split parameters evenly
split_size = out_features // len(parameters_split)
for name in parameters_split:
self.weight_names.append(f"{name.rstrip('_')}_weight")
self.bias_names.append(f"{name.rstrip('_')}_bias")
self.parameter_split_sizes.append(split_size)
else:
assert False, "Type of 'parameters_split' is not None, tuple or dict"
self.updated_parameters_split = parameters_split
raise TypeError("Invalid configuration for parameters split")
self.weight_names = []
self.bias_names = []
# Make sure parameter splits are valid
if sum(self.parameter_split_sizes) != out_features:
raise ValueError(
f"Trying to split weight buffer ({out_features=}) "
f"with split sizes {self.parameter_split_sizes}"
)
slice_begin = 0
for pname, slice_size in parameters_split.items():
wname = pname + "weight"
bname = pname + "bias"
# Adjust parameter splits for tensor-parallel distribution
if self.parallel_mode == "column":
for i, size in enumerate(self.parameter_split_sizes):
if size % self.tp_size != 0:
raise RuntimeError(
f"Attempting to distribute a parameter with out_features={size} "
f"between {self.tp_size} tensor-parallel processes"
)
self.parameter_split_sizes[i] = size // self.tp_size
# Construct parameters from weight and bias buffers
offset = 0
for i, split_size in enumerate(self.parameter_split_sizes):
split_start = offset
offset += split_size
split_end = offset
# Check if parameters are subviews of buffers
is_subview = (split_start, split_end) != (0, self.out_features)
if is_subview and self.primary_weights_in_fp8:
raise RuntimeError(
"Splitting Float8Tensor into multiple params "
"is not supported"
)
slice_end = slice_begin + slice_size
# Construct weight parameter
weight = self.weight_tensor
if is_subview:
weight = weight[split_start:split_end]
weight = torch.nn.Parameter(weight)
self.register_parameter(self.weight_names[i], weight)
# TODO(ksivaman): Add indexing op to torch dispatcher for float8
if self.primary_weights_in_fp8:
assert len(parameters_split) == 1, ("Slicing operation is not "
"supported in Float8Tensor "
"class!")
self.register_parameter(wname, Parameter(self.weight_tensor))
# Construct bias parameter if needed
if self.use_bias:
bias = self.bias_tensor
if is_subview:
bias = bias[split_start:split_end]
bias = torch.nn.Parameter(bias)
self.register_parameter(self.bias_names[i], bias)
if parallel_mode == "row":
bias.sequence_parallel = sequence_parallel
else:
bias = torch.Tensor().to(dtype=params_dtype, device=device)
setattr(self, self.bias_names[i], bias)
self.register_parameter(
wname, Parameter(self.weight_tensor[slice_begin:slice_end])
)
# Configure tensor parallelism
set_tensor_model_parallel_attributes(
tensor=getattr(self, wname),
tensor=weight,
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)
if self.use_bias:
self.register_parameter(
bname, Parameter(self.bias_tensor[slice_begin:slice_end])
)
if parallel_mode == "row":
setattr(getattr(self, bname), "sequence_parallel", sequence_parallel)
else:
setattr(self, bname, torch.Tensor().to(dtype=params_dtype, device=device))
if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
set_tensor_model_parallel_attributes(bias, True, 0, 1)
self.weight_names.append(wname)
self.bias_names.append(bname)
slice_begin = slice_end
# Concatenated tensors are not needed if not splitting
# into multiple parameters
if not is_subview:
del self.weight_tensor
del self.bias_tensor
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
......@@ -767,12 +797,6 @@ class Linear(TransformerEngineBaseModule):
else:
self.gemm_bias_unfused_add = False
# Clean up weight and bias buffers
if self.parameters_split is None:
del self.weight_tensor
if self.use_bias:
del self.bias_tensor
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
......@@ -828,18 +852,26 @@ class Linear(TransformerEngineBaseModule):
with self.prepare_forward(inp, is_first_microbatch) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8."
bias_tensor = (
self.bias if self.parameters_split is None
else self.bias_tensor if not torch.is_grad_enabled()
else self.noop_cat("bias_tensor", self.bias_names,
self.updated_parameters_split)
)
weight_tensor = (
self.weight if self.parameters_split is None
else self.weight_tensor if not torch.is_grad_enabled()
else self.noop_cat("weight_tensor", self.weight_names,
self.updated_parameters_split)
)
# Get concatenated weight and bias tensors
if len(self.parameter_split_sizes) == 1:
weight_tensor = getattr(self, self.weight_names[0])
bias_tensor = getattr(self, self.bias_names[0])
elif torch.is_grad_enabled():
weight_tensor = _noop_cat(
[getattr(self, name) for name in self.weight_names],
self.weight_tensor,
)
if self.use_bias:
bias_tensor = _noop_cat(
[getattr(self, name) for name in self.bias_names],
self.bias_tensor,
)
else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused
else:
weight_tensor = self.weight_tensor
bias_tensor = self.bias_tensor
# Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment