Unverified Commit 44f6fdd7 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Llama 3.1: replace for loop by tensor ops at inv_freq initialization (#32244)

* replace for loop by tensor ops

* rm assert; readability
parent 8da90687
...@@ -324,18 +324,17 @@ def _compute_llama3_parameters( ...@@ -324,18 +324,17 @@ def _compute_llama3_parameters(
low_freq_wavelen = old_context_len / low_freq_factor low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in inv_freq: wavelen = 2 * math.pi / inv_freq
wavelen = 2 * math.pi / freq # wavelen < high_freq_wavelen: do nothing
if wavelen < high_freq_wavelen: # wavelen > low_freq_wavelen: divide by factor
new_freqs.append(freq) inv_freq_new = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
elif wavelen > low_freq_wavelen: # otherwise: interpolate between the two, using a smooth factor
new_freqs.append(freq / factor) smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
else: smoothed_inv_freq = (1 - smooth_factor) * inv_freq_new / factor + smooth_factor * inv_freq_new
assert low_freq_wavelen != high_freq_wavelen is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) inv_freq_new = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_new)
new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device)
return inv_freq, attention_factor return inv_freq, attention_factor
...@@ -501,7 +500,7 @@ def _validate_llama3_parameters(config: PretrainedConfig): ...@@ -501,7 +500,7 @@ def _validate_llama3_parameters(config: PretrainedConfig):
logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}") logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
if high_freq_factor is None or not isinstance(high_freq_factor, float): if high_freq_factor is None or not isinstance(high_freq_factor, float):
logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}") logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
if high_freq_factor < low_freq_factor: if high_freq_factor <= low_freq_factor:
logger.warning( logger.warning(
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=" "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
f"{high_freq_factor} and low_freq_factor={low_freq_factor}" f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment