Unverified Commit 4ab41737 authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Add support for Llama 3 rotary embeddings (#2286)

* Add support for Llama 3 rotary embeddings

* Update transformers to 4.43
parent 5d121a97
This diff is collapsed.
......@@ -26,7 +26,7 @@ hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97"
tokenizers = "^0.19.1"
huggingface-hub = "^0.23"
transformers = "^4.42"
transformers = "^4.43"
einops = "^0.6.1"
texttable = { version = "^1.6.7", optional = true }
datasets = { version = "^2.14.0", optional = true }
......
......@@ -11,12 +11,12 @@ grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
numpy==2.0.1 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
......@@ -41,7 +41,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.43.0 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
......
......@@ -11,12 +11,12 @@ grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
numpy==2.0.1 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
......@@ -41,7 +41,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.43.0 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
......
......@@ -11,12 +11,12 @@ grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
numpy==2.0.1 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
......@@ -41,7 +41,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.43.0 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
......
import os
import math
import torch
from torch import nn
from loguru import logger
......@@ -85,9 +86,13 @@ class PositionRotaryEmbedding(nn.Module):
scaling_factor = None
rope_scaling = _get_rope_config(config)
if rope_scaling is not None:
if rope_scaling["type"] == "linear":
# `rope_type` is now standard in transformers, but some existing models
# have `type` instead.
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))
if rope_type == "linear":
pass
elif rope_scaling["type"] == "dynamic":
elif rope_type == "dynamic":
scaling_factor = rope_scaling["factor"]
return DynamicPositionRotaryEmbedding(
dim=dim,
......@@ -96,7 +101,20 @@ class PositionRotaryEmbedding(nn.Module):
device=inv_freq.device,
scaling_factor=scaling_factor,
)
elif rope_scaling["type"] == "yarn":
elif rope_type == "llama3":
inv_freq = apply_llama3_scaling(
inv_freq,
scaling_factor=rope_scaling["factor"],
low_freq_factor=rope_scaling["low_freq_factor"],
high_freq_factor=rope_scaling["high_freq_factor"],
original_max_position_embeddings=rope_scaling[
"original_max_position_embeddings"
],
)
return cls(inv_freq, scaling_factor)
elif rope_type == "yarn":
scaling_factor = rope_scaling["factor"]
mscale = rope_scaling.get("mscale", 1.0)
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
......@@ -115,7 +133,7 @@ class PositionRotaryEmbedding(nn.Module):
mscale=mscale,
mscale_all_dim=mscale_all_dim,
)
elif rope_scaling["type"] in ["su", "longrope"]:
elif rope_type in ["su", "longrope"]:
short_factor = torch.tensor(
rope_scaling["short_factor"], dtype=torch.float32, device=device
)
......@@ -327,10 +345,6 @@ class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
self._sin_cached = torch.sin(freqs).to(dtype)
# Inverse dim formula to find dim based on number of rotations
import math
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
......@@ -434,3 +448,33 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
def apply_llama3_scaling(
freqs: torch.Tensor,
*,
scaling_factor: int,
low_freq_factor: int,
high_freq_factor: int,
original_max_position_embeddings: int,
):
low_freq_wavelen = original_max_position_embeddings / low_freq_factor
high_freq_wavelen = original_max_position_embeddings / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scaling_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment