Unverified Commit d45f47ab authored by Park Jun's avatar Park Jun Committed by GitHub
Browse files

Fix: Disable torch.autocast in RotaryEmbedding of Gemma and LLaMa for MPS device (#29439)



* Fix: Disable torch.autocast in RotaryEmbedding of Gemma and LLaMa for MPS devices

* Update src/transformers/models/gemma/modeling_gemma.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update llama ang gemma rope use cpu in mps device

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 2a939f20
......@@ -115,7 +115,7 @@ class GemmaRotaryEmbedding(nn.Module):
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) else "cpu"
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
......
......@@ -139,7 +139,7 @@ class LlamaRotaryEmbedding(nn.Module):
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) else "cpu"
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment