"vscode:/vscode.git/clone" did not exist on "2dfcc4f1bc03753ab438bfc2d7613623e8d04809"
Commit c33c4896 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Update utils.py (#286)

parent 1767ff4b
......@@ -69,7 +69,7 @@ def apply_rotary_emb(x, freqs_i):
n = x.size(1)
seq_len = freqs_i.size(0)
x_i = torch.view_as_complex(x[:seq_len].to(torch.float64).reshape(seq_len, n, -1, 2))
x_i = torch.view_as_complex(x[:seq_len].to(torch.float32).reshape(seq_len, n, -1, 2))
# Apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[seq_len:]])
......@@ -113,7 +113,7 @@ def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)),
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)),
)
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
......@@ -123,7 +123,7 @@ def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
half = dim // 2
position = position.type(torch.float64)
position = position.type(torch.float32)
# calculation
sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment