"tests/models/roberta/test_tokenization_roberta.py" did not exist on "7e98e211f0e86e414b22946bd89391e49d2ea900"
Unverified Commit 2e37ef35 authored by Jason Phang's avatar Jason Phang Committed by GitHub
Browse files

Remove RuntimeErrors for NaN-checking in 20B (#17563)

parent f6ad0e05
......@@ -193,8 +193,6 @@ class GPTNeoXAttention(nn.Module):
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
attn_scores = torch.einsum("bik,bjk->bij", query, key) / self.norm_factor
if torch.isnan(attn_scores).any():
raise RuntimeError()
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
attn_scores = torch.where(causal_mask, attn_scores, self.masked_bias.to(attn_scores.dtype))
......@@ -204,8 +202,6 @@ class GPTNeoXAttention(nn.Module):
attn_scores = attn_scores + attention_mask
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
if torch.isnan(attn_weights).any():
raise RuntimeError()
attn_weights = attn_weights.to(value.dtype)
# Mask heads if we want to
......@@ -213,8 +209,6 @@ class GPTNeoXAttention(nn.Module):
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
if torch.isnan(attn_output).any():
raise RuntimeError()
return attn_output, attn_weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment