Unverified Commit 7d6c1d02 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix unfused QKV params case; stack vs interleave option (#83)



* fix qkv weight unfused path
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix non FA non interleaved case
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 2417a53a
......@@ -24,7 +24,7 @@ from transformer_engine.pytorch.jit import (
from transformer_engine.pytorch.utils import (
divide,
attention_mask_func,
split_tensor_along_last_dim,
split_tensor_along_dim,
cast_if_needed,
get_default_init_method,
)
......@@ -126,11 +126,11 @@ class UnfusedDotProductAttention(torch.nn.Module):
)
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(
query_layer = query_layer.reshape(
output_size[2], output_size[0] * output_size[1], -1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk]
matmul_result = torch.empty(
......@@ -171,7 +171,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
)
# change view [sk, b * np, hn]
value_layer = value_layer.view(
value_layer = value_layer.reshape(
value_layer.size(0), output_size[0] * output_size[1], -1
)
......@@ -504,6 +504,7 @@ class MultiHeadAttention(torch.nn.Module):
set_parallel_mode: bool = False,
fuse_qkv_params: bool = False,
zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True,
) -> None:
super().__init__()
self.layer_number = (layer_number,)
......@@ -515,6 +516,10 @@ class MultiHeadAttention(torch.nn.Module):
self.params_dtype = params_dtype
self.init_method = init_method
if not fuse_qkv_params:
qkv_weight_interleaved = False
self.qkv_weight_interleaved = qkv_weight_interleaved
assert (
attention_type in AttnTypes
), f"attention_type {attention_type} not supported"
......@@ -703,16 +708,28 @@ class MultiHeadAttention(torch.nn.Module):
is_first_microbatch=is_first_microbatch,
)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
if self.qkv_weight_interleaved:
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
# split along last dimension
split_dim = -1
else:
# [sq, b, (np * 3 * hn)] --> [sq, b, 3 * np, hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
3 * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
# split along second last dimension
split_dim = -2
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
query_layer, key_layer, value_layer = split_tensor_along_last_dim(
mixed_x_layer, 3
# mixed_x_layer --> 3 [sq, b, np, hn]
query_layer, key_layer, value_layer = split_tensor_along_dim(
mixed_x_layer, split_dim, 3
)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
......@@ -721,15 +738,27 @@ class MultiHeadAttention(torch.nn.Module):
is_first_microbatch=is_first_microbatch,
)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head,
)
if self.qkv_weight_interleaved:
# [sq, b, (np * 2 * hn)] --> [sq, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head,
)
# split along last dimension
split_dim = -1
else:
# [sq, b, (np * 2 * hn)] --> [sq, b, 2 * np, hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + (
2 * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
# split along second last dimension
split_dim = -2
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer, value_layer) = split_tensor_along_last_dim(mixed_kv_layer, 2)
# mixed_kv_layer --> 2 [sk, b, np, hn]
key_layer, value_layer = split_tensor_along_dim(mixed_kv_layer, split_dim, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
if self.input_layernorm:
......@@ -863,7 +892,12 @@ class TransformerLayer(torch.nn.Module):
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
qkv_weight_interleaved : bool, default = `True`
if set to `False`, the QKV weight is interpreted as a concatenation of
query, key, and value weights along the `0th` dimension. The default
interpretation is that the individual `q`, `k`, and `v` weights for each
attention head are interleaved. This parameter is set to `False` when
using :attr:`fuse_qkv_params=False`.
Parallelism parameters
----------------------
set_parallel_mode : bool, default = `False`
......@@ -938,6 +972,7 @@ class TransformerLayer(torch.nn.Module):
set_parallel_mode: bool = False,
fuse_qkv_params: bool = False,
zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True,
) -> None:
super().__init__()
......@@ -958,6 +993,9 @@ class TransformerLayer(torch.nn.Module):
not fuse_wgrad_accumulation
), "Gradient accumulation fusion requires single QKV parameter."
if not fuse_qkv_params:
qkv_weight_interleaved = False
self.kv_channels = (
kv_channels if kv_channels else (hidden_size // num_attention_heads)
)
......@@ -995,6 +1033,7 @@ class TransformerLayer(torch.nn.Module):
"set_parallel_mode": set_parallel_mode,
"fuse_qkv_params": fuse_qkv_params,
"zero_centered_gamma": zero_centered_gamma,
"qkv_weight_interleaved" : qkv_weight_interleaved,
}
self.self_attention = MultiHeadAttention(
......
......@@ -78,8 +78,8 @@ def divide(numerator: int, denominator: int) -> int:
return numerator // denominator
def split_tensor_along_last_dim(
tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False
def split_tensor_along_dim(
tensor: torch.Tensor, dim: int, num_partitions: int, contiguous_split_chunks: bool = False
) -> Tuple[torch.Tensor, ...]:
"""Split a tensor along its last dimension.
Arguments:
......@@ -89,10 +89,9 @@ def split_tensor_along_last_dim(
in memory.
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
split_size = divide(tensor.size()[dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
tensor_list = torch.split(tensor, split_size, dim=dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment