Unverified Commit 06d5fa97 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Make QK layer scaling opt-in (#339)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7a30ba45
...@@ -157,6 +157,10 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -157,6 +157,10 @@ class UnfusedDotProductAttention(torch.nn.Module):
# on average it should not be partition dependent. # on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(attention_dropout) self.attention_dropout = torch.nn.Dropout(attention_dropout)
# An FP16 training trick required for certain GPT-like models.
self.apply_qk_layer_scaling = (
bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None)
def forward( def forward(
self, self,
query_layer: torch.Tensor, query_layer: torch.Tensor,
...@@ -166,7 +170,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -166,7 +170,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
"""core attention fprop""" """core attention fprop"""
batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
apply_qk_layer_scaling = self.layer_number is not None and key_layer.dtype == torch.float16 apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
# [b, np, sq, sk] # [b, np, sq, sk]
output_size = ( output_size = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment