Unverified Commit 36047fd7 authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

make TransformerLayer accept a `bshd` or `sbhd` tensor format (#557)



* make TransformerLayer accept a `bshd` or `sbhd` tensor format
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Fixes from feedback
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* more feedback fixes
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* remove incorrect info from docstring
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix from feedback
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
parent 434d58fa
...@@ -666,10 +666,10 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f ...@@ -666,10 +666,10 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_1_2", "te_2_0"]) @pytest.mark.parametrize("model", ["te_1_2", "te_2_0"])
def test_te_layer_misc(dtype, model_configs, model): @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd"])
def test_te_layer_misc(dtype, model_configs, model, qkv_format):
"""Test TransformerLayer module with miscellanous settings""" """Test TransformerLayer module with miscellanous settings"""
ckpt_attn = True ckpt_attn = True
qkv_format = "bshd"
fused_qkv_params = True fused_qkv_params = True
RoPE = True RoPE = True
test_transformer_layer(dtype, model_configs, model, test_transformer_layer(dtype, model_configs, model,
...@@ -705,7 +705,7 @@ def _run_transformer_layer( ...@@ -705,7 +705,7 @@ def _run_transformer_layer(
config: ModelConfig, config: ModelConfig,
backend: str, backend: str,
ckpt_attn: bool, ckpt_attn: bool,
qkv_layout: str, qkv_format: str,
workspace_opt: bool, workspace_opt: bool,
fused_qkv_params: bool, fused_qkv_params: bool,
RoPE: bool, RoPE: bool,
...@@ -724,6 +724,10 @@ def _run_transformer_layer( ...@@ -724,6 +724,10 @@ def _run_transformer_layer(
# Create input tensor # Create input tensor
inp = torch.randn(config.max_seqlen_q, config.batch_size, config.hidden_size, inp = torch.randn(config.max_seqlen_q, config.batch_size, config.hidden_size,
dtype=dtype, device="cuda", requires_grad = True) dtype=dtype, device="cuda", requires_grad = True)
# In case the format to be tested is batch-first, need to transpose the
# input tensor.
if qkv_format == "bshd":
inp = inp.transpose(0,1)
# Create seqlens # Create seqlens
if "padding" in config.attn_mask_type: if "padding" in config.attn_mask_type:
...@@ -815,6 +819,7 @@ def _run_transformer_layer( ...@@ -815,6 +819,7 @@ def _run_transformer_layer(
qkv_weight_interleaved=False, qkv_weight_interleaved=False,
ub_tp_comm_overlap=False, ub_tp_comm_overlap=False,
bias=True, bias=True,
attn_input_format=qkv_format,
) )
.to(dtype=dtype, device="cuda") .to(dtype=dtype, device="cuda")
) )
......
...@@ -1197,3 +1197,80 @@ def test_gpt_fp8_parameters(dtype, bs, model): ...@@ -1197,3 +1197,80 @@ def test_gpt_fp8_parameters(dtype, bs, model):
outputs = _test_gpt_fp8_parameters(bs, dtype, config, False) outputs = _test_gpt_fp8_parameters(bs, dtype, config, False)
outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True) outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True)
assert_all_equal(outputs, outputs_fp8_params) assert_all_equal(outputs, outputs_fp8_params)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_transformer_layer_hidden_states_format(dtype, bs, model):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
# Set `torch.manual_seed` to make sure the weights are identical to the
# other layer. Set `*dropout` values to 0 to make sure the forward pass
# is identical to the other layer.
torch.manual_seed(0)
block_sbhd = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0,
attention_dropout=0,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
hidden_states_format="sbhd"
)
.to(dtype=dtype)
.cuda()
)
# Set `torch.manual_seed` to make sure the weights are identical to the
# other layer. Set `*dropout` values to 0 to make sure the forward pass
# is identical to the other layer.
torch.manual_seed(0)
block_bshd = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0,
attention_dropout=0,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
hidden_states_format="bshd"
)
.to(dtype=dtype)
.cuda()
)
for (n1, p1), (n2, p2) in zip(block_bshd.named_parameters(), block_sbhd.named_parameters()):
assert torch.all(torch.eq(p1, p2)), f"{n1}, {n2} not identical"
x_sbhd = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).to(dtype).cuda()
x_bshd = x_sbhd.transpose(0,1).contiguous()
# To make sure forward is also identical (just in case some module decides
# to act fancy)
torch.manual_seed(0)
y_sbhd = block_sbhd(x_sbhd)
# To make sure forward is also identical (just in case some module decides
# to act fancy)
torch.manual_seed(0)
y_bshd = block_bshd(x_bshd)
assert_all_equal([y_bshd], [y_sbhd.transpose(0,1).contiguous()])
...@@ -1034,11 +1034,34 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor: ...@@ -1034,11 +1034,34 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor, tensor_format: str = "sbhd") -> torch.Tensor:
""" """
input tensor t is of shape [seq_length, ..., dim] Parameters
rotary positional embeding tensor `freqs` is of shape [seq_length, ..., dim] ----------
t: torch.Tensor
input tensor on which rotary positional embedding will be applied
freqs: torch.Tensor
rotary positional embeding tensor `freqs` is of shape
`[seq_length, ..., dim]`
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
of shape `[seq, bs, ...]`.
""" """
assert tensor_format in ("sbhd", "bshd"),("Only formats `sbhd` or `bshd` "
"are supported for input tensor "
"`t`.")
max_seq_len = freqs.shape[0]
cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]
# Only apply the rotary embeddings up to the sequence length of the running
# input.
assert cur_seq_len <= max_seq_len, (f"Rotary Embeddings only supported "
"upto {max_seq_len} sequence length!")
freqs = freqs[:cur_seq_len].to(t.dtype)
if tensor_format == "bshd":
freqs = freqs.transpose(0,1) # [seq, 1, 1, dim] -> [1, seq, 1, dim]
rot_dim = freqs.shape[-1] rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:] t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
...@@ -2821,6 +2844,14 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2821,6 +2844,14 @@ class MultiheadAttention(torch.nn.Module):
The device on which the parameters of the model will allocated. It is the user's The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
forward pass. forward pass.
qkv_format: str, default = `sbhd`
dimension format for `query_layer`, `key_layer` and `value_layer`,
{`sbhd`, `bshd`}. `s` stands for the sequence length, `b` batch size,
`h` the number of heads and `d` head size. `sbhd` and `bshd` formats
are used for when sequences in a batch are of equal length or padded to
equal length. Please note that these formats do not reflect how
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
For that, please use `_get_qkv_layout` to gain the layout information.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -2899,9 +2930,11 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2899,9 +2930,11 @@ class MultiheadAttention(torch.nn.Module):
bias: bool = True, bias: bool = True,
normalization: str = "LayerNorm", normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
qkv_format: str = "sbhd",
) -> None: ) -> None:
super().__init__() super().__init__()
self.qkv_format = qkv_format
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.window_size = window_size self.window_size = window_size
self.window_size = check_set_window_size(attn_mask_type, self.window_size) self.window_size = check_set_window_size(attn_mask_type, self.window_size)
...@@ -3045,6 +3078,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -3045,6 +3078,7 @@ class MultiheadAttention(torch.nn.Module):
kv_channels, kv_channels,
num_gqa_groups=self.num_gqa_groups, num_gqa_groups=self.num_gqa_groups,
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
qkv_format=self.qkv_format,
tp_size=tp_size, tp_size=tp_size,
get_rng_state_tracker=get_rng_state_tracker, get_rng_state_tracker=get_rng_state_tracker,
sequence_parallel=sequence_parallel, sequence_parallel=sequence_parallel,
...@@ -3398,14 +3432,14 @@ class MultiheadAttention(torch.nn.Module): ...@@ -3398,14 +3432,14 @@ class MultiheadAttention(torch.nn.Module):
# apply relative positional encoding (rotary embedding) # apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None: if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb q_pos_emb, k_pos_emb = rotary_pos_emb
query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb) query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb) key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format)
context_layer = self.core_attention( context_layer = self.core_attention(
query_layer, query_layer,
key_layer, key_layer,
value_layer, value_layer,
qkv_format='sbhd', qkv_format=self.qkv_format,
cu_seqlens_q=None, cu_seqlens_q=None,
cu_seqlens_kv=None, cu_seqlens_kv=None,
attention_mask=attention_mask, attention_mask=attention_mask,
......
...@@ -168,6 +168,14 @@ class TransformerLayer(torch.nn.Module): ...@@ -168,6 +168,14 @@ class TransformerLayer(torch.nn.Module):
The device on which the parameters of the model will allocated. It is the user's The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
forward pass. forward pass.
attn_input_format: {'sbhd', 'bshd'}, default = 'sbhd'
This controls whether the dimensions of the
intermediate hidden states is 'batch first' ('bshd') or
'sequence first' ('sbhd'). `s` stands for the sequence
length, `b` batch size, `h` the number of heads, `d`
head size. Note that these formats are very closely
related to the `qkv_format` in the `MultiHeadAttention`
and `DotProductAttention` modules.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -253,6 +261,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -253,6 +261,7 @@ class TransformerLayer(torch.nn.Module):
activation: str = 'gelu', activation: str = 'gelu',
normalization: str = "LayerNorm", normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
attn_input_format: str = "sbhd",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -331,6 +340,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -331,6 +340,8 @@ class TransformerLayer(torch.nn.Module):
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.attn_input_format = attn_input_format
attention_args = ( attention_args = (
hidden_size, hidden_size,
num_attention_heads, num_attention_heads,
...@@ -360,6 +371,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -360,6 +371,7 @@ class TransformerLayer(torch.nn.Module):
"ub_split_rs" : ub_split_rs, "ub_split_rs" : ub_split_rs,
"ub_atomic_gemm_rs" : ub_atomic_gemm_rs, "ub_atomic_gemm_rs" : ub_atomic_gemm_rs,
"ub_atomic_gemm_ag" : ub_atomic_gemm_ag, "ub_atomic_gemm_ag" : ub_atomic_gemm_ag,
"qkv_format" : self.attn_input_format,
} }
self.self_attention = MultiheadAttention( self.self_attention = MultiheadAttention(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment