Unverified Commit 2e37ef35 authored by Jason Phang's avatar Jason Phang Committed by GitHub
Browse files

Remove RuntimeErrors for NaN-checking in 20B (#17563)

parent f6ad0e05
...@@ -193,8 +193,6 @@ class GPTNeoXAttention(nn.Module): ...@@ -193,8 +193,6 @@ class GPTNeoXAttention(nn.Module):
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
attn_scores = torch.einsum("bik,bjk->bij", query, key) / self.norm_factor attn_scores = torch.einsum("bik,bjk->bij", query, key) / self.norm_factor
if torch.isnan(attn_scores).any():
raise RuntimeError()
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
attn_scores = torch.where(causal_mask, attn_scores, self.masked_bias.to(attn_scores.dtype)) attn_scores = torch.where(causal_mask, attn_scores, self.masked_bias.to(attn_scores.dtype))
...@@ -204,8 +202,6 @@ class GPTNeoXAttention(nn.Module): ...@@ -204,8 +202,6 @@ class GPTNeoXAttention(nn.Module):
attn_scores = attn_scores + attention_mask attn_scores = attn_scores + attention_mask
attn_weights = nn.functional.softmax(attn_scores, dim=-1) attn_weights = nn.functional.softmax(attn_scores, dim=-1)
if torch.isnan(attn_weights).any():
raise RuntimeError()
attn_weights = attn_weights.to(value.dtype) attn_weights = attn_weights.to(value.dtype)
# Mask heads if we want to # Mask heads if we want to
...@@ -213,8 +209,6 @@ class GPTNeoXAttention(nn.Module): ...@@ -213,8 +209,6 @@ class GPTNeoXAttention(nn.Module):
attn_weights = attn_weights * head_mask attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value) attn_output = torch.matmul(attn_weights, value)
if torch.isnan(attn_output).any():
raise RuntimeError()
return attn_output, attn_weights return attn_output, attn_weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment