Unverified Commit 4ddd9de9 authored by Tim Dettmers's avatar Tim Dettmers Committed by GitHub
Browse files

Bugfix: LLaMA layer norm incorrectly changes input type and consumers lots of memory (#23535)



* Fixed bug where LLaMA layer norm would change input type.

* make fix-copies

---------
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
parent fe34486f
...@@ -81,14 +81,11 @@ class LlamaRMSNorm(nn.Module): ...@@ -81,14 +81,11 @@ class LlamaRMSNorm(nn.Module):
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, hidden_states): def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary return (self.weight * hidden_states).to(input_dtype)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
class LlamaRotaryEmbedding(torch.nn.Module): class LlamaRotaryEmbedding(torch.nn.Module):
......
...@@ -91,14 +91,11 @@ class OpenLlamaRMSNorm(nn.Module): ...@@ -91,14 +91,11 @@ class OpenLlamaRMSNorm(nn.Module):
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, hidden_states): def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary return (self.weight * hidden_states).to(input_dtype)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->OpenLlama # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->OpenLlama
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment