Unverified Commit 151a0af6 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Miscellaneous fixes for attention (#1780)



* add missing args in cross-attn
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* allow thd for TELayer
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add CP note for reordering
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix wording about CP
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add modulo cpx2 requirement
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add example of token reordering
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* improve the CP docstring
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* tweak CP wording
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* test thd TELayer
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add enc_dec_kv for decoder
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix cross attn in decoder
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix unfused + bshd/sbhd + telayer
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove debugging
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

---------
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 75fe5601
...@@ -1107,9 +1107,11 @@ model_configs_te_layer = { ...@@ -1107,9 +1107,11 @@ model_configs_te_layer = {
"te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), "te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
"te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), "te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), "te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"),
"te_1_3": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), "te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), "te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), "te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"te_2_3": ModelConfig(1, 16, 16, 64, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"),
"te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"), "te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"),
"te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"), "te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"),
} }
...@@ -1120,7 +1122,7 @@ model_configs_te_layer = { ...@@ -1120,7 +1122,7 @@ model_configs_te_layer = {
@pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", model_configs_te_layer.keys()) @pytest.mark.parametrize("model", model_configs_te_layer.keys())
@pytest.mark.parametrize("ckpt_attn", [False]) @pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("qkv_format", ["sbhd"]) @pytest.mark.parametrize("qkv_format", ["sbhd", "bshd", "thd"])
@pytest.mark.parametrize("fused_qkv_params", [False]) @pytest.mark.parametrize("fused_qkv_params", [False])
@pytest.mark.parametrize("RoPE", [False]) @pytest.mark.parametrize("RoPE", [False])
def test_transformer_layer( def test_transformer_layer(
...@@ -1137,13 +1139,18 @@ def test_transformer_layer( ...@@ -1137,13 +1139,18 @@ def test_transformer_layer(
available_backends, _, fused_attn_backends = _get_attention_backends( available_backends, _, fused_attn_backends = _get_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout="sbh3d" if fused_qkv_params else "sb3hd", qkv_layout=(
qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd")
),
) )
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# Skip if only unfused backend is supported # Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.") pytest.skip("Less than two backends to compare.")
# Skip if qkv_format = thd and "padding" not in attn_mask_type
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
pytest.skip("THD requires padding mask.")
# UnfusedDotProductAttention backend # UnfusedDotProductAttention backend
if unfused_attn_supported: if unfused_attn_supported:
...@@ -1264,6 +1271,7 @@ def _run_transformer_layer( ...@@ -1264,6 +1271,7 @@ def _run_transformer_layer(
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
# Create input tensor # Create input tensor
if qkv_format == "sbhd":
inp = torch.randn( inp = torch.randn(
config.max_seqlen_q, config.max_seqlen_q,
config.batch_size, config.batch_size,
...@@ -1272,40 +1280,75 @@ def _run_transformer_layer( ...@@ -1272,40 +1280,75 @@ def _run_transformer_layer(
device="cuda", device="cuda",
requires_grad=True, requires_grad=True,
) )
# In case the format to be tested is batch-first, need to transpose the inp_enc = torch.randn(
# input tensor. config.max_seqlen_kv,
config.batch_size,
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
if qkv_format == "bshd": if qkv_format == "bshd":
inp = inp.transpose(0, 1) inp = torch.randn(
config.batch_size,
config.max_seqlen_q,
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_enc = torch.randn(
config.batch_size,
config.max_seqlen_kv,
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
# Create seqlens # Create seqlens
if "padding" in config.attn_mask_type: if "padding" in config.attn_mask_type or qkv_format == "thd":
if config.attn_type == "self":
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
seqlens_kv = seqlens_q
if config.attn_type == "cross":
if config.max_seqlen_q > 1:
seqlens_q = torch.randint( seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
) )
else:
seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
seqlens_kv = torch.randint(
1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
)
else: else:
seqlens_q = torch.full( seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
) )
seqlens_kv = torch.full(
# Create attention mask if padding [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
attention_mask = None
if "padding" in config.attn_mask_type:
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size):
attention_mask_q = torch.cat(
[
attention_mask_q,
torch.Tensor(
[False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
) )
.to(torch.bool) cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
.unsqueeze(0) cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
.unsqueeze(0) cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
.unsqueeze(0), cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
], if qkv_format == "thd":
dim=0, inp = torch.randn(
cu_seqlens_q[-1],
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_enc = torch.randn(
cu_seqlens_kv[-1],
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
) )
attention_mask = attention_mask_q.to(device="cuda")
sigma = 0.02 sigma = 0.02
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
...@@ -1357,7 +1400,7 @@ def _run_transformer_layer( ...@@ -1357,7 +1400,7 @@ def _run_transformer_layer(
sequence_parallel=False, sequence_parallel=False,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
layer_type="encoder", layer_type="encoder" if config.attn_type == "self" else "decoder",
drop_path_rate=drop_path_rates[layer_number - 1], drop_path_rate=drop_path_rates[layer_number - 1],
set_parallel_mode=True, set_parallel_mode=True,
fuse_qkv_params=fused_qkv_params, fuse_qkv_params=fused_qkv_params,
...@@ -1376,13 +1419,18 @@ def _run_transformer_layer( ...@@ -1376,13 +1419,18 @@ def _run_transformer_layer(
# Run a forward and backward pass # Run a forward and backward pass
out = block( out = block(
inp, inp,
attention_mask=attention_mask,
self_attn_mask_type=config.attn_mask_type, self_attn_mask_type=config.attn_mask_type,
encoder_output=inp_enc if config.attn_type == "cross" else None,
enc_dec_attn_mask_type=config.attn_mask_type if config.attn_type == "cross" else None,
checkpoint_core_attention=False, checkpoint_core_attention=False,
rotary_pos_emb=rotary_pos_emb, rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=config.attn_bias_type, core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias, core_attention_bias=bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
) )
loss = out.sum() loss = out.sum()
loss.backward() loss.backward()
......
...@@ -215,7 +215,12 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -215,7 +215,12 @@ class UnfusedDotProductAttention(torch.nn.Module):
if "padding" in attn_mask_type and attention_mask is None: if "padding" in attn_mask_type and attention_mask is None:
attention_mask = dpa_utils.get_padding_mask( attention_mask = dpa_utils.get_padding_mask(
batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv batch_size,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
self.attention_type,
) )
attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = ( attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = (
dpa_utils.get_full_mask( dpa_utils.get_full_mask(
......
...@@ -946,15 +946,23 @@ def get_attention_backend( ...@@ -946,15 +946,23 @@ def get_attention_backend(
@torch.no_grad() @torch.no_grad()
def get_padding_mask( def get_padding_mask(
batch_size: int, batch_size: int,
cu_seqlens_q: torch.Tensor, cu_seqlens_q: torch.Tensor = None,
cu_seqlens_kv: torch.Tensor, cu_seqlens_kv: torch.Tensor = None,
max_seqlen_q: int, max_seqlen_q: int = None,
max_seqlen_kv: int, max_seqlen_kv: int = None,
attention_type: str = "self",
): ):
"""Convert cu_seqlens to attention_mask""" """Convert cu_seqlens to attention_mask"""
assert (
cu_seqlens_q is not None and max_seqlen_q is not None
), "cu_seqlens_q and max_seqlen_q are required for self-attention and cross-attention"
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool) attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
if attention_type == "cross":
assert (
cu_seqlens_kv is not None and max_seqlen_kv is not None
), "cu_seqlens_kv and max_seqlen_kv are required for cross-attention"
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool) attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
for i in range(batch_size): for i in range(batch_size):
attention_mask_q = torch.cat( attention_mask_q = torch.cat(
...@@ -968,6 +976,7 @@ def get_padding_mask( ...@@ -968,6 +976,7 @@ def get_padding_mask(
], ],
dim=0, dim=0,
) )
if attention_type == "cross":
attention_mask_kv = torch.cat( attention_mask_kv = torch.cat(
[ [
attention_mask_kv, attention_mask_kv,
...@@ -979,8 +988,12 @@ def get_padding_mask( ...@@ -979,8 +988,12 @@ def get_padding_mask(
], ],
dim=0, dim=0,
) )
attention_mask_q = attention_mask_q.to(device="cuda")
if attention_type == "self":
attention_mask = attention_mask_q
else:
attention_mask = ( attention_mask = (
attention_mask_q.to(device="cuda"), attention_mask_q,
attention_mask_kv.to(device="cuda"), attention_mask_kv.to(device="cuda"),
) )
return attention_mask return attention_mask
......
...@@ -482,6 +482,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -482,6 +482,8 @@ class MultiheadAttention(torch.nn.Module):
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None,
cu_seqlens_q_padded: Optional[torch.Tensor] = None,
cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None, max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None, max_seqlen_kv: Optional[int] = None,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
...@@ -556,6 +558,12 @@ class MultiheadAttention(torch.nn.Module): ...@@ -556,6 +558,12 @@ class MultiheadAttention(torch.nn.Module):
cu_seqlens_kv: Optional[torch.Tensor], default = `None` cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
max_seqlen_q: Optional[int], default = `None` max_seqlen_q: Optional[int], default = `None`
Maximum sequence length in `query_layer`. Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q` if not provided. Calculated from `cu_seqlens_q` if not provided.
...@@ -714,6 +722,18 @@ class MultiheadAttention(torch.nn.Module): ...@@ -714,6 +722,18 @@ class MultiheadAttention(torch.nn.Module):
for x in (key_layer, value_layer) for x in (key_layer, value_layer)
) )
if self.qkv_format == "thd":
key_layer, value_layer = (
x.reshape(x.size(0), -1, self.hidden_size_per_attention_head)
for x in (key_layer, value_layer)
)
else:
# key, value: -> [sq, b, ng, hn]
key_layer, value_layer = (
x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head)
for x in (key_layer, value_layer)
)
# Attention head [sq, b, h] --> [sq, b, hp] # Attention head [sq, b, h] --> [sq, b, hp]
if self.input_layernorm: if self.input_layernorm:
layernorm_query_outputs = self.layernorm_query( layernorm_query_outputs = self.layernorm_query(
...@@ -803,6 +823,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -803,6 +823,8 @@ class MultiheadAttention(torch.nn.Module):
qkv_format=self.qkv_format, qkv_format=self.qkv_format,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv, max_seqlen_kv=max_seqlen_kv,
attention_mask=attention_mask, attention_mask=attention_mask,
......
...@@ -179,12 +179,12 @@ class TransformerLayer(torch.nn.Module): ...@@ -179,12 +179,12 @@ class TransformerLayer(torch.nn.Module):
The device on which the parameters of the model will be allocated. It is the user's The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
forward pass. forward pass.
attn_input_format: {'sbhd', 'bshd'}, default = 'sbhd' attn_input_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
This controls whether the dimensions of the This controls whether the dimensions of the
intermediate hidden states is 'batch first' ('bshd') or intermediate hidden states is 'sequence first' ('sbhd'), 'batch first' ('bshd'),
'sequence first' ('sbhd'). `s` stands for the sequence or 'token first' ('thd'). `s` stands for the sequence length, `b` batch size,
length, `b` batch size, `h` the number of heads, `d` `t` the total number of tokens, `h` the number of heads, `d` head size.
head size. Note that these formats are very closely Note that these formats are very closely
related to the `qkv_format` in the `MultiHeadAttention` related to the `qkv_format` in the `MultiHeadAttention`
and `DotProductAttention` modules. and `DotProductAttention` modules.
name: str, default = `None` name: str, default = `None`
...@@ -552,6 +552,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -552,6 +552,8 @@ class TransformerLayer(torch.nn.Module):
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None,
cu_seqlens_q_padded: Optional[torch.Tensor] = None,
cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None, max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None, max_seqlen_kv: Optional[int] = None,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
...@@ -633,15 +635,25 @@ class TransformerLayer(torch.nn.Module): ...@@ -633,15 +635,25 @@ class TransformerLayer(torch.nn.Module):
cu_seqlens_q: Optional[torch.Tensor], default = `None` cu_seqlens_q: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32. with shape [batch_size + 1] and dtype torch.int32.
Used by encoders, or decoders' self-attention.
cu_seqlens_kv: Optional[torch.Tensor], default = `None` cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
Used by decoders' cross-attention.
cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32. Set to `cu_seqlens_q` if None.
Used by encoders, or decoders' self-attention.
cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
Set to `cu_seqlens_kv` if None. Used by decoders' cross-attention.
max_seqlen_q: Optional[int], default = `None` max_seqlen_q: Optional[int], default = `None`
Maximum sequence length in `query_layer`. Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q` if not provided. Calculated from `cu_seqlens_q_padded` if not provided.
max_seqlen_kv: Optional[int], default = `None` max_seqlen_kv: Optional[int], default = `None`
Maximum sequence length in `key_layer` and `value_layer`. Maximum sequence length in `key_layer` and `value_layer`.
Calculated from `cu_seqlens_kv` if not provided. Calculated from `cu_seqlens_kv_padded` if not provided.
fast_zero_fill: bool, default = `True` fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use. Whether to set output tensors to 0 or not before use.
inference_params: InferenceParams, default = None inference_params: InferenceParams, default = None
...@@ -649,7 +661,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -649,7 +661,8 @@ class TransformerLayer(torch.nn.Module):
to efficiently calculate and store the context during inference. to efficiently calculate and store the context during inference.
pad_between_seqs: Optional[bool], default = `None` pad_between_seqs: Optional[bool], default = `None`
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If true, there are padding tokens between individual sequences in a packed batch. If true, there are padding tokens between individual sequences in a packed batch,
i.e. qkv_format = 'thd'.
""" """
if self_attn_mask_type is None: if self_attn_mask_type is None:
...@@ -678,7 +691,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -678,7 +691,9 @@ class TransformerLayer(torch.nn.Module):
if ( if (
"padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary" "padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary"
) and attention_mask is not None: ) and attention_mask is not None:
assert attention_mask.dtype == torch.bool, "Attention mask must be a boolean tensor" assert all(
attention_mask[i].dtype == torch.bool for i in range(len(attention_mask))
), "Attention mask must be a boolean tensor or a list/tuple of two boolean tensors"
if ( if (
"padding" in enc_dec_attn_mask_type or enc_dec_attn_mask_type == "arbitrary" "padding" in enc_dec_attn_mask_type or enc_dec_attn_mask_type == "arbitrary"
) and enc_dec_attn_mask is not None: ) and enc_dec_attn_mask is not None:
...@@ -707,9 +722,11 @@ class TransformerLayer(torch.nn.Module): ...@@ -707,9 +722,11 @@ class TransformerLayer(torch.nn.Module):
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_q,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_q_padded,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv, max_seqlen_kv=max_seqlen_q,
fast_zero_fill=fast_zero_fill, fast_zero_fill=fast_zero_fill,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
) )
...@@ -733,12 +750,21 @@ class TransformerLayer(torch.nn.Module): ...@@ -733,12 +750,21 @@ class TransformerLayer(torch.nn.Module):
attn_mask_type=enc_dec_attn_mask_type, attn_mask_type=enc_dec_attn_mask_type,
window_size=enc_dec_window_size, window_size=enc_dec_window_size,
encoder_output=encoder_output, encoder_output=encoder_output,
inference_params=inference_params,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
fast_zero_fill=fast_zero_fill, fast_zero_fill=fast_zero_fill,
pad_between_seqs=pad_between_seqs,
) )
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
attention_output, attention_bias, residual = inter_attention_outputs attention_output, attention_bias, residual = inter_attention_outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment