Unverified Commit d3a4b475 authored by Daniel Han's avatar Daniel Han Committed by GitHub
Browse files

RoPE loses precision for Llama / Gemma + Gemma logits.float() (#29285)



* Update modeling_llama.py

Llama - Force float32 since bfloat16 loses precision on long contexts

* Update modeling_llama.py

* Update modeling_gemma.py

Fix RoPE and logits.float()

* @torch.no_grad()

* @torch.no_grad()

* Cos, Sin to float32

* cos, sin to float32

* Update src/transformers/models/gemma/modeling_gemma.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_llama.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Resolve PR conflicts

* Fix RoPE for llama

* Revert "Fix RoPE for llama"

This reverts commit b860a22dab9bb01cd15cb9a3220abeaefad3e458.

* Fix RoPE for llama

* RoPE device

* Autocast device type

* RoPE

* RoPE isinstance

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 7628b3a0
......@@ -101,18 +101,25 @@ class GemmaRotaryEmbedding(nn.Module):
self.base = base
self.register_buffer("inv_freq", None, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None:
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half
......@@ -1082,7 +1089,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
......
......@@ -126,6 +126,7 @@ class LlamaRotaryEmbedding(nn.Module):
)
return self._cos_cached
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
if seq_len is not None:
logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.39.")
......@@ -133,9 +134,16 @@ class LlamaRotaryEmbedding(nn.Module):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment