Unverified Commit 151a0af6 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Miscellaneous fixes for attention (#1780)



* add missing args in cross-attn
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* allow thd for TELayer
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add CP note for reordering
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix wording about CP
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add modulo cpx2 requirement
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add example of token reordering
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* improve the CP docstring
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* tweak CP wording
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* test thd TELayer
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add enc_dec_kv for decoder
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix cross attn in decoder
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix unfused + bshd/sbhd + telayer
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove debugging
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

---------
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 75fe5601
...@@ -1107,9 +1107,11 @@ model_configs_te_layer = { ...@@ -1107,9 +1107,11 @@ model_configs_te_layer = {
"te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), "te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
"te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), "te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), "te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"),
"te_1_3": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), "te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), "te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), "te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"te_2_3": ModelConfig(1, 16, 16, 64, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"),
"te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"), "te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"),
"te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"), "te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"),
} }
...@@ -1120,7 +1122,7 @@ model_configs_te_layer = { ...@@ -1120,7 +1122,7 @@ model_configs_te_layer = {
@pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", model_configs_te_layer.keys()) @pytest.mark.parametrize("model", model_configs_te_layer.keys())
@pytest.mark.parametrize("ckpt_attn", [False]) @pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("qkv_format", ["sbhd"]) @pytest.mark.parametrize("qkv_format", ["sbhd", "bshd", "thd"])
@pytest.mark.parametrize("fused_qkv_params", [False]) @pytest.mark.parametrize("fused_qkv_params", [False])
@pytest.mark.parametrize("RoPE", [False]) @pytest.mark.parametrize("RoPE", [False])
def test_transformer_layer( def test_transformer_layer(
...@@ -1137,13 +1139,18 @@ def test_transformer_layer( ...@@ -1137,13 +1139,18 @@ def test_transformer_layer(
available_backends, _, fused_attn_backends = _get_attention_backends( available_backends, _, fused_attn_backends = _get_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout="sbh3d" if fused_qkv_params else "sb3hd", qkv_layout=(
qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd")
),
) )
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# Skip if only unfused backend is supported # Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.") pytest.skip("Less than two backends to compare.")
# Skip if qkv_format = thd and "padding" not in attn_mask_type
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
pytest.skip("THD requires padding mask.")
# UnfusedDotProductAttention backend # UnfusedDotProductAttention backend
if unfused_attn_supported: if unfused_attn_supported:
...@@ -1264,48 +1271,84 @@ def _run_transformer_layer( ...@@ -1264,48 +1271,84 @@ def _run_transformer_layer(
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
# Create input tensor # Create input tensor
inp = torch.randn( if qkv_format == "sbhd":
config.max_seqlen_q, inp = torch.randn(
config.batch_size, config.max_seqlen_q,
config.hidden_size, config.batch_size,
dtype=dtype, config.hidden_size,
device="cuda", dtype=dtype,
requires_grad=True, device="cuda",
) requires_grad=True,
# In case the format to be tested is batch-first, need to transpose the )
# input tensor. inp_enc = torch.randn(
config.max_seqlen_kv,
config.batch_size,
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
if qkv_format == "bshd": if qkv_format == "bshd":
inp = inp.transpose(0, 1) inp = torch.randn(
config.batch_size,
config.max_seqlen_q,
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_enc = torch.randn(
config.batch_size,
config.max_seqlen_kv,
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
# Create seqlens # Create seqlens
if "padding" in config.attn_mask_type: if "padding" in config.attn_mask_type or qkv_format == "thd":
seqlens_q = torch.randint( if config.attn_type == "self":
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" seqlens_q = torch.randint(
) 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
seqlens_kv = seqlens_q
if config.attn_type == "cross":
if config.max_seqlen_q > 1:
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
else:
seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
seqlens_kv = torch.randint(
1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
)
else: else:
seqlens_q = torch.full( seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
) )
seqlens_kv = torch.full(
# Create attention mask if padding [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
attention_mask = None )
if "padding" in config.attn_mask_type: cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool) cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
for i in range(config.batch_size): cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
attention_mask_q = torch.cat( cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
[ if qkv_format == "thd":
attention_mask_q, inp = torch.randn(
torch.Tensor( cu_seqlens_q[-1],
[False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i]) config.hidden_size,
) dtype=dtype,
.to(torch.bool) device="cuda",
.unsqueeze(0) requires_grad=True,
.unsqueeze(0) )
.unsqueeze(0), inp_enc = torch.randn(
], cu_seqlens_kv[-1],
dim=0, config.hidden_size,
) dtype=dtype,
attention_mask = attention_mask_q.to(device="cuda") device="cuda",
requires_grad=True,
)
sigma = 0.02 sigma = 0.02
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
...@@ -1357,7 +1400,7 @@ def _run_transformer_layer( ...@@ -1357,7 +1400,7 @@ def _run_transformer_layer(
sequence_parallel=False, sequence_parallel=False,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
layer_type="encoder", layer_type="encoder" if config.attn_type == "self" else "decoder",
drop_path_rate=drop_path_rates[layer_number - 1], drop_path_rate=drop_path_rates[layer_number - 1],
set_parallel_mode=True, set_parallel_mode=True,
fuse_qkv_params=fused_qkv_params, fuse_qkv_params=fused_qkv_params,
...@@ -1376,13 +1419,18 @@ def _run_transformer_layer( ...@@ -1376,13 +1419,18 @@ def _run_transformer_layer(
# Run a forward and backward pass # Run a forward and backward pass
out = block( out = block(
inp, inp,
attention_mask=attention_mask,
self_attn_mask_type=config.attn_mask_type, self_attn_mask_type=config.attn_mask_type,
encoder_output=inp_enc if config.attn_type == "cross" else None,
enc_dec_attn_mask_type=config.attn_mask_type if config.attn_type == "cross" else None,
checkpoint_core_attention=False, checkpoint_core_attention=False,
rotary_pos_emb=rotary_pos_emb, rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=config.attn_bias_type, core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias, core_attention_bias=bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
) )
loss = out.sum() loss = out.sum()
loss.backward() loss.backward()
......
...@@ -215,7 +215,12 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -215,7 +215,12 @@ class UnfusedDotProductAttention(torch.nn.Module):
if "padding" in attn_mask_type and attention_mask is None: if "padding" in attn_mask_type and attention_mask is None:
attention_mask = dpa_utils.get_padding_mask( attention_mask = dpa_utils.get_padding_mask(
batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv batch_size,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
self.attention_type,
) )
attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = ( attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = (
dpa_utils.get_full_mask( dpa_utils.get_full_mask(
......
...@@ -946,16 +946,24 @@ def get_attention_backend( ...@@ -946,16 +946,24 @@ def get_attention_backend(
@torch.no_grad() @torch.no_grad()
def get_padding_mask( def get_padding_mask(
batch_size: int, batch_size: int,
cu_seqlens_q: torch.Tensor, cu_seqlens_q: torch.Tensor = None,
cu_seqlens_kv: torch.Tensor, cu_seqlens_kv: torch.Tensor = None,
max_seqlen_q: int, max_seqlen_q: int = None,
max_seqlen_kv: int, max_seqlen_kv: int = None,
attention_type: str = "self",
): ):
"""Convert cu_seqlens to attention_mask""" """Convert cu_seqlens to attention_mask"""
assert (
cu_seqlens_q is not None and max_seqlen_q is not None
), "cu_seqlens_q and max_seqlen_q are required for self-attention and cross-attention"
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool) attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool) if attention_type == "cross":
assert (
cu_seqlens_kv is not None and max_seqlen_kv is not None
), "cu_seqlens_kv and max_seqlen_kv are required for cross-attention"
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
for i in range(batch_size): for i in range(batch_size):
attention_mask_q = torch.cat( attention_mask_q = torch.cat(
[ [
...@@ -968,21 +976,26 @@ def get_padding_mask( ...@@ -968,21 +976,26 @@ def get_padding_mask(
], ],
dim=0, dim=0,
) )
attention_mask_kv = torch.cat( if attention_type == "cross":
[ attention_mask_kv = torch.cat(
attention_mask_kv, [
torch.Tensor([False] * seqlens_kv[i] + [True] * (max_seqlen_kv - seqlens_kv[i])) attention_mask_kv,
.to(dtype=torch.bool) torch.Tensor([False] * seqlens_kv[i] + [True] * (max_seqlen_kv - seqlens_kv[i]))
.unsqueeze(0) .to(dtype=torch.bool)
.unsqueeze(0) .unsqueeze(0)
.unsqueeze(0), .unsqueeze(0)
], .unsqueeze(0),
dim=0, ],
dim=0,
)
attention_mask_q = attention_mask_q.to(device="cuda")
if attention_type == "self":
attention_mask = attention_mask_q
else:
attention_mask = (
attention_mask_q,
attention_mask_kv.to(device="cuda"),
) )
attention_mask = (
attention_mask_q.to(device="cuda"),
attention_mask_kv.to(device="cuda"),
)
return attention_mask return attention_mask
......
...@@ -482,6 +482,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -482,6 +482,8 @@ class MultiheadAttention(torch.nn.Module):
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None,
cu_seqlens_q_padded: Optional[torch.Tensor] = None,
cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None, max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None, max_seqlen_kv: Optional[int] = None,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
...@@ -556,6 +558,12 @@ class MultiheadAttention(torch.nn.Module): ...@@ -556,6 +558,12 @@ class MultiheadAttention(torch.nn.Module):
cu_seqlens_kv: Optional[torch.Tensor], default = `None` cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
max_seqlen_q: Optional[int], default = `None` max_seqlen_q: Optional[int], default = `None`
Maximum sequence length in `query_layer`. Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q` if not provided. Calculated from `cu_seqlens_q` if not provided.
...@@ -714,6 +722,18 @@ class MultiheadAttention(torch.nn.Module): ...@@ -714,6 +722,18 @@ class MultiheadAttention(torch.nn.Module):
for x in (key_layer, value_layer) for x in (key_layer, value_layer)
) )
if self.qkv_format == "thd":
key_layer, value_layer = (
x.reshape(x.size(0), -1, self.hidden_size_per_attention_head)
for x in (key_layer, value_layer)
)
else:
# key, value: -> [sq, b, ng, hn]
key_layer, value_layer = (
x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head)
for x in (key_layer, value_layer)
)
# Attention head [sq, b, h] --> [sq, b, hp] # Attention head [sq, b, h] --> [sq, b, hp]
if self.input_layernorm: if self.input_layernorm:
layernorm_query_outputs = self.layernorm_query( layernorm_query_outputs = self.layernorm_query(
...@@ -803,6 +823,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -803,6 +823,8 @@ class MultiheadAttention(torch.nn.Module):
qkv_format=self.qkv_format, qkv_format=self.qkv_format,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv, max_seqlen_kv=max_seqlen_kv,
attention_mask=attention_mask, attention_mask=attention_mask,
......
...@@ -179,12 +179,12 @@ class TransformerLayer(torch.nn.Module): ...@@ -179,12 +179,12 @@ class TransformerLayer(torch.nn.Module):
The device on which the parameters of the model will be allocated. It is the user's The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
forward pass. forward pass.
attn_input_format: {'sbhd', 'bshd'}, default = 'sbhd' attn_input_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
This controls whether the dimensions of the This controls whether the dimensions of the
intermediate hidden states is 'batch first' ('bshd') or intermediate hidden states is 'sequence first' ('sbhd'), 'batch first' ('bshd'),
'sequence first' ('sbhd'). `s` stands for the sequence or 'token first' ('thd'). `s` stands for the sequence length, `b` batch size,
length, `b` batch size, `h` the number of heads, `d` `t` the total number of tokens, `h` the number of heads, `d` head size.
head size. Note that these formats are very closely Note that these formats are very closely
related to the `qkv_format` in the `MultiHeadAttention` related to the `qkv_format` in the `MultiHeadAttention`
and `DotProductAttention` modules. and `DotProductAttention` modules.
name: str, default = `None` name: str, default = `None`
...@@ -552,6 +552,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -552,6 +552,8 @@ class TransformerLayer(torch.nn.Module):
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None,
cu_seqlens_q_padded: Optional[torch.Tensor] = None,
cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None, max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None, max_seqlen_kv: Optional[int] = None,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
...@@ -568,88 +570,99 @@ class TransformerLayer(torch.nn.Module): ...@@ -568,88 +570,99 @@ class TransformerLayer(torch.nn.Module):
Parameters Parameters
---------- ----------
hidden_states : torch.Tensor hidden_states : torch.Tensor
Input tensor. Input tensor.
attention_mask : Optional[torch.Tensor], default = `None` attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out self-attention softmax input. It should be Boolean tensor used to mask out self-attention softmax input. It should be
in [batch_size, 1, 1, seqlen_q] for padding masks, and broadcastable in [batch_size, 1, 1, seqlen_q] for padding masks, and broadcastable
to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] for "`arbitrary`" to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] for "`arbitrary`"
mask. It should be `None` for causal masks and "`no_mask`" type. mask. It should be `None` for causal masks and "`no_mask`" type.
A `True` value means the corresponding position is masked out and A `True` value means the corresponding position is masked out and
a `False` means that position is allowed to participate in attention. a `False` means that position is allowed to participate in attention.
self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal',
'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'}, 'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'},
default = `causal` default = `causal`
Type of attention mask passed into softmax operation for encoder. Type of attention mask passed into softmax operation for encoder.
By default, causal masks are aligned to the top left corner of By default, causal masks are aligned to the top left corner of
the softmax matrix. When "`bottom_right`" is specified in the mask type, the softmax matrix. When "`bottom_right`" is specified in the mask type,
causal masks are aligned to the bottom right corner. causal masks are aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = `None` window_size: Optional[Tuple[int, int]], default = `None`
Sliding window size for local attention in encoder. Sliding window size for local attention in encoder.
encoder_output : Optional[torch.Tensor], default = `None` encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`. `layer_type="decoder"`.
enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]], enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
default = `None`. Boolean tensors used to mask out inter-attention softmax input if default = `None`. Boolean tensors used to mask out inter-attention softmax input if
using `layer_type="decoder"`. It should be a tuple of two masks in using `layer_type="decoder"`. It should be a tuple of two masks in
[batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for padding masks. [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for padding masks.
It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
for "`arbitrary`" mask. It should be `None` for causal masks and "`no_mask`". for "`arbitrary`" mask. It should be `None` for causal masks and "`no_mask`".
A `True` value means the corresponding position is masked out and a `False` A `True` value means the corresponding position is masked out and a `False`
means that position is allowed to participate in attention. means that position is allowed to participate in attention.
enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'}, enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = `None` default = `None`
Type of attention mask passed into softmax operation for decoder. Type of attention mask passed into softmax operation for decoder.
enc_dec_window_size: Optional[Tuple[int, int]], default = `None` enc_dec_window_size: Optional[Tuple[int, int]], default = `None`
Sliding window size for local attention in decoder. Sliding window size for local attention in decoder.
is_first_microbatch : {True, False, None}, default = None is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not. whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations: When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of * during FP8 training, it allows caching of the FP8 versions of
the weights the weights
* it also allows skipping gradient accumulation during the * it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
checkpoint_core_attention: bool, default = `False` checkpoint_core_attention: bool, default = `False`
If true, forward activations for core attention are recomputed If true, forward activations for core attention are recomputed
during the backward pass in order to save memory that would during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until otherwise be occupied to store the forward activations until
backprop. backprop.
rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None` rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
Embeddings for query and key tensors for applying rotary position Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied. embedding. By default no input embedding is applied.
core_attention_bias_type: str, default = `no_bias` core_attention_bias_type: str, default = `no_bias`
Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`} Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
core_attention_bias: Optional[torch.Tensor], default = `None` core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T Bias tensor for Q * K.T
alibi_slopes: Optional[torch.Tensor], default = `None` alibi_slopes: Optional[torch.Tensor], default = `None`
ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
to the attention score of query i and key j. to the attention score of query i and key j.
cu_seqlens_q: Optional[torch.Tensor], default = `None` cu_seqlens_q: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32. with shape [batch_size + 1] and dtype torch.int32.
Used by encoders, or decoders' self-attention.
cu_seqlens_kv: Optional[torch.Tensor], default = `None` cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
Used by decoders' cross-attention.
cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32. Set to `cu_seqlens_q` if None.
Used by encoders, or decoders' self-attention.
cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
Set to `cu_seqlens_kv` if None. Used by decoders' cross-attention.
max_seqlen_q: Optional[int], default = `None` max_seqlen_q: Optional[int], default = `None`
Maximum sequence length in `query_layer`. Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q` if not provided. Calculated from `cu_seqlens_q_padded` if not provided.
max_seqlen_kv: Optional[int], default = `None` max_seqlen_kv: Optional[int], default = `None`
Maximum sequence length in `key_layer` and `value_layer`. Maximum sequence length in `key_layer` and `value_layer`.
Calculated from `cu_seqlens_kv` if not provided. Calculated from `cu_seqlens_kv_padded` if not provided.
fast_zero_fill: bool, default = `True` fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use. Whether to set output tensors to 0 or not before use.
inference_params: InferenceParams, default = None inference_params: InferenceParams, default = None
Inference parameters that are passed to the main model in order Inference parameters that are passed to the main model in order
to efficiently calculate and store the context during inference. to efficiently calculate and store the context during inference.
pad_between_seqs: Optional[bool], default = `None` pad_between_seqs: Optional[bool], default = `None`
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If true, there are padding tokens between individual sequences in a packed batch. If true, there are padding tokens between individual sequences in a packed batch,
i.e. qkv_format = 'thd'.
""" """
if self_attn_mask_type is None: if self_attn_mask_type is None:
...@@ -678,7 +691,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -678,7 +691,9 @@ class TransformerLayer(torch.nn.Module):
if ( if (
"padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary" "padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary"
) and attention_mask is not None: ) and attention_mask is not None:
assert attention_mask.dtype == torch.bool, "Attention mask must be a boolean tensor" assert all(
attention_mask[i].dtype == torch.bool for i in range(len(attention_mask))
), "Attention mask must be a boolean tensor or a list/tuple of two boolean tensors"
if ( if (
"padding" in enc_dec_attn_mask_type or enc_dec_attn_mask_type == "arbitrary" "padding" in enc_dec_attn_mask_type or enc_dec_attn_mask_type == "arbitrary"
) and enc_dec_attn_mask is not None: ) and enc_dec_attn_mask is not None:
...@@ -707,9 +722,11 @@ class TransformerLayer(torch.nn.Module): ...@@ -707,9 +722,11 @@ class TransformerLayer(torch.nn.Module):
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_q,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_q_padded,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv, max_seqlen_kv=max_seqlen_q,
fast_zero_fill=fast_zero_fill, fast_zero_fill=fast_zero_fill,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
) )
...@@ -733,12 +750,21 @@ class TransformerLayer(torch.nn.Module): ...@@ -733,12 +750,21 @@ class TransformerLayer(torch.nn.Module):
attn_mask_type=enc_dec_attn_mask_type, attn_mask_type=enc_dec_attn_mask_type,
window_size=enc_dec_window_size, window_size=enc_dec_window_size,
encoder_output=encoder_output, encoder_output=encoder_output,
inference_params=inference_params,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
fast_zero_fill=fast_zero_fill, fast_zero_fill=fast_zero_fill,
pad_between_seqs=pad_between_seqs,
) )
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
attention_output, attention_bias, residual = inter_attention_outputs attention_output, attention_bias, residual = inter_attention_outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment