Unverified Commit c0caadbe authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

Expose `rotary_base` as an arg instead of hardcoding (#944)



* make rotary_base arg
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* rotary base can be a float
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 841634ca
......@@ -4051,6 +4051,7 @@ class RotaryPositionEmbedding(torch.nn.Module):
rotary_percent: float = 1.0,
seq_len_interpolation_factor: Optional[int] = None,
pretrained_max_position_embeddings: Optional[int] = None,
rotary_base: float = 10000.0,
):
"""
Parameters
......@@ -4069,8 +4070,9 @@ class RotaryPositionEmbedding(torch.nn.Module):
if rotary_percent < 1.0:
dim = int(dim * rotary_percent)
self.seq_len_interpolation_factor = seq_len_interpolation_factor
self.rotary_base = rotary_base
inv_freq = 1.0 / (
10000
self.rotary_base
** (
torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
/ dim
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment