"tests/models/vscode:/vscode.git/clone" did not exist on "cbb8a37929c3860210f95c9ec99b8b84b8cf57a1"
Unverified Commit 587a19c7 authored by SeongBeomLEE's avatar SeongBeomLEE Committed by GitHub
Browse files

fix: GPTNeoX half inference error (#22888)

* fix: half inference error

norm_factor is still torch.float32 after using model.half

So I changed it to register_buffer so I can change it to torch.float16 after using model.half

* fix: Added a variable "persistent=False"

* run make style
parent 3d852da2
......@@ -94,7 +94,11 @@ class GPTNeoXAttention(nn.Module):
self.rotary_emb = RotaryEmbedding(
self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base
)
self.norm_factor = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype())
self.register_buffer(
"norm_factor",
torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()),
persistent=False,
)
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment