"git@developer.sourcefind.cn:gaoqiong/pybind11.git" did not exist on "0f82a0b01482037554689dd945c3b10fcc9792cb"
Unverified Commit 7d6c1d02 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix unfused QKV params case; stack vs interleave option (#83)



* fix qkv weight unfused path
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix non FA non interleaved case
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 2417a53a
...@@ -24,7 +24,7 @@ from transformer_engine.pytorch.jit import ( ...@@ -24,7 +24,7 @@ from transformer_engine.pytorch.jit import (
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
divide, divide,
attention_mask_func, attention_mask_func,
split_tensor_along_last_dim, split_tensor_along_dim,
cast_if_needed, cast_if_needed,
get_default_init_method, get_default_init_method,
) )
...@@ -126,11 +126,11 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -126,11 +126,11 @@ class UnfusedDotProductAttention(torch.nn.Module):
) )
# [sq, b, np, hn] -> [sq, b * np, hn] # [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view( query_layer = query_layer.reshape(
output_size[2], output_size[0] * output_size[1], -1 output_size[2], output_size[0] * output_size[1], -1
) )
# [sk, b, np, hn] -> [sk, b * np, hn] # [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk] # preallocting result tensor: [b * np, sq, sk]
matmul_result = torch.empty( matmul_result = torch.empty(
...@@ -171,7 +171,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -171,7 +171,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
) )
# change view [sk, b * np, hn] # change view [sk, b * np, hn]
value_layer = value_layer.view( value_layer = value_layer.reshape(
value_layer.size(0), output_size[0] * output_size[1], -1 value_layer.size(0), output_size[0] * output_size[1], -1
) )
...@@ -504,6 +504,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -504,6 +504,7 @@ class MultiHeadAttention(torch.nn.Module):
set_parallel_mode: bool = False, set_parallel_mode: bool = False,
fuse_qkv_params: bool = False, fuse_qkv_params: bool = False,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_number = (layer_number,) self.layer_number = (layer_number,)
...@@ -515,6 +516,10 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -515,6 +516,10 @@ class MultiHeadAttention(torch.nn.Module):
self.params_dtype = params_dtype self.params_dtype = params_dtype
self.init_method = init_method self.init_method = init_method
if not fuse_qkv_params:
qkv_weight_interleaved = False
self.qkv_weight_interleaved = qkv_weight_interleaved
assert ( assert (
attention_type in AttnTypes attention_type in AttnTypes
), f"attention_type {attention_type} not supported" ), f"attention_type {attention_type} not supported"
...@@ -703,16 +708,28 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -703,16 +708,28 @@ class MultiHeadAttention(torch.nn.Module):
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
) )
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] if self.qkv_weight_interleaved:
new_tensor_shape = mixed_x_layer.size()[:-1] + ( # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
self.num_attention_heads_per_partition, new_tensor_shape = mixed_x_layer.size()[:-1] + (
3 * self.hidden_size_per_attention_head, self.num_attention_heads_per_partition,
) 3 * self.hidden_size_per_attention_head,
)
# split along last dimension
split_dim = -1
else:
# [sq, b, (np * 3 * hn)] --> [sq, b, 3 * np, hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
3 * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
# split along second last dimension
split_dim = -2
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] # mixed_x_layer --> 3 [sq, b, np, hn]
query_layer, key_layer, value_layer = split_tensor_along_last_dim( query_layer, key_layer, value_layer = split_tensor_along_dim(
mixed_x_layer, 3 mixed_x_layer, split_dim, 3
) )
else: else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
...@@ -721,15 +738,27 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -721,15 +738,27 @@ class MultiHeadAttention(torch.nn.Module):
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
) )
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] if self.qkv_weight_interleaved:
new_tensor_shape = mixed_kv_layer.size()[:-1] + ( # [sq, b, (np * 2 * hn)] --> [sq, b, np, 2 * hn]
self.num_attention_heads_per_partition, new_tensor_shape = mixed_kv_layer.size()[:-1] + (
2 * self.hidden_size_per_attention_head, self.num_attention_heads_per_partition,
) 2 * self.hidden_size_per_attention_head,
)
# split along last dimension
split_dim = -1
else:
# [sq, b, (np * 2 * hn)] --> [sq, b, 2 * np, hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + (
2 * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
# split along second last dimension
split_dim = -2
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] # mixed_kv_layer --> 2 [sk, b, np, hn]
(key_layer, value_layer) = split_tensor_along_last_dim(mixed_kv_layer, 2) key_layer, value_layer = split_tensor_along_dim(mixed_kv_layer, split_dim, 2)
# Attention head [sq, b, h] --> [sq, b, hp] # Attention head [sq, b, h] --> [sq, b, hp]
if self.input_layernorm: if self.input_layernorm:
...@@ -863,7 +892,12 @@ class TransformerLayer(torch.nn.Module): ...@@ -863,7 +892,12 @@ class TransformerLayer(torch.nn.Module):
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta (1 + \gamma) + \beta
qkv_weight_interleaved : bool, default = `True`
if set to `False`, the QKV weight is interpreted as a concatenation of
query, key, and value weights along the `0th` dimension. The default
interpretation is that the individual `q`, `k`, and `v` weights for each
attention head are interleaved. This parameter is set to `False` when
using :attr:`fuse_qkv_params=False`.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
set_parallel_mode : bool, default = `False` set_parallel_mode : bool, default = `False`
...@@ -938,6 +972,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -938,6 +972,7 @@ class TransformerLayer(torch.nn.Module):
set_parallel_mode: bool = False, set_parallel_mode: bool = False,
fuse_qkv_params: bool = False, fuse_qkv_params: bool = False,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -958,6 +993,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -958,6 +993,9 @@ class TransformerLayer(torch.nn.Module):
not fuse_wgrad_accumulation not fuse_wgrad_accumulation
), "Gradient accumulation fusion requires single QKV parameter." ), "Gradient accumulation fusion requires single QKV parameter."
if not fuse_qkv_params:
qkv_weight_interleaved = False
self.kv_channels = ( self.kv_channels = (
kv_channels if kv_channels else (hidden_size // num_attention_heads) kv_channels if kv_channels else (hidden_size // num_attention_heads)
) )
...@@ -995,6 +1033,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -995,6 +1033,7 @@ class TransformerLayer(torch.nn.Module):
"set_parallel_mode": set_parallel_mode, "set_parallel_mode": set_parallel_mode,
"fuse_qkv_params": fuse_qkv_params, "fuse_qkv_params": fuse_qkv_params,
"zero_centered_gamma": zero_centered_gamma, "zero_centered_gamma": zero_centered_gamma,
"qkv_weight_interleaved" : qkv_weight_interleaved,
} }
self.self_attention = MultiHeadAttention( self.self_attention = MultiHeadAttention(
......
...@@ -78,8 +78,8 @@ def divide(numerator: int, denominator: int) -> int: ...@@ -78,8 +78,8 @@ def divide(numerator: int, denominator: int) -> int:
return numerator // denominator return numerator // denominator
def split_tensor_along_last_dim( def split_tensor_along_dim(
tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False tensor: torch.Tensor, dim: int, num_partitions: int, contiguous_split_chunks: bool = False
) -> Tuple[torch.Tensor, ...]: ) -> Tuple[torch.Tensor, ...]:
"""Split a tensor along its last dimension. """Split a tensor along its last dimension.
Arguments: Arguments:
...@@ -89,10 +89,9 @@ def split_tensor_along_last_dim( ...@@ -89,10 +89,9 @@ def split_tensor_along_last_dim(
in memory. in memory.
""" """
# Get the size and dimension. # Get the size and dimension.
last_dim = tensor.dim() - 1 split_size = divide(tensor.size()[dim], num_partitions)
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split. # Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) tensor_list = torch.split(tensor, split_size, dim=dim)
# Note: torch.split does not create contiguous tensors by default. # Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks: if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list) return tuple(chunk.contiguous() for chunk in tensor_list)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment