Unverified Commit db589510 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Fix layernorm in GQA (#434)



* [PyTorch] Implement GQA based on fused q, k, v projection. Additionally fixes #392
Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>

* [PyTorch] Extend parameters_split option in Linear and LayerNormLinear to support splitting with different sizes as required by unfused GQA.
Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>

* fix parameters split
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix noop cat to bypass torch.cat and support uneven split
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix torch.split args
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix cuda graph due to noop_cat
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove the use of enumerate when possible
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix strides in SplitAlongDim
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>
parent 903e1f4f
...@@ -141,7 +141,8 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type) ...@@ -141,7 +141,8 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("ckpt_attn", [False]) @pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"]) @pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"])
def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type): @pytest.mark.parametrize("fused_qkv_params", [True, False])
def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type, fused_qkv_params):
"""Test TransformerLayer module when its DotProductAttention is enabled with """Test TransformerLayer module when its DotProductAttention is enabled with
FlashAttention, FusedAttention, or UnfusedDotProductAttention backend""" FlashAttention, FusedAttention, or UnfusedDotProductAttention backend"""
...@@ -149,11 +150,11 @@ def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type): ...@@ -149,11 +150,11 @@ def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type):
if bias_type == "no_bias": if bias_type == "no_bias":
flash_attn_fwd, flash_attn_bwd = _run_transformer_layer( flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
dtype, bs, config, "FlashAttention", ckpt_attn, bias_type) dtype, bs, config, "FlashAttention", ckpt_attn, bias_type, fused_qkv_params)
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer( fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
dtype, bs, config, "FusedAttention", ckpt_attn, bias_type) dtype, bs, config, "FusedAttention", ckpt_attn, bias_type, fused_qkv_params)
unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer( unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
dtype, bs, config, "UnfusedDotProductAttention", ckpt_attn, bias_type) dtype, bs, config, "UnfusedDotProductAttention", ckpt_attn, bias_type, fused_qkv_params)
atol, rtol = (5e-1, 5e-2) atol, rtol = (5e-1, 5e-2)
if bias_type == "no_bias": if bias_type == "no_bias":
...@@ -162,7 +163,7 @@ def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type): ...@@ -162,7 +163,7 @@ def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type):
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type): def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fused_qkv_params):
reset_rng_states() reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
...@@ -220,7 +221,7 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type): ...@@ -220,7 +221,7 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type):
layer_type="encoder", layer_type="encoder",
drop_path_rate=drop_path_rates[layer_number - 1], drop_path_rate=drop_path_rates[layer_number - 1],
set_parallel_mode=True, set_parallel_mode=True,
fuse_qkv_params=True, fuse_qkv_params=fused_qkv_params,
zero_centered_gamma=False, zero_centered_gamma=False,
qkv_weight_interleaved=False, qkv_weight_interleaved=False,
ub_tp_comm_overlap=False, ub_tp_comm_overlap=False,
......
...@@ -8,8 +8,9 @@ import warnings ...@@ -8,8 +8,9 @@ import warnings
import math import math
from importlib.metadata import version from importlib.metadata import version
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Callable, Optional, Tuple, Union, Dict from typing import Any, Callable, Optional, Tuple, Union, Dict, List
from pkg_resources import packaging from pkg_resources import packaging
import numpy as np
import torch import torch
...@@ -508,48 +509,61 @@ def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: ...@@ -508,48 +509,61 @@ def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
return torch.cat((t, t_pass), dim=-1) return torch.cat((t, t_pass), dim=-1)
class _SplitLastDim(torch.autograd.Function): class _SplitAlongDim(torch.autograd.Function):
"""""" """"""
@staticmethod @staticmethod
def forward(ctx, def forward(ctx,
mixed_x_layer: torch.Tensor, mixed_x_layer: torch.Tensor,
num_parts: int split_dim: int,
split_size_or_sections: Union[int, List[int], Tuple[int]],
) -> Tuple[torch.Tensor, ...]: ) -> Tuple[torch.Tensor, ...]:
return split_tensor_along_dim(mixed_x_layer, -1, num_parts) ctx.split_dim = split_dim
ctx.split_size_or_sections = split_size_or_sections
return torch.split(mixed_x_layer, split_size_or_sections, dim = split_dim)
@staticmethod @staticmethod
def backward(ctx, def backward(ctx,
*grad_outputs): *grad_outputs):
assert len(grad_outputs) > 0, "No gradients received for backprop!" assert len(grad_outputs) > 0, "No gradients received for backprop!"
if isinstance(ctx.split_size_or_sections, (list, tuple)):
split_sizes = ctx.split_size_or_sections
assert (len(grad_outputs) == len(split_sizes)
), "Unequal number of gradients vs split sections for backprop!"
if isinstance(ctx.split_size_or_sections, int):
split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
dims = len(grad_outputs[0].shape)
split_dim = (ctx.split_dim + dims) % dims
noop_ok = True noop_ok = True
strides = grad_outputs[0].stride() strides = grad_outputs[0].stride()
data_ptr = grad_outputs[0].storage().data_ptr() data_ptr = grad_outputs[0].storage().data_ptr()
shape = grad_outputs[0].shape shape = list(grad_outputs[0].shape)
last_dim_size = grad_outputs[0].shape[-1]
for i, tensor in enumerate(grad_outputs): for i, tensor in enumerate(grad_outputs):
shape_i = shape
shape_i[split_dim] = split_sizes[i]
offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim+1:])
if (tensor.stride() != strides or if (tensor.stride() != strides or
tensor.shape != shape or list(tensor.shape) != shape_i or
tensor.storage().data_ptr() != data_ptr or tensor.storage().data_ptr() != data_ptr or
tensor.storage_offset() != i * last_dim_size): tensor.storage_offset() != offset_size):
noop_ok = False noop_ok = False
break break
if noop_ok: if noop_ok:
ret = torch.Tensor().to(grad_outputs[0].dtype)
ret = torch.Tensor().to(device=grad_outputs[0].device, ret = torch.Tensor().to(device=grad_outputs[0].device,
dtype=grad_outputs[0].dtype) dtype=grad_outputs[0].dtype)
new_shape = list(shape) new_shape = list(shape)
new_shape[-1] = new_shape[-1] * len(grad_outputs) new_shape[split_dim] = sum(split_sizes)
ret.set_(grad_outputs[0].storage(), ret.set_(grad_outputs[0].untyped_storage(),
grad_outputs[0].storage_offset(), grad_outputs[0].storage_offset(),
new_shape, new_shape,
grad_outputs[0].stride() strides
) )
return ret, None return ret, None, None
return torch.cat(grad_outputs, dim = -1), None return torch.cat(grad_outputs, dim = split_dim), None, None
class _CombineQKV(torch.autograd.Function): class _CombineQKV(torch.autograd.Function):
"""""" """"""
...@@ -1869,8 +1883,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -1869,8 +1883,8 @@ class MultiheadAttention(torch.nn.Module):
num_attention_heads if num_gqa_groups is None else num_gqa_groups num_attention_heads if num_gqa_groups is None else num_gqa_groups
) )
assert (num_attention_heads % self.num_gqa_groups == 0 assert (num_attention_heads % self.num_gqa_groups == 0
), "The number of GQA groups must be divisible by the number of attention heads!" ), "The number of attention heads must be divisible by the number of GQA groups!"
assert (num_attention_heads % tp_size == 0 assert (self.num_gqa_groups % tp_size == 0
), "The number of GQA groups must be divisible by tensor parallel size!" ), "The number of GQA groups must be divisible by tensor parallel size!"
self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size) self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // num_attention_heads) self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // num_attention_heads)
...@@ -1887,18 +1901,21 @@ class MultiheadAttention(torch.nn.Module): ...@@ -1887,18 +1901,21 @@ class MultiheadAttention(torch.nn.Module):
qkv_parallel_mode = "column" if set_parallel_mode else None qkv_parallel_mode = "column" if set_parallel_mode else None
if self.attention_type == "self" and self.num_gqa_groups == self.num_attention_heads: if self.attention_type == "self":
parameters_split = {"query_": hidden_size,
"key_": self.hidden_size_kv,
"value_": self.hidden_size_kv} if not fuse_qkv_params else None
if self.input_layernorm: if self.input_layernorm:
self.layernorm_qkv = LayerNormLinear( self.layernorm_qkv = LayerNormLinear(
hidden_size, hidden_size,
3 * hidden_size, hidden_size + 2 * self.hidden_size_kv,
eps=layernorm_epsilon, eps=layernorm_epsilon,
init_method=init_method, init_method=init_method,
bias=bias, bias=bias,
return_bias=False, return_bias=False,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
return_layernorm_output=return_layernorm_output, return_layernorm_output=return_layernorm_output,
parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None, parameters_split=parameters_split,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad, ub_bulk_dgrad=ub_bulk_dgrad,
...@@ -1909,17 +1926,15 @@ class MultiheadAttention(torch.nn.Module): ...@@ -1909,17 +1926,15 @@ class MultiheadAttention(torch.nn.Module):
else: else:
self.qkv = Linear( self.qkv = Linear(
hidden_size, hidden_size,
3 * hidden_size, hidden_size + 2 * self.hidden_size_kv,
init_method=init_method, init_method=init_method,
bias=bias, bias=bias,
return_bias=False, return_bias=False,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None, parameters_split=parameters_split,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
elif ((self.attention_type == "cross") elif self.attention_type == "cross":
or (self.attention_type == "self"
and self.num_gqa_groups != self.num_attention_heads)):
if self.input_layernorm: if self.input_layernorm:
self.layernorm_query = LayerNormLinear( self.layernorm_query = LayerNormLinear(
hidden_size, hidden_size,
...@@ -1929,6 +1944,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -1929,6 +1944,7 @@ class MultiheadAttention(torch.nn.Module):
bias=bias, bias=bias,
return_bias=False, return_bias=False,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
parameters_split=("query_",) if not fuse_qkv_params else None,
return_layernorm_output=return_layernorm_output, return_layernorm_output=return_layernorm_output,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_wgrad=ub_bulk_wgrad,
...@@ -2115,8 +2131,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2115,8 +2131,8 @@ class MultiheadAttention(torch.nn.Module):
# Query, Key, and Value # Query, Key, and Value
# ===================== # =====================
if self.attention_type == "self" and self.num_gqa_groups == self.num_attention_heads: if self.attention_type == "self":
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
if self.input_layernorm: if self.input_layernorm:
layernorm_qkv_outputs = self.layernorm_qkv( layernorm_qkv_outputs = self.layernorm_qkv(
hidden_states, hidden_states,
...@@ -2132,49 +2148,59 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2132,49 +2148,59 @@ class MultiheadAttention(torch.nn.Module):
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
) )
num_queries_per_key_value = (self.num_attention_heads_per_partition //
self.num_gqa_groups_per_partition)
if self.qkv_weight_interleaved: if self.qkv_weight_interleaved:
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
# split along last dimension
split_dim = -1
else:
# [sq, b, (np * 3 * hn)] --> [sq, b, 3 * np, hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + ( new_tensor_shape = mixed_x_layer.size()[:-1] + (
3 * self.num_attention_heads_per_partition, self.num_gqa_groups_per_partition,
(num_queries_per_key_value + 2),
self.hidden_size_per_attention_head, self.hidden_size_per_attention_head,
) )
# split along second last dimension # split along second last dimension
split_dim = -2 split_dim = -2
else:
# [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, (np/ng + 2), ng, hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
(num_queries_per_key_value + 2),
self.num_gqa_groups_per_partition,
self.hidden_size_per_attention_head
)
# split along third last dimension
split_dim = -3
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# mixed_x_layer --> 3 [sq, b, np, hn] # qkv_weight_interleaved:
if split_dim == -1 and not is_in_onnx_export_mode(): # [sq, b, ng, (np/ng + 2), hn]
query_layer, key_layer, value_layer = _SplitLastDim.apply(mixed_x_layer, 3) # --> [sq, b, ng, np/ng, hn], [sq, b, ng, 1, hn], [sq, b, ng, 1, hn]
else: # not qkv_weight_interleaved:
query_layer, key_layer, value_layer = split_tensor_along_dim( # [sq, b, (np/ng + 2), ng, hn]
mixed_x_layer, split_dim, 3 # --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn]
if not is_in_onnx_export_mode():
query_layer, key_layer, value_layer = _SplitAlongDim.apply(
mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1)
) )
elif ((self.attention_type == "cross")
or (self.attention_type == "self"
and self.num_gqa_groups != self.num_attention_heads)):
if self.attention_type == "cross":
input_tensor = encoder_output
else: else:
input_tensor = hidden_states query_layer, key_layer, value_layer = torch.split(
mixed_x_layer, (num_queries_per_key_value, 1, 1), dim = split_dim,
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] )
# query: -> [sq, b, np, hn]
# key, value: -> [sq, b, ng, hn]
query_layer, key_layer, value_layer = (x.reshape(x.size(0), x.size(1), -1,
self.hidden_size_per_attention_head)
for x in (query_layer, key_layer, value_layer))
elif self.attention_type == "cross":
# Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)]
mixed_kv_layer = self.key_value( mixed_kv_layer = self.key_value(
input_tensor, encoder_output,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
) )
if self.qkv_weight_interleaved: if self.qkv_weight_interleaved:
# [sq, b, (np * 2 * hn)] --> [sq, b, np, 2 * hn] # [sq, b, (ng * 2 * hn)] --> [sq, b, ng, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + ( new_tensor_shape = mixed_kv_layer.size()[:-1] + (
self.num_gqa_groups_per_partition, self.num_gqa_groups_per_partition,
2 * self.hidden_size_per_attention_head, 2 * self.hidden_size_per_attention_head,
...@@ -2182,7 +2208,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2182,7 +2208,7 @@ class MultiheadAttention(torch.nn.Module):
# split along last dimension # split along last dimension
split_dim = -1 split_dim = -1
else: else:
# [sq, b, (np * 2 * hn)] --> [sq, b, 2 * np, hn] # [sq, b, (ng * 2 * hn)] --> [sq, b, 2 * ng, hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + ( new_tensor_shape = mixed_kv_layer.size()[:-1] + (
2 * self.num_gqa_groups_per_partition, 2 * self.num_gqa_groups_per_partition,
self.hidden_size_per_attention_head, self.hidden_size_per_attention_head,
...@@ -2192,11 +2218,15 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2192,11 +2218,15 @@ class MultiheadAttention(torch.nn.Module):
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# mixed_kv_layer --> 2 [sk, b, np, hn] # mixed_kv_layer --> 2 [sk, b, ng, hn]
if split_dim == -1 and not is_in_onnx_export_mode(): if not is_in_onnx_export_mode():
key_layer, value_layer = _SplitLastDim.apply(mixed_kv_layer, 2) key_layer, value_layer = _SplitAlongDim.apply(
mixed_kv_layer, split_dim, mixed_kv_layer.shape[split_dim] // 2,
)
else: else:
key_layer, value_layer = split_tensor_along_dim(mixed_kv_layer, split_dim, 2) key_layer, value_layer = torch.split(
mixed_kv_layer, mixed_kv_layer.shape[split_dim] // 2, dim = split_dim,
)
# Attention head [sq, b, h] --> [sq, b, hp] # Attention head [sq, b, h] --> [sq, b, hp]
if self.input_layernorm: if self.input_layernorm:
......
...@@ -212,8 +212,9 @@ class _NoopCat(torch.autograd.Function): ...@@ -212,8 +212,9 @@ class _NoopCat(torch.autograd.Function):
*params_split: Tuple[torch.Tensor, ...], *params_split: Tuple[torch.Tensor, ...],
) -> torch.Tensor: ) -> torch.Tensor:
assert not full_param_buffer.requires_grad, "Buffers should not require gradient" assert not full_param_buffer.requires_grad, "Buffers should not require gradient"
sum_params_shape = sum(p.shape[0] for p in params_split)
assert ( assert (
full_param_buffer.shape[0] % len(params_split) == 0 full_param_buffer.shape[0] == sum_params_shape
), "Dimensions not compatible for concatenation" ), "Dimensions not compatible for concatenation"
param_temp = full_param_buffer.new() param_temp = full_param_buffer.new()
...@@ -223,18 +224,19 @@ class _NoopCat(torch.autograd.Function): ...@@ -223,18 +224,19 @@ class _NoopCat(torch.autograd.Function):
full_param_buffer.stride()) full_param_buffer.stride())
param_temp.requires_grad = True param_temp.requires_grad = True
ctx.save_for_backward(full_param_buffer, *params_split) ctx.save_for_backward(*params_split)
return param_temp return param_temp
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
full_param_buffer, *params_split = ctx.saved_tensors params_split = ctx.saved_tensors
split_size = full_param_buffer.shape[0] // len(params_split)
grads = [] grads = []
slice_begin = 0
for i, _ in enumerate(params_split): for i, _ in enumerate(params_split):
grads.append(grad_output[i * split_size : (i+1) * split_size]) slice_size = params_split[i].shape[0]
slice_end = slice_begin + slice_size
grads.append(grad_output[slice_begin:slice_end])
slice_begin = slice_end
return None, *grads return None, *grads
...@@ -753,7 +755,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -753,7 +755,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return grad_output_mat, grad_output_c, grad_output_t, grad_bias return grad_output_mat, grad_output_c, grad_output_t, grad_bias
def noop_cat(self, buffer_name: str, pnames: List[str]) -> torch.Tensor: def noop_cat(self,
buffer_name: str,
pnames: List[str],
parameters_split: Dict[str, int]
) -> torch.Tensor:
"""No-op replacement of `torch.cat`. The buffer and split parameters must occupy """No-op replacement of `torch.cat`. The buffer and split parameters must occupy
the same memory region. If this is not the case, then the split parameters the same memory region. If this is not the case, then the split parameters
are concatenated and the buffer is overwritten. The parameters' memory is then are concatenated and the buffer is overwritten. The parameters' memory is then
...@@ -762,17 +768,24 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -762,17 +768,24 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
assert hasattr(self, buffer_name), f"No buffer named {buffer_name}" assert hasattr(self, buffer_name), f"No buffer named {buffer_name}"
full_param_buffer = getattr(self, buffer_name) full_param_buffer = getattr(self, buffer_name)
split_size = full_param_buffer.shape[0] // len(pnames)
params = [getattr(self, name) for name in pnames] params = [getattr(self, name) for name in pnames]
slice_begin = 0
for i, p in enumerate(params): for i, p in enumerate(params):
if p.data.data_ptr() != full_param_buffer[i*split_size : (i+1)*split_size].data_ptr(): slice_size = parameters_split[pnames[i].split('_')[0]+'_']
slice_end = slice_begin + slice_size
if p.data.data_ptr() != full_param_buffer[slice_begin:slice_end].data_ptr():
with torch.no_grad(): with torch.no_grad():
setattr(self, buffer_name, torch.cat(params)) setattr(self, buffer_name, torch.cat(params))
for j, pname in enumerate(pnames): slice_begin_j = 0
for pname in pnames:
slice_size_j = parameters_split[pname.split('_')[0]+'_']
slice_end_j = slice_begin_j + slice_size_j
full_param_buffer = getattr(self, buffer_name) full_param_buffer = getattr(self, buffer_name)
setattr(self, pname, setattr(self, pname,
Parameter(full_param_buffer[j*split_size : (j+1)*split_size])) Parameter(full_param_buffer[slice_begin_j:slice_end_j]))
slice_begin_j = slice_end_j
break break
slice_begin = slice_end
return _NoopCat.apply(getattr(self, buffer_name), *[getattr(self, name) for name in pnames]) return _NoopCat.apply(getattr(self, buffer_name), *[getattr(self, name) for name in pnames])
......
...@@ -536,11 +536,14 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -536,11 +536,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
together with the output of the linear transformation. together with the output of the linear transformation.
Example use case: residual connection for transformer module is Example use case: residual connection for transformer module is
taken post layernorm. taken post layernorm.
parameters_split : Tuple[str, ...], default = None parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
if a tuple of strings is provided, the weight and bias parameters of the if a tuple of strings or a dict of strings to integers is provided,
module are exposed as `N` separate `torch.nn.parameter.Parameter`s each, the weight and bias parameters of the module are exposed as `N` separate
split along the first dimension, where `N` is the length of the argument `torch.nn.parameter.Parameter`s each, split along the first dimension,
and the strings contained are the names of the split parameters. where `N` is the length of the argument and the strings contained are the
names of the split parameters. In the case of a tuple, each parameter
has the same shape. In the case of a dict, the values give the
`out_features` for each projection.
zero_centered_gamma : bool, default = 'False' zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to the LayerNorm formula changes to
...@@ -607,7 +610,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -607,7 +610,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
parallel_mode: Optional[str] = None, parallel_mode: Optional[str] = None,
return_layernorm_output: bool = False, return_layernorm_output: bool = False,
skip_weight_param_allocation: bool = False, skip_weight_param_allocation: bool = False,
parameters_split: Optional[Tuple[str, ...]] = None, parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
...@@ -707,23 +710,35 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -707,23 +710,35 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.bias_tensor.zero_() self.bias_tensor.zero_()
if parameters_split is None: if parameters_split is None:
parameters_split = ("",) parameters_split = {"": self.out_features}
elif isinstance(parameters_split, tuple):
assert ( assert (
self.out_features % len(parameters_split) == 0 self.out_features % len(parameters_split) == 0
), f"Weight and bias params cannot be split into {len(parameters_split)} parts" ), f"Weight and bias params cannot be split into {len(parameters_split)} parts"
split_size = self.out_features // len(parameters_split)
split_size = self.out_features // len(parameters_split) parameters_split = {key: split_size for key in parameters_split}
elif isinstance(parameters_split, dict):
overall_split_size = sum(parameters_split.values())
assert(
self.out_features == overall_split_size
), f"Overall sum of parameters_split (={overall_split_size}) does not match "\
f"to out features (={self.out_features})"
else:
assert False, "Type of 'parameters_split' is not None, tuple or dict"
self.updated_parameters_split = parameters_split
self.weight_names = [] self.weight_names = []
self.bias_names = [] self.bias_names = []
for i, pname in enumerate(parameters_split): slice_begin = 0
for pname, slice_size in parameters_split.items():
wname = pname + "weight" wname = pname + "weight"
bname = pname + "bias" bname = pname + "bias"
slice_end = slice_begin + slice_size
self.register_parameter( self.register_parameter(
wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size]) wname, Parameter(self.weight_tensor[slice_begin:slice_end])
) )
set_tensor_model_parallel_attributes( set_tensor_model_parallel_attributes(
...@@ -735,7 +750,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -735,7 +750,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
if self.use_bias: if self.use_bias:
self.register_parameter( self.register_parameter(
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size]) bname, Parameter(self.bias_tensor[slice_begin:slice_end])
) )
if parallel_mode == "row": if parallel_mode == "row":
setattr(getattr(self, bname), "sequence_parallel", sequence_parallel) setattr(getattr(self, bname), "sequence_parallel", sequence_parallel)
...@@ -748,6 +763,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -748,6 +763,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.weight_names.append(wname) self.weight_names.append(wname)
self.bias_names.append(bname) self.bias_names.append(bname)
slice_begin = slice_end
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
...@@ -843,12 +860,14 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -843,12 +860,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
bias_tensor = ( bias_tensor = (
self.bias if self.parameters_split is None self.bias if self.parameters_split is None
else self.bias_tensor if not torch.is_grad_enabled() else self.bias_tensor if not torch.is_grad_enabled()
else self.noop_cat("bias_tensor", self.bias_names) else self.noop_cat("bias_tensor", self.bias_names,
self.updated_parameters_split)
) )
weight_tensor = ( weight_tensor = (
self.weight if self.parameters_split is None self.weight if self.parameters_split is None
else self.weight_tensor if not torch.is_grad_enabled() else self.weight_tensor if not torch.is_grad_enabled()
else self.noop_cat("weight_tensor", self.weight_names) else self.noop_cat("weight_tensor", self.weight_names,
self.updated_parameters_split)
) )
# Fetch the fp8 weights placeholders (for linear/gemm) # Fetch the fp8 weights placeholders (for linear/gemm)
......
...@@ -461,11 +461,14 @@ class Linear(TransformerEngineBaseModule): ...@@ -461,11 +461,14 @@ class Linear(TransformerEngineBaseModule):
init_method : Callable, default = `None` init_method : Callable, default = `None`
used for initializing weights in the following way: `init_method(weight)`. used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
parameters_split : Tuple[str, ...], default = None parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
if a tuple of strings is provided, the weight and bias parameters of the if a tuple of strings or a dict of strings to integers is provided,
module are exposed as `N` separate `torch.nn.parameter.Parameter`s each, the weight and bias parameters of the module are exposed as `N` separate
split along the first dimension, where `N` is the length of the argument `torch.nn.parameter.Parameter`s each, split along the first dimension,
and the strings contained are the names of the split parameters. where `N` is the length of the argument and the strings contained are the
names of the split parameters. In the case of a tuple, each parameter
has the same shape. In the case of a dict, the values give the
`out_features` for each projection.
device : Union[torch.device, str], default = "cuda" device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
...@@ -522,7 +525,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -522,7 +525,7 @@ class Linear(TransformerEngineBaseModule):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
parallel_mode: Optional[str] = None, parallel_mode: Optional[str] = None,
skip_weight_param_allocation: bool = False, skip_weight_param_allocation: bool = False,
parameters_split: Optional[Tuple[str, ...]] = None, parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
ub_split_rs: bool = False, ub_split_rs: bool = False,
ub_split_ag: bool = False, ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
...@@ -598,23 +601,35 @@ class Linear(TransformerEngineBaseModule): ...@@ -598,23 +601,35 @@ class Linear(TransformerEngineBaseModule):
self.bias_tensor.zero_() self.bias_tensor.zero_()
if parameters_split is None: if parameters_split is None:
parameters_split = ("",) parameters_split = {"": self.out_features}
elif isinstance(parameters_split, tuple):
assert ( assert (
self.out_features % len(parameters_split) == 0 self.out_features % len(parameters_split) == 0
), f"Weight and bias params cannot be split into {len(parameters_split)} parts" ), f"Weight and bias params cannot be split into {len(parameters_split)} parts"
split_size = self.out_features // len(parameters_split)
split_size = self.out_features // len(parameters_split) parameters_split = {key: split_size for key in parameters_split}
elif isinstance(parameters_split, dict):
overall_split_size = sum(parameters_split.values())
assert(
self.out_features == overall_split_size
), f"Overall sum of parameters_split (={overall_split_size}) does not match "\
f"to out features (={self.out_features})"
else:
assert False, "Type of 'parameters_split' is not None, tuple or dict"
self.updated_parameters_split = parameters_split
self.weight_names = [] self.weight_names = []
self.bias_names = [] self.bias_names = []
for i, pname in enumerate(parameters_split): slice_begin = 0
for pname, slice_size in parameters_split.items():
wname = pname + "weight" wname = pname + "weight"
bname = pname + "bias" bname = pname + "bias"
slice_end = slice_begin + slice_size
self.register_parameter( self.register_parameter(
wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size]) wname, Parameter(self.weight_tensor[slice_begin:slice_end])
) )
set_tensor_model_parallel_attributes( set_tensor_model_parallel_attributes(
...@@ -626,7 +641,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -626,7 +641,7 @@ class Linear(TransformerEngineBaseModule):
if self.use_bias: if self.use_bias:
self.register_parameter( self.register_parameter(
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size]) bname, Parameter(self.bias_tensor[slice_begin:slice_end])
) )
if parallel_mode == "row": if parallel_mode == "row":
setattr(getattr(self, bname), "sequence_parallel", sequence_parallel) setattr(getattr(self, bname), "sequence_parallel", sequence_parallel)
...@@ -639,6 +654,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -639,6 +654,8 @@ class Linear(TransformerEngineBaseModule):
self.weight_names.append(wname) self.weight_names.append(wname)
self.bias_names.append(bname) self.bias_names.append(bname)
slice_begin = slice_end
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
# For RPL, bias has to be added after TP collectives # For RPL, bias has to be added after TP collectives
...@@ -717,12 +734,14 @@ class Linear(TransformerEngineBaseModule): ...@@ -717,12 +734,14 @@ class Linear(TransformerEngineBaseModule):
bias_tensor = ( bias_tensor = (
self.bias if self.parameters_split is None self.bias if self.parameters_split is None
else self.bias_tensor if not torch.is_grad_enabled() else self.bias_tensor if not torch.is_grad_enabled()
else self.noop_cat("bias_tensor", self.bias_names) else self.noop_cat("bias_tensor", self.bias_names,
self.updated_parameters_split)
) )
weight_tensor = ( weight_tensor = (
self.weight if self.parameters_split is None self.weight if self.parameters_split is None
else self.weight_tensor if not torch.is_grad_enabled() else self.weight_tensor if not torch.is_grad_enabled()
else self.noop_cat("weight_tensor", self.weight_names) else self.noop_cat("weight_tensor", self.weight_names,
self.updated_parameters_split)
) )
# Fetch the fp8 weights placeholders (for linear/gemm) # Fetch the fp8 weights placeholders (for linear/gemm)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment