Unverified Commit 9d732fd2 authored by Gabriele Sarti's avatar Gabriele Sarti Committed by GitHub
Browse files

XGLM - Fix Softmax NaNs when using FP16 (#18057)



* fix fp16 for xglm

* Removed misleading comment

* Fix undefined variable
Co-authored-by: default avatarGabriele Sarti <gsarti@amazon.com>
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent 99c32493
......@@ -235,7 +235,6 @@ class XGLMSinusoidalPositionalEmbedding(nn.Module):
return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->XGLM
class XGLMAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
......@@ -338,8 +337,13 @@ class XGLMAttention(nn.Module):
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
if attn_weights.dtype == torch.float16:
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
else:
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
......
......@@ -18,7 +18,7 @@ import math
import unittest
from transformers import XGLMConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
from ...generation.test_generation_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
......@@ -468,3 +468,22 @@ class XGLMModelLanguageGenerationTest(unittest.TestCase):
model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=1.25 * MAX_TIME))
@require_torch_gpu
def test_batched_nan_fp16(self):
model_name = "facebook/xglm-564M"
tokenizer = XGLMTokenizer.from_pretrained(model_name, use_fast=False, padding_side="left")
model = XGLMForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, use_cache=True).cuda()
model = model.eval()
batch = tokenizer(["Who are you?", "Joe Biden is the president of"], padding=True, return_tensors="pt")
input_ids = batch["input_ids"].cuda()
attention_mask = batch["attention_mask"].cuda()
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
self.assertFalse(
torch.isnan(outputs.logits[0]).any().item()
) # the first logits could contain NaNs if it fails
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment