"vscode:/vscode.git/clone" did not exist on "7e61d56a45c19284cfda0cee8995fb552f6b1f4e"
Unverified Commit 587a19c7 authored by SeongBeomLEE's avatar SeongBeomLEE Committed by GitHub
Browse files

fix: GPTNeoX half inference error (#22888)

* fix: half inference error

norm_factor is still torch.float32 after using model.half

So I changed it to register_buffer so I can change it to torch.float16 after using model.half

* fix: Added a variable "persistent=False"

* run make style
parent 3d852da2
...@@ -94,7 +94,11 @@ class GPTNeoXAttention(nn.Module): ...@@ -94,7 +94,11 @@ class GPTNeoXAttention(nn.Module):
self.rotary_emb = RotaryEmbedding( self.rotary_emb = RotaryEmbedding(
self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base
) )
self.norm_factor = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()) self.register_buffer(
"norm_factor",
torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()),
persistent=False,
)
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size) self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment