Unverified Commit 0c9c0ba1 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Enforce boolean attention mask type (#49)



* Enforce boolean attention mask type
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d6ff6f4d
......@@ -520,6 +520,11 @@ class MultiHeadAttention(torch.nn.Module):
"""MultiHeadAttention FWD"""
# hidden_states: [sq, b, h]
if attention_mask is not None:
assert (
attention_mask.dtype == torch.bool
), "Attention mask must be a boolean tensor"
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
......@@ -1006,6 +1011,11 @@ class TransformerLayer(torch.nn.Module):
hidden_states = hidden_states.contiguous()
if attention_mask is not None:
assert (
attention_mask.dtype == torch.bool
), "Attention mask must be a boolean tensor"
# For AMP
if torch.is_autocast_enabled():
hidden_states = cast_if_needed(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment