Unverified Commit e67b4f2c authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Use FP32 in RoPE initialization (#1004)


Co-authored-by: default avatarOne <imone@tuta.io>
parent d6770d1f
......@@ -133,9 +133,10 @@ def test_rotary_embedding(
device="cuda")
# Create the rotary embedding.
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
inv_freq = 1.0 / (base**(
torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
t = torch.arange(max_position).float()
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cos_sin_cache = torch.cat((cos, sin), dim=-1)
......
......@@ -264,10 +264,10 @@ class PagedAttentionWithRoPE(PagedAttention):
self.is_neox_style = is_neox_style
# Create the cos and sin cache.
inv_freq = 1.0 / (base**(
torch.arange(0, rotary_dim, 2, device="cuda") / rotary_dim))
t = torch.arange(max_position, device="cuda").float()
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
inv_freq = 1.0 / (base**(torch.arange(
0, rotary_dim, 2, dtype=torch.float, device="cuda") / rotary_dim))
t = torch.arange(max_position, dtype=torch.float, device="cuda")
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment