Unverified Commit 44f6fdd7 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Llama 3.1: replace for loop by tensor ops at inv_freq initialization (#32244)

* replace for loop by tensor ops

* rm assert; readability
parent 8da90687
......@@ -324,18 +324,17 @@ def _compute_llama3_parameters(
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in inv_freq:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device)
wavelen = 2 * math.pi / inv_freq
# wavelen < high_freq_wavelen: do nothing
# wavelen > low_freq_wavelen: divide by factor
inv_freq_new = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
# otherwise: interpolate between the two, using a smooth factor
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_new / factor + smooth_factor * inv_freq_new
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
inv_freq_new = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_new)
return inv_freq, attention_factor
......@@ -501,7 +500,7 @@ def _validate_llama3_parameters(config: PretrainedConfig):
logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
if high_freq_factor is None or not isinstance(high_freq_factor, float):
logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
if high_freq_factor < low_freq_factor:
if high_freq_factor <= low_freq_factor:
logger.warning(
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment