"vscode:/vscode.git/clone" did not exist on "9454ac40fc2e8604cefe11b5166d2f24dbedc675"
Commit 861c8257 authored by Tri Dao's avatar Tri Dao
Browse files

[Rotary] Clean up rotary Triton implementation a bit

parent 1c523c1c
...@@ -79,12 +79,10 @@ def rotary_kernel( ...@@ -79,12 +79,10 @@ def rotary_kernel(
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
other=0.0, other=0.0,
).to(tl.float32) ).to(tl.float32)
if not CONJUGATE: if CONJUGATE:
o0 = x0 * cos - x1 * sin sin = -sin
o1 = x0 * sin + x1 * cos o0 = x0 * cos - x1 * sin
else: o1 = x0 * sin + x1 * cos
o0 = x0 * cos + x1 * sin
o1 = -x0 * sin + x1 * cos
# write back result # write back result
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim) OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half)) tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
...@@ -122,13 +120,11 @@ def rotary_kernel( ...@@ -122,13 +120,11 @@ def rotary_kernel(
x1 = tl.load( x1 = tl.load(
X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0 X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
).to(tl.float32) ).to(tl.float32)
if not CONJUGATE: if CONJUGATE:
o0 = x0 * cos - x1 * sin sin = -sin
o1 = x1 * sin + x0 * cos x0_cos = x0 * cos
else: x1_sin = x1 * sin
o0 = x0 * cos + x1 * sin out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
o1 = -x1 * sin + x0 * cos
out = tl.where(rk[None, :] % 2 == 0, o0, o1)
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim) OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim)) tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment