Unverified Commit 2cf87e2b authored by Nora Belrose's avatar Nora Belrose Committed by GitHub
Browse files

Prevent Dynamo graph fragmentation in GPTNeoX with torch.baddbmm fix (#24941)



* Pass a Python scalar for alpha in torch.baddbmm

* fixup

---------
Co-authored-by: default avatarArthur Zucker <arthur.zucker@gmail.com>
parent b413e061
...@@ -100,11 +100,7 @@ class GPTNeoXAttention(nn.Module): ...@@ -100,11 +100,7 @@ class GPTNeoXAttention(nn.Module):
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False) self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
self._init_rope() self._init_rope()
self.register_buffer( self.norm_factor = self.head_size**-0.5
"norm_factor",
torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()),
persistent=False,
)
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size) self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.attention_dropout = nn.Dropout(config.attention_dropout) self.attention_dropout = nn.Dropout(config.attention_dropout)
...@@ -258,7 +254,7 @@ class GPTNeoXAttention(nn.Module): ...@@ -258,7 +254,7 @@ class GPTNeoXAttention(nn.Module):
query, query,
key.transpose(1, 2), key.transpose(1, 2),
beta=1.0, beta=1.0,
alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor), alpha=self.norm_factor,
) )
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment