Commit ee8984d2 authored by Alexander Ploshkin's avatar Alexander Ploshkin
Browse files

add asserts for sin shape

parent c7c66976
...@@ -43,6 +43,7 @@ class ApplyRotaryEmb(torch.autograd.Function): ...@@ -43,6 +43,7 @@ class ApplyRotaryEmb(torch.autograd.Function):
rotary_dim *= 2 rotary_dim *= 2
assert rotary_dim <= headdim assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x1, x2 = x[..., :rotary_dim].chunk(2, dim=-1) x1, x2 = x[..., :rotary_dim].chunk(2, dim=-1)
out = torch.empty_like(x) if not inplace else x out = torch.empty_like(x) if not inplace else x
o1, o2 = out[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (x1, x2) o1, o2 = out[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (x1, x2)
...@@ -90,6 +91,7 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function): ...@@ -90,6 +91,7 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
rotary_dim *= 2 rotary_dim *= 2
assert rotary_dim <= headdim assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
q1, q2 = qkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1) q1, q2 = qkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1)
rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'), rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False) rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment