"...git@developer.sourcefind.cn:tsoc/superbenchmark.git" did not exist on "949f9cb406a0263e45c38825b6953f3b46953c9e"
Unverified Commit db589510 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Fix layernorm in GQA (#434)



* [PyTorch] Implement GQA based on fused q, k, v projection. Additionally fixes #392
Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>

* [PyTorch] Extend parameters_split option in Linear and LayerNormLinear to support splitting with different sizes as required by unfused GQA.
Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>

* fix parameters split
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix noop cat to bypass torch.cat and support uneven split
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix torch.split args
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix cuda graph due to noop_cat
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove the use of enumerate when possible
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix strides in SplitAlongDim
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>
parent 903e1f4f
......@@ -141,7 +141,8 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"])
def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type):
@pytest.mark.parametrize("fused_qkv_params", [True, False])
def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type, fused_qkv_params):
"""Test TransformerLayer module when its DotProductAttention is enabled with
FlashAttention, FusedAttention, or UnfusedDotProductAttention backend"""
......@@ -149,11 +150,11 @@ def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type):
if bias_type == "no_bias":
flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
dtype, bs, config, "FlashAttention", ckpt_attn, bias_type)
dtype, bs, config, "FlashAttention", ckpt_attn, bias_type, fused_qkv_params)
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
dtype, bs, config, "FusedAttention", ckpt_attn, bias_type)
dtype, bs, config, "FusedAttention", ckpt_attn, bias_type, fused_qkv_params)
unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
dtype, bs, config, "UnfusedDotProductAttention", ckpt_attn, bias_type)
dtype, bs, config, "UnfusedDotProductAttention", ckpt_attn, bias_type, fused_qkv_params)
atol, rtol = (5e-1, 5e-2)
if bias_type == "no_bias":
......@@ -162,7 +163,7 @@ def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type):
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type):
def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fused_qkv_params):
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
......@@ -220,7 +221,7 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type):
layer_type="encoder",
drop_path_rate=drop_path_rates[layer_number - 1],
set_parallel_mode=True,
fuse_qkv_params=True,
fuse_qkv_params=fused_qkv_params,
zero_centered_gamma=False,
qkv_weight_interleaved=False,
ub_tp_comm_overlap=False,
......
......@@ -8,8 +8,9 @@ import warnings
import math
from importlib.metadata import version
from contextlib import nullcontext
from typing import Any, Callable, Optional, Tuple, Union, Dict
from typing import Any, Callable, Optional, Tuple, Union, Dict, List
from pkg_resources import packaging
import numpy as np
import torch
......@@ -508,48 +509,61 @@ def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
return torch.cat((t, t_pass), dim=-1)
class _SplitLastDim(torch.autograd.Function):
class _SplitAlongDim(torch.autograd.Function):
""""""
@staticmethod
def forward(ctx,
mixed_x_layer: torch.Tensor,
num_parts: int
split_dim: int,
split_size_or_sections: Union[int, List[int], Tuple[int]],
) -> Tuple[torch.Tensor, ...]:
return split_tensor_along_dim(mixed_x_layer, -1, num_parts)
ctx.split_dim = split_dim
ctx.split_size_or_sections = split_size_or_sections
return torch.split(mixed_x_layer, split_size_or_sections, dim = split_dim)
@staticmethod
def backward(ctx,
*grad_outputs):
assert len(grad_outputs) > 0, "No gradients received for backprop!"
if isinstance(ctx.split_size_or_sections, (list, tuple)):
split_sizes = ctx.split_size_or_sections
assert (len(grad_outputs) == len(split_sizes)
), "Unequal number of gradients vs split sections for backprop!"
if isinstance(ctx.split_size_or_sections, int):
split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
dims = len(grad_outputs[0].shape)
split_dim = (ctx.split_dim + dims) % dims
noop_ok = True
strides = grad_outputs[0].stride()
data_ptr = grad_outputs[0].storage().data_ptr()
shape = grad_outputs[0].shape
last_dim_size = grad_outputs[0].shape[-1]
shape = list(grad_outputs[0].shape)
for i, tensor in enumerate(grad_outputs):
shape_i = shape
shape_i[split_dim] = split_sizes[i]
offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim+1:])
if (tensor.stride() != strides or
tensor.shape != shape or
list(tensor.shape) != shape_i or
tensor.storage().data_ptr() != data_ptr or
tensor.storage_offset() != i * last_dim_size):
tensor.storage_offset() != offset_size):
noop_ok = False
break
if noop_ok:
ret = torch.Tensor().to(grad_outputs[0].dtype)
ret = torch.Tensor().to(device=grad_outputs[0].device,
dtype=grad_outputs[0].dtype)
new_shape = list(shape)
new_shape[-1] = new_shape[-1] * len(grad_outputs)
ret.set_(grad_outputs[0].storage(),
new_shape[split_dim] = sum(split_sizes)
ret.set_(grad_outputs[0].untyped_storage(),
grad_outputs[0].storage_offset(),
new_shape,
grad_outputs[0].stride()
strides
)
return ret, None
return ret, None, None
return torch.cat(grad_outputs, dim = -1), None
return torch.cat(grad_outputs, dim = split_dim), None, None
class _CombineQKV(torch.autograd.Function):
""""""
......@@ -1869,8 +1883,8 @@ class MultiheadAttention(torch.nn.Module):
num_attention_heads if num_gqa_groups is None else num_gqa_groups
)
assert (num_attention_heads % self.num_gqa_groups == 0
), "The number of GQA groups must be divisible by the number of attention heads!"
assert (num_attention_heads % tp_size == 0
), "The number of attention heads must be divisible by the number of GQA groups!"
assert (self.num_gqa_groups % tp_size == 0
), "The number of GQA groups must be divisible by tensor parallel size!"
self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // num_attention_heads)
......@@ -1887,18 +1901,21 @@ class MultiheadAttention(torch.nn.Module):
qkv_parallel_mode = "column" if set_parallel_mode else None
if self.attention_type == "self" and self.num_gqa_groups == self.num_attention_heads:
if self.attention_type == "self":
parameters_split = {"query_": hidden_size,
"key_": self.hidden_size_kv,
"value_": self.hidden_size_kv} if not fuse_qkv_params else None
if self.input_layernorm:
self.layernorm_qkv = LayerNormLinear(
hidden_size,
3 * hidden_size,
hidden_size + 2 * self.hidden_size_kv,
eps=layernorm_epsilon,
init_method=init_method,
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
return_layernorm_output=return_layernorm_output,
parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None,
parameters_split=parameters_split,
zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
......@@ -1909,17 +1926,15 @@ class MultiheadAttention(torch.nn.Module):
else:
self.qkv = Linear(
hidden_size,
3 * hidden_size,
hidden_size + 2 * self.hidden_size_kv,
init_method=init_method,
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None,
parameters_split=parameters_split,
**common_gemm_kwargs,
)
elif ((self.attention_type == "cross")
or (self.attention_type == "self"
and self.num_gqa_groups != self.num_attention_heads)):
elif self.attention_type == "cross":
if self.input_layernorm:
self.layernorm_query = LayerNormLinear(
hidden_size,
......@@ -1929,6 +1944,7 @@ class MultiheadAttention(torch.nn.Module):
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
parameters_split=("query_",) if not fuse_qkv_params else None,
return_layernorm_output=return_layernorm_output,
zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
......@@ -2115,8 +2131,8 @@ class MultiheadAttention(torch.nn.Module):
# Query, Key, and Value
# =====================
if self.attention_type == "self" and self.num_gqa_groups == self.num_attention_heads:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
if self.attention_type == "self":
# Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
if self.input_layernorm:
layernorm_qkv_outputs = self.layernorm_qkv(
hidden_states,
......@@ -2132,49 +2148,59 @@ class MultiheadAttention(torch.nn.Module):
is_first_microbatch=is_first_microbatch,
)
num_queries_per_key_value = (self.num_attention_heads_per_partition //
self.num_gqa_groups_per_partition)
if self.qkv_weight_interleaved:
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
# split along last dimension
split_dim = -1
else:
# [sq, b, (np * 3 * hn)] --> [sq, b, 3 * np, hn]
# [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
3 * self.num_attention_heads_per_partition,
self.num_gqa_groups_per_partition,
(num_queries_per_key_value + 2),
self.hidden_size_per_attention_head,
)
# split along second last dimension
split_dim = -2
else:
# [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, (np/ng + 2), ng, hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
(num_queries_per_key_value + 2),
self.num_gqa_groups_per_partition,
self.hidden_size_per_attention_head
)
# split along third last dimension
split_dim = -3
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# mixed_x_layer --> 3 [sq, b, np, hn]
if split_dim == -1 and not is_in_onnx_export_mode():
query_layer, key_layer, value_layer = _SplitLastDim.apply(mixed_x_layer, 3)
else:
query_layer, key_layer, value_layer = split_tensor_along_dim(
mixed_x_layer, split_dim, 3
# qkv_weight_interleaved:
# [sq, b, ng, (np/ng + 2), hn]
# --> [sq, b, ng, np/ng, hn], [sq, b, ng, 1, hn], [sq, b, ng, 1, hn]
# not qkv_weight_interleaved:
# [sq, b, (np/ng + 2), ng, hn]
# --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn]
if not is_in_onnx_export_mode():
query_layer, key_layer, value_layer = _SplitAlongDim.apply(
mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1)
)
elif ((self.attention_type == "cross")
or (self.attention_type == "self"
and self.num_gqa_groups != self.num_attention_heads)):
if self.attention_type == "cross":
input_tensor = encoder_output
else:
input_tensor = hidden_states
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
query_layer, key_layer, value_layer = torch.split(
mixed_x_layer, (num_queries_per_key_value, 1, 1), dim = split_dim,
)
# query: -> [sq, b, np, hn]
# key, value: -> [sq, b, ng, hn]
query_layer, key_layer, value_layer = (x.reshape(x.size(0), x.size(1), -1,
self.hidden_size_per_attention_head)
for x in (query_layer, key_layer, value_layer))
elif self.attention_type == "cross":
# Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)]
mixed_kv_layer = self.key_value(
input_tensor,
encoder_output,
is_first_microbatch=is_first_microbatch,
)
if self.qkv_weight_interleaved:
# [sq, b, (np * 2 * hn)] --> [sq, b, np, 2 * hn]
# [sq, b, (ng * 2 * hn)] --> [sq, b, ng, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + (
self.num_gqa_groups_per_partition,
2 * self.hidden_size_per_attention_head,
......@@ -2182,7 +2208,7 @@ class MultiheadAttention(torch.nn.Module):
# split along last dimension
split_dim = -1
else:
# [sq, b, (np * 2 * hn)] --> [sq, b, 2 * np, hn]
# [sq, b, (ng * 2 * hn)] --> [sq, b, 2 * ng, hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + (
2 * self.num_gqa_groups_per_partition,
self.hidden_size_per_attention_head,
......@@ -2192,11 +2218,15 @@ class MultiheadAttention(torch.nn.Module):
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# mixed_kv_layer --> 2 [sk, b, np, hn]
if split_dim == -1 and not is_in_onnx_export_mode():
key_layer, value_layer = _SplitLastDim.apply(mixed_kv_layer, 2)
# mixed_kv_layer --> 2 [sk, b, ng, hn]
if not is_in_onnx_export_mode():
key_layer, value_layer = _SplitAlongDim.apply(
mixed_kv_layer, split_dim, mixed_kv_layer.shape[split_dim] // 2,
)
else:
key_layer, value_layer = split_tensor_along_dim(mixed_kv_layer, split_dim, 2)
key_layer, value_layer = torch.split(
mixed_kv_layer, mixed_kv_layer.shape[split_dim] // 2, dim = split_dim,
)
# Attention head [sq, b, h] --> [sq, b, hp]
if self.input_layernorm:
......
......@@ -212,8 +212,9 @@ class _NoopCat(torch.autograd.Function):
*params_split: Tuple[torch.Tensor, ...],
) -> torch.Tensor:
assert not full_param_buffer.requires_grad, "Buffers should not require gradient"
sum_params_shape = sum(p.shape[0] for p in params_split)
assert (
full_param_buffer.shape[0] % len(params_split) == 0
full_param_buffer.shape[0] == sum_params_shape
), "Dimensions not compatible for concatenation"
param_temp = full_param_buffer.new()
......@@ -223,18 +224,19 @@ class _NoopCat(torch.autograd.Function):
full_param_buffer.stride())
param_temp.requires_grad = True
ctx.save_for_backward(full_param_buffer, *params_split)
ctx.save_for_backward(*params_split)
return param_temp
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
full_param_buffer, *params_split = ctx.saved_tensors
split_size = full_param_buffer.shape[0] // len(params_split)
params_split = ctx.saved_tensors
grads = []
slice_begin = 0
for i, _ in enumerate(params_split):
grads.append(grad_output[i * split_size : (i+1) * split_size])
slice_size = params_split[i].shape[0]
slice_end = slice_begin + slice_size
grads.append(grad_output[slice_begin:slice_end])
slice_begin = slice_end
return None, *grads
......@@ -753,7 +755,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return grad_output_mat, grad_output_c, grad_output_t, grad_bias
def noop_cat(self, buffer_name: str, pnames: List[str]) -> torch.Tensor:
def noop_cat(self,
buffer_name: str,
pnames: List[str],
parameters_split: Dict[str, int]
) -> torch.Tensor:
"""No-op replacement of `torch.cat`. The buffer and split parameters must occupy
the same memory region. If this is not the case, then the split parameters
are concatenated and the buffer is overwritten. The parameters' memory is then
......@@ -762,17 +768,24 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
assert hasattr(self, buffer_name), f"No buffer named {buffer_name}"
full_param_buffer = getattr(self, buffer_name)
split_size = full_param_buffer.shape[0] // len(pnames)
params = [getattr(self, name) for name in pnames]
slice_begin = 0
for i, p in enumerate(params):
if p.data.data_ptr() != full_param_buffer[i*split_size : (i+1)*split_size].data_ptr():
slice_size = parameters_split[pnames[i].split('_')[0]+'_']
slice_end = slice_begin + slice_size
if p.data.data_ptr() != full_param_buffer[slice_begin:slice_end].data_ptr():
with torch.no_grad():
setattr(self, buffer_name, torch.cat(params))
for j, pname in enumerate(pnames):
slice_begin_j = 0
for pname in pnames:
slice_size_j = parameters_split[pname.split('_')[0]+'_']
slice_end_j = slice_begin_j + slice_size_j
full_param_buffer = getattr(self, buffer_name)
setattr(self, pname,
Parameter(full_param_buffer[j*split_size : (j+1)*split_size]))
Parameter(full_param_buffer[slice_begin_j:slice_end_j]))
slice_begin_j = slice_end_j
break
slice_begin = slice_end
return _NoopCat.apply(getattr(self, buffer_name), *[getattr(self, name) for name in pnames])
......
......@@ -536,11 +536,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
together with the output of the linear transformation.
Example use case: residual connection for transformer module is
taken post layernorm.
parameters_split : Tuple[str, ...], default = None
if a tuple of strings is provided, the weight and bias parameters of the
module are exposed as `N` separate `torch.nn.parameter.Parameter`s each,
split along the first dimension, where `N` is the length of the argument
and the strings contained are the names of the split parameters.
parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
if a tuple of strings or a dict of strings to integers is provided,
the weight and bias parameters of the module are exposed as `N` separate
`torch.nn.parameter.Parameter`s each, split along the first dimension,
where `N` is the length of the argument and the strings contained are the
names of the split parameters. In the case of a tuple, each parameter
has the same shape. In the case of a dict, the values give the
`out_features` for each projection.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
......@@ -607,7 +610,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
parallel_mode: Optional[str] = None,
return_layernorm_output: bool = False,
skip_weight_param_allocation: bool = False,
parameters_split: Optional[Tuple[str, ...]] = None,
parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
zero_centered_gamma: bool = False,
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
......@@ -707,23 +710,35 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.bias_tensor.zero_()
if parameters_split is None:
parameters_split = ("",)
assert (
self.out_features % len(parameters_split) == 0
), f"Weight and bias params cannot be split into {len(parameters_split)} parts"
split_size = self.out_features // len(parameters_split)
parameters_split = {"": self.out_features}
elif isinstance(parameters_split, tuple):
assert (
self.out_features % len(parameters_split) == 0
), f"Weight and bias params cannot be split into {len(parameters_split)} parts"
split_size = self.out_features // len(parameters_split)
parameters_split = {key: split_size for key in parameters_split}
elif isinstance(parameters_split, dict):
overall_split_size = sum(parameters_split.values())
assert(
self.out_features == overall_split_size
), f"Overall sum of parameters_split (={overall_split_size}) does not match "\
f"to out features (={self.out_features})"
else:
assert False, "Type of 'parameters_split' is not None, tuple or dict"
self.updated_parameters_split = parameters_split
self.weight_names = []
self.bias_names = []
for i, pname in enumerate(parameters_split):
slice_begin = 0
for pname, slice_size in parameters_split.items():
wname = pname + "weight"
bname = pname + "bias"
slice_end = slice_begin + slice_size
self.register_parameter(
wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size])
wname, Parameter(self.weight_tensor[slice_begin:slice_end])
)
set_tensor_model_parallel_attributes(
......@@ -735,7 +750,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
if self.use_bias:
self.register_parameter(
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
bname, Parameter(self.bias_tensor[slice_begin:slice_end])
)
if parallel_mode == "row":
setattr(getattr(self, bname), "sequence_parallel", sequence_parallel)
......@@ -748,6 +763,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.weight_names.append(wname)
self.bias_names.append(bname)
slice_begin = slice_end
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
......@@ -843,12 +860,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
bias_tensor = (
self.bias if self.parameters_split is None
else self.bias_tensor if not torch.is_grad_enabled()
else self.noop_cat("bias_tensor", self.bias_names)
else self.noop_cat("bias_tensor", self.bias_names,
self.updated_parameters_split)
)
weight_tensor = (
self.weight if self.parameters_split is None
else self.weight_tensor if not torch.is_grad_enabled()
else self.noop_cat("weight_tensor", self.weight_names)
else self.noop_cat("weight_tensor", self.weight_names,
self.updated_parameters_split)
)
# Fetch the fp8 weights placeholders (for linear/gemm)
......
......@@ -461,11 +461,14 @@ class Linear(TransformerEngineBaseModule):
init_method : Callable, default = `None`
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
parameters_split : Tuple[str, ...], default = None
if a tuple of strings is provided, the weight and bias parameters of the
module are exposed as `N` separate `torch.nn.parameter.Parameter`s each,
split along the first dimension, where `N` is the length of the argument
and the strings contained are the names of the split parameters.
parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
if a tuple of strings or a dict of strings to integers is provided,
the weight and bias parameters of the module are exposed as `N` separate
`torch.nn.parameter.Parameter`s each, split along the first dimension,
where `N` is the length of the argument and the strings contained are the
names of the split parameters. In the case of a tuple, each parameter
has the same shape. In the case of a dict, the values give the
`out_features` for each projection.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
......@@ -522,7 +525,7 @@ class Linear(TransformerEngineBaseModule):
params_dtype: Optional[torch.dtype] = None,
parallel_mode: Optional[str] = None,
skip_weight_param_allocation: bool = False,
parameters_split: Optional[Tuple[str, ...]] = None,
parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
ub_split_rs: bool = False,
ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda",
......@@ -598,23 +601,35 @@ class Linear(TransformerEngineBaseModule):
self.bias_tensor.zero_()
if parameters_split is None:
parameters_split = ("",)
assert (
self.out_features % len(parameters_split) == 0
), f"Weight and bias params cannot be split into {len(parameters_split)} parts"
split_size = self.out_features // len(parameters_split)
parameters_split = {"": self.out_features}
elif isinstance(parameters_split, tuple):
assert (
self.out_features % len(parameters_split) == 0
), f"Weight and bias params cannot be split into {len(parameters_split)} parts"
split_size = self.out_features // len(parameters_split)
parameters_split = {key: split_size for key in parameters_split}
elif isinstance(parameters_split, dict):
overall_split_size = sum(parameters_split.values())
assert(
self.out_features == overall_split_size
), f"Overall sum of parameters_split (={overall_split_size}) does not match "\
f"to out features (={self.out_features})"
else:
assert False, "Type of 'parameters_split' is not None, tuple or dict"
self.updated_parameters_split = parameters_split
self.weight_names = []
self.bias_names = []
for i, pname in enumerate(parameters_split):
slice_begin = 0
for pname, slice_size in parameters_split.items():
wname = pname + "weight"
bname = pname + "bias"
slice_end = slice_begin + slice_size
self.register_parameter(
wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size])
wname, Parameter(self.weight_tensor[slice_begin:slice_end])
)
set_tensor_model_parallel_attributes(
......@@ -626,7 +641,7 @@ class Linear(TransformerEngineBaseModule):
if self.use_bias:
self.register_parameter(
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
bname, Parameter(self.bias_tensor[slice_begin:slice_end])
)
if parallel_mode == "row":
setattr(getattr(self, bname), "sequence_parallel", sequence_parallel)
......@@ -639,6 +654,8 @@ class Linear(TransformerEngineBaseModule):
self.weight_names.append(wname)
self.bias_names.append(bname)
slice_begin = slice_end
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
# For RPL, bias has to be added after TP collectives
......@@ -717,12 +734,14 @@ class Linear(TransformerEngineBaseModule):
bias_tensor = (
self.bias if self.parameters_split is None
else self.bias_tensor if not torch.is_grad_enabled()
else self.noop_cat("bias_tensor", self.bias_names)
else self.noop_cat("bias_tensor", self.bias_names,
self.updated_parameters_split)
)
weight_tensor = (
self.weight if self.parameters_split is None
else self.weight_tensor if not torch.is_grad_enabled()
else self.noop_cat("weight_tensor", self.weight_names)
else self.noop_cat("weight_tensor", self.weight_names,
self.updated_parameters_split)
)
# Fetch the fp8 weights placeholders (for linear/gemm)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment