Unverified Commit 04490337 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

QKV parameters unfused path fixes and optimization (#66)



* Bug fixes from PR 22
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add FP8 tests to ci
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Better QKV parameter fusion
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* small fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* keep original param for unfused case to retain externally set attrs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* lint fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix ONNX exports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* improve arg naming
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* No need to set data pointers
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Assert memory loc in NoopCat
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Handle case of different memory in param and buffer
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix assert always true
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Reassign params memory to avoid more concats
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 78b4e933
......@@ -7,7 +7,7 @@ import os
import pickle
import warnings
from abc import ABC, abstractmethod
from typing import Union, Optional, Callable, Tuple, Dict, Any, Mapping
from typing import Union, Optional, Callable, Tuple, Dict, Any, Mapping, List
from functools import partial
from contextlib import contextmanager
......@@ -135,6 +135,42 @@ def _prepare_backward(fp8: bool,
delete_key_from_amax_buffer(forward=False)
class _NoopCat(torch.autograd.Function):
"""This class is a no-op replacement for `torch.cat`."""
@staticmethod
def forward(ctx,
full_param_buffer: torch.Tensor,
*params_split: Tuple[torch.Tensor, ...],
) -> torch.Tensor:
assert not full_param_buffer.requires_grad, "Buffers should not require gradient"
assert (
full_param_buffer.shape[0] % len(params_split) == 0
), "Dimensions not compatible for concatenation"
param_temp = full_param_buffer.new()
param_temp.set_(full_param_buffer.storage(),
full_param_buffer.storage_offset(),
full_param_buffer.size(),
full_param_buffer.stride())
param_temp.requires_grad = True
ctx.save_for_backward(full_param_buffer, *params_split)
return param_temp
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
full_param_buffer, *params_split = ctx.saved_tensors
split_size = full_param_buffer.shape[0] // len(params_split)
grads = []
for i, _ in enumerate(params_split):
grads.append(grad_output[i * split_size : (i+1) * split_size])
return None, *grads
class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module."""
......@@ -572,6 +608,29 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return grad_output_mat, grad_output_c, grad_output_t, grad_bias
def noop_cat(self, buffer_name: str, pnames: List[str]) -> torch.Tensor:
"""No-op replacement of `torch.cat`. The buffer and split parameters must occupy
the same memory region. If this is not the case, then the split parameters
are concatenated and the buffer is overwritten. The parameters' memory is then
re-assigned to point to the buffer to avoid subsequent concatenations.
"""
assert hasattr(self, buffer_name), f"No buffer named {buffer_name}"
full_param_buffer = getattr(self, buffer_name)
split_size = full_param_buffer.shape[0] // len(pnames)
params = [getattr(self, name) for name in pnames]
for i, p in enumerate(params):
if p.data.data_ptr() != full_param_buffer[i*split_size : (i+1)*split_size].data_ptr():
with torch.no_grad():
setattr(self, buffer_name, torch.cat(params))
for j, pname in enumerate(pnames):
full_param_buffer = getattr(self, buffer_name)
setattr(self, pname,
Parameter(full_param_buffer[j*split_size : (j+1)*split_size]))
break
return _NoopCat.apply(getattr(self, buffer_name), *[getattr(self, name) for name in pnames])
@abstractmethod
def forward(self):
"""Needs override."""
......@@ -993,6 +1052,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
together with the output of the linear transformation.
Example use case: residual connection for transformer module is
taken post layernorm.
parameters_split : Tuple[str, ...], default = None
if a tuple of strings is provided, the weight and bias parameters of the
module are exposed as `N` separate `torch.nn.parameter.Parameter`s each,
split along the first dimension, where `N` is the length of the argument
and the strings contained are the names of the split parameters.
Parallelism parameters
----------------------
......@@ -1047,6 +1111,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
parallel_mode: Optional[str] = None,
return_layernorm_output: bool = False,
skip_weight_param_allocation: bool = False,
parameters_split: Optional[Tuple[str, ...]] = None,
) -> None:
super().__init__()
self.in_features = in_features
......@@ -1055,7 +1120,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.use_bias = bias
self.return_bias = return_bias
self.return_layernorm_output = return_layernorm_output
self.skip_weight_param_allocation = skip_weight_param_allocation
self.parameters_split = parameters_split
if tp_group is None:
self.tp_size = tp_size
......@@ -1101,17 +1166,16 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.reset_layer_norm_parameters()
if not skip_weight_param_allocation:
self.weight = Parameter(
torch.empty(
self.out_features,
self.in_features,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
self.register_buffer("weight_tensor",
torch.empty(
self.out_features,
self.in_features,
device=torch.cuda.current_device(),
dtype=params_dtype),
persistent=False)
initialize_affine_weight_gpu(
self.weight,
self.weight_tensor,
init_method,
get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0,
......@@ -1119,20 +1183,59 @@ class LayerNormLinear(TransformerEngineBaseModule):
)
if self.use_bias or self.return_bias:
self.bias = Parameter(
torch.empty(
self.out_features,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
if self.parallel_mode == "column":
set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
self.register_buffer("bias_tensor",
torch.empty(
self.out_features,
device=torch.cuda.current_device(),
dtype=params_dtype),
persistent=False)
else:
self.register_buffer("bias", torch.Tensor().type(params_dtype), persistent=False)
self.register_buffer(
"bias_tensor", torch.Tensor().type(params_dtype), persistent=False
)
with torch.no_grad():
self.bias.zero_()
self.bias_tensor.zero_()
if parameters_split is None:
parameters_split = ("",)
assert (
self.out_features % len(parameters_split) == 0
), f"Weight and bias params cannot be split into {len(parameters_split)} parts"
split_size = self.out_features // len(parameters_split)
self.weight_names = []
self.bias_names = []
for i, pname in enumerate(parameters_split):
wname = pname + "weight"
bname = pname + "bias"
self.register_parameter(
wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size])
)
set_tensor_model_parallel_attributes(
tensor=getattr(self, wname),
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)
if self.use_bias or self.return_bias:
self.register_parameter(
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
)
else:
self.register_buffer(bname, torch.Tensor().type(params_dtype), persistent=False)
if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
self.weight_names.append(wname)
self.bias_names.append(bname)
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
......@@ -1193,7 +1296,18 @@ class LayerNormLinear(TransformerEngineBaseModule):
"""
with self.prepare_forward(inp, is_first_microbatch) as inp:
bias_tensor = bias if bias is not None else self.bias
bias_tensor = (
bias if bias is not None
else self.bias if self.parameters_split is None
else self.bias_tensor if not self.training
else self.noop_cat("bias_tensor", self.bias_names)
)
weight_tensor = (
weight if weight is not None
else self.weight if self.parameters_split is None
else self.weight_tensor if not self.training
else self.noop_cat("weight_tensor", self.weight_names)
)
if self.training:
fwd_fn = _LayerNormLinear.apply
......@@ -1205,7 +1319,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight if weight is not None else self.weight,
weight_tensor,
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
bias_tensor,
......@@ -1586,6 +1700,11 @@ class Linear(TransformerEngineBaseModule):
init_method : Callable, default = `None`
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
parameters_split : Tuple[str, ...], default = None
if a tuple of strings is provided, the weight and bias parameters of the
module are exposed as `N` separate `torch.nn.parameter.Parameter`s each,
split along the first dimension, where `N` is the length of the argument
and the strings contained are the names of the split parameters.
Parallelism parameters
----------------------
......@@ -1641,6 +1760,7 @@ class Linear(TransformerEngineBaseModule):
params_dtype: torch.dtype = torch.float32,
parallel_mode: Optional[str] = None,
skip_weight_param_allocation: bool = False,
parameters_split: Optional[Tuple[str, ...]] = None,
) -> None:
super().__init__()
self.in_features = in_features
......@@ -1648,7 +1768,7 @@ class Linear(TransformerEngineBaseModule):
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.use_bias = bias
self.return_bias = return_bias
self.skip_weight_param_allocation = skip_weight_param_allocation
self.parameters_split = parameters_split
if tp_group is None:
self.tp_size = tp_size
......@@ -1675,17 +1795,16 @@ class Linear(TransformerEngineBaseModule):
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
if not skip_weight_param_allocation:
self.weight = Parameter(
torch.empty(
self.out_features,
self.in_features,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
self.register_buffer("weight_tensor",
torch.empty(
self.out_features,
self.in_features,
device=torch.cuda.current_device(),
dtype=params_dtype),
persistent=False)
initialize_affine_weight_gpu(
self.weight,
self.weight_tensor,
init_method,
get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0,
......@@ -1693,20 +1812,59 @@ class Linear(TransformerEngineBaseModule):
)
if self.use_bias or self.return_bias:
self.bias = Parameter(
torch.empty(
self.out_features,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
if self.parallel_mode == "column":
set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
self.register_buffer("bias_tensor",
torch.empty(
self.out_features,
device=torch.cuda.current_device(),
dtype=params_dtype),
persistent=False)
else:
self.register_buffer("bias", torch.Tensor().type(params_dtype), persistent=False)
self.register_buffer(
"bias_tensor", torch.Tensor().type(params_dtype), persistent=False
)
with torch.no_grad():
self.bias.zero_()
self.bias_tensor.zero_()
if parameters_split is None:
parameters_split = ("",)
assert (
self.out_features % len(parameters_split) == 0
), f"Weight and bias params cannot be split into {len(parameters_split)} parts"
split_size = self.out_features // len(parameters_split)
self.weight_names = []
self.bias_names = []
for i, pname in enumerate(parameters_split):
wname = pname + "weight"
bname = pname + "bias"
self.register_parameter(
wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size])
)
set_tensor_model_parallel_attributes(
tensor=getattr(self, wname),
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)
if self.use_bias or self.return_bias:
self.register_parameter(
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
)
else:
self.register_buffer(bname, torch.Tensor().type(params_dtype), persistent=False)
if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
self.weight_names.append(wname)
self.bias_names.append(bname)
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
......@@ -1755,7 +1913,18 @@ class Linear(TransformerEngineBaseModule):
"""
with self.prepare_forward(inp, is_first_microbatch) as inp:
bias_tensor = bias if bias is not None else self.bias
bias_tensor = (
bias if bias is not None
else self.bias if self.parameters_split is None
else self.bias_tensor if not self.training
else self.noop_cat("bias_tensor", self.bias_names)
)
weight_tensor = (
weight if weight is not None
else self.weight if self.parameters_split is None
else self.weight_tensor if not self.training
else self.noop_cat("weight_tensor", self.weight_names)
)
if self.training:
linear_fn = _Linear.apply
......@@ -1764,7 +1933,7 @@ class Linear(TransformerEngineBaseModule):
linear_fn = _Linear.forward
args = [None]
args += (
weight if weight is not None else self.weight,
weight_tensor,
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
inp,
......
......@@ -9,7 +9,6 @@ from contextlib import nullcontext
from typing import Any, Callable, Optional, Tuple, Union
import torch
from torch.nn.parameter import Parameter
from transformer_engine.pytorch import LayerNormLinear, Linear, LayerNormMLP, LayerNorm
from transformer_engine.pytorch.jit import (
......@@ -36,8 +35,6 @@ from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import (
get_distributed_world_size,
checkpoint,
initialize_affine_weight_gpu,
set_tensor_model_parallel_attributes,
)
......@@ -261,7 +258,6 @@ class MultiHeadAttention(torch.nn.Module):
self.return_layernorm_output = return_layernorm_output
self.params_dtype = params_dtype
self.init_method = init_method
self.fuse_qkv_params = fuse_qkv_params
assert (
attention_type in AttnTypes
......@@ -284,13 +280,6 @@ class MultiHeadAttention(torch.nn.Module):
}
qkv_parallel_mode = "column" if set_parallel_mode else None
if not fuse_qkv_params:
self.set_qkv_params(
hidden_size,
3 * hidden_size,
parallel_mode=qkv_parallel_mode,
bias=True,
)
if self.attention_type == "self":
if self.input_layernorm:
......@@ -303,7 +292,7 @@ class MultiHeadAttention(torch.nn.Module):
return_bias=False,
parallel_mode=qkv_parallel_mode,
return_layernorm_output=return_layernorm_output,
skip_weight_param_allocation=not fuse_qkv_params,
parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None,
**common_gemm_kwargs,
)
else:
......@@ -314,7 +303,7 @@ class MultiHeadAttention(torch.nn.Module):
bias=True,
return_bias=False,
parallel_mode=qkv_parallel_mode,
skip_weight_param_allocation=not fuse_qkv_params,
parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None,
**common_gemm_kwargs,
)
else:
......@@ -328,7 +317,6 @@ class MultiHeadAttention(torch.nn.Module):
return_bias=False,
parallel_mode=qkv_parallel_mode,
return_layernorm_output=return_layernorm_output,
skip_weight_param_allocation=not fuse_qkv_params,
**common_gemm_kwargs,
)
else:
......@@ -339,7 +327,6 @@ class MultiHeadAttention(torch.nn.Module):
bias=True,
return_bias=False,
parallel_mode=qkv_parallel_mode,
skip_weight_param_allocation=not fuse_qkv_params,
**common_gemm_kwargs,
)
self.key_value = Linear(
......@@ -349,7 +336,7 @@ class MultiHeadAttention(torch.nn.Module):
bias=True,
return_bias=False,
parallel_mode=qkv_parallel_mode,
skip_weight_param_allocation=not fuse_qkv_params,
parameters_split=("key_", "value_") if not fuse_qkv_params else None,
**common_gemm_kwargs,
)
......@@ -378,88 +365,6 @@ class MultiHeadAttention(torch.nn.Module):
**common_gemm_kwargs,
)
def set_qkv_params(
self,
in_features: torch.Tensor,
out_features: torch.Tensor,
parallel_mode: Optional[bool] = None,
bias: bool = False,
) -> None:
"""Initialize separate Parameters for query, key, and value tensors."""
if parallel_mode == "column":
out_features = divide(out_features, self.tp_size)
elif parallel_mode == "row":
in_features = divide(in_features, self.tp_size)
assert (
out_features % 3 == 0
), f"3 way QKV split with dimension {out_features} not possible."
weight_tensor = torch.empty(
out_features,
in_features,
device=torch.cuda.current_device(),
dtype=self.params_dtype,
)
initialize_affine_weight_gpu(
weight_tensor,
self.init_method,
self.get_rng_state_tracker,
partition_dim=1 if parallel_mode == "row" else 0,
stride=1,
)
qkv_first_dim = out_features // 3
self.query = Parameter(weight_tensor[0:qkv_first_dim, :])
self.key = Parameter(weight_tensor[qkv_first_dim : 2 * qkv_first_dim, :])
self.value = Parameter(weight_tensor[2 * qkv_first_dim : 3 * qkv_first_dim, :])
set_tensor_model_parallel_attributes(
tensor=self.query,
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)
set_tensor_model_parallel_attributes(
tensor=self.key,
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)
set_tensor_model_parallel_attributes(
tensor=self.value,
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)
if bias:
bias_tensor = torch.empty(
out_features,
device=torch.cuda.current_device(),
dtype=self.params_dtype,
)
self.query_bias = Parameter(bias_tensor[0:qkv_first_dim])
self.key_bias = Parameter(bias_tensor[qkv_first_dim : 2 * qkv_first_dim])
self.value_bias = Parameter(
bias_tensor[2 * qkv_first_dim : 3 * qkv_first_dim]
)
if parallel_mode == "column":
set_tensor_model_parallel_attributes(self.query_bias, True, 0, 1)
set_tensor_model_parallel_attributes(self.key_bias, True, 0, 1)
set_tensor_model_parallel_attributes(self.value_bias, True, 0, 1)
else:
self.register_buffer("query_bias", torch.Tensor(), persistent=False)
self.register_buffer("key_bias", torch.Tensor(), persistent=False)
self.register_buffer("value_bias", torch.Tensor(), persistent=False)
with torch.no_grad():
self.query_bias.zero_()
self.key_bias.zero_()
self.value_bias.zero_()
def _checkpointed_core_attention_forward(
self,
query_layer: torch.Tensor,
......@@ -554,23 +459,10 @@ class MultiHeadAttention(torch.nn.Module):
# =====================
if self.attention_type == "self":
qkv_weight = (
torch.cat((self.query, self.key, self.value))
if not self.fuse_qkv_params
else None
)
qkv_bias = (
torch.cat((self.query_bias, self.key_bias, self.value_bias))
if not self.fuse_qkv_params
else None
)
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
if self.input_layernorm:
layernorm_qkv_outputs = self.layernorm_qkv(
hidden_states,
weight=qkv_weight,
bias=qkv_bias,
is_first_microbatch=is_first_microbatch,
)
if self.return_layernorm_output:
......@@ -580,8 +472,6 @@ class MultiHeadAttention(torch.nn.Module):
else:
mixed_x_layer = self.qkv(
hidden_states,
weight=qkv_weight,
bias=qkv_bias,
is_first_microbatch=is_first_microbatch,
)
......@@ -597,20 +487,9 @@ class MultiHeadAttention(torch.nn.Module):
mixed_x_layer, 3
)
else:
kv_weight = (
torch.cat((self.key, self.value)) if not self.fuse_qkv_params else None
)
kv_bias = (
torch.cat((self.key_bias, self.value_bias))
if not self.fuse_qkv_params
else None
)
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer = self.key_value(
encoder_output,
weight=kv_weight,
bias=kv_bias,
is_first_microbatch=is_first_microbatch,
)
......@@ -628,8 +507,6 @@ class MultiHeadAttention(torch.nn.Module):
if self.input_layernorm:
layernorm_query_outputs = self.layernorm_query(
hidden_states,
weight=self.query,
bias=self.query_bias,
is_first_microbatch=is_first_microbatch,
)
if self.return_layernorm_output:
......@@ -639,8 +516,6 @@ class MultiHeadAttention(torch.nn.Module):
else:
query_layer = self.query_layer(
hidden_states,
weight=self.query,
bias=self.query_bias,
is_first_microbatch=is_first_microbatch,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment