"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "a2dec768a27ab7520d4ae4f5f72643043f305fd3"
Unverified Commit 2cf87e2b authored by Nora Belrose's avatar Nora Belrose Committed by GitHub
Browse files

Prevent Dynamo graph fragmentation in GPTNeoX with torch.baddbmm fix (#24941)



* Pass a Python scalar for alpha in torch.baddbmm

* fixup

---------
Co-authored-by: default avatarArthur Zucker <arthur.zucker@gmail.com>
parent b413e061
......@@ -100,11 +100,7 @@ class GPTNeoXAttention(nn.Module):
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
self._init_rope()
self.register_buffer(
"norm_factor",
torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()),
persistent=False,
)
self.norm_factor = self.head_size**-0.5
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.attention_dropout = nn.Dropout(config.attention_dropout)
......@@ -258,7 +254,7 @@ class GPTNeoXAttention(nn.Module):
query,
key.transpose(1, 2),
beta=1.0,
alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor),
alpha=self.norm_factor,
)
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment