Commit 1c523c1c authored by Tri Dao's avatar Tri Dao
Browse files

[Rotary] Speed up rotary kernel when interleaved=True

parent 26d7d92f
...@@ -13,7 +13,7 @@ import triton.language as tl ...@@ -13,7 +13,7 @@ import triton.language as tl
# triton.Config({"BLOCK_M": 8}), # triton.Config({"BLOCK_M": 8}),
# triton.Config({"BLOCK_M": 16}), # triton.Config({"BLOCK_M": 16}),
# ], # ],
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"] # key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"],
# ) # )
@triton.jit @triton.jit
def rotary_kernel( def rotary_kernel(
...@@ -49,34 +49,34 @@ def rotary_kernel( ...@@ -49,34 +49,34 @@ def rotary_kernel(
pid_head = tl.program_id(axis=2) pid_head = tl.program_id(axis=2)
rotary_dim_half = rotary_dim // 2 rotary_dim_half = rotary_dim // 2
X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rk = tl.arange(0, BLOCK_K // 2)
if not IS_SEQLEN_OFFSETS_TENSOR: if not IS_SEQLEN_OFFSETS_TENSOR:
rm_cs = rm + SEQLEN_OFFSETS rm_cs = rm + SEQLEN_OFFSETS
else: else:
rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
rk = tl.arange(0, BLOCK_K)
rk_half = tl.arange(0, BLOCK_K // 2)
X = X + ( if not INTERLEAVED:
pid_batch * stride_x_batch # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
+ rm[:, None] * stride_x_seqlen X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)
+ pid_head * stride_x_nheads COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
+ rk[None, :] * stride_x_headdim * (2 if INTERLEAVED else 1) SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
)
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk[None, :])
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk[None, :])
cos = tl.load( cos = tl.load(
COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk[None, :] < rotary_dim_half), other=1.0 COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0
).to(tl.float32) ).to(tl.float32)
sin = tl.load( sin = tl.load(
SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk[None, :] < rotary_dim_half), other=0.0 SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0
).to(tl.float32)
x0 = tl.load(
X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0
).to(tl.float32) ).to(tl.float32)
x0 = tl.load(X, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half), other=0.0).to(
tl.float32
)
x1 = tl.load( x1 = tl.load(
X + stride_x_headdim * (1 if INTERLEAVED else rotary_dim_half), X + rotary_dim_half * stride_x_headdim,
mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half), mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
other=0.0, other=0.0,
).to(tl.float32) ).to(tl.float32)
if not CONJUGATE: if not CONJUGATE:
...@@ -85,20 +85,52 @@ def rotary_kernel( ...@@ -85,20 +85,52 @@ def rotary_kernel(
else: else:
o0 = x0 * cos + x1 * sin o0 = x0 * cos + x1 * sin
o1 = -x0 * sin + x1 * cos o1 = -x0 * sin + x1 * cos
# write back result # write back result
OUT = OUT + ( OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
pid_batch * stride_out_batch tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
+ rm[:, None] * stride_out_seqlen
+ pid_head * stride_out_nheads
+ rk[None, :] * stride_out_headdim * (2 if INTERLEAVED else 1)
)
tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half))
tl.store( tl.store(
OUT + stride_out_headdim * (1 if INTERLEAVED else rotary_dim_half), OUT + rotary_dim_half * stride_out_headdim,
o1, o1,
mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half), mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
)
else:
# We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
# Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
# Loading x0 will be fast but x1 will be slow.
# Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
# Then we do the calculation and use tl.where to pick put the right outputs for the even
# and for the odd indices.
rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
rk_repeat = tl.arange(0, BLOCK_K) // 2
X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
cos = tl.load(
COS,
mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
other=1.0,
).to(tl.float32)
sin = tl.load(
SIN,
mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
other=0.0,
).to(tl.float32)
x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(
tl.float32
) )
x1 = tl.load(
X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
).to(tl.float32)
if not CONJUGATE:
o0 = x0 * cos - x1 * sin
o1 = x1 * sin + x0 * cos
else:
o0 = x0 * cos + x1 * sin
o1 = -x1 * sin + x0 * cos
out = tl.where(rk[None, :] % 2 == 0, o0, o1)
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
def apply_rotary( def apply_rotary(
......
...@@ -20,7 +20,7 @@ is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0) ...@@ -20,7 +20,7 @@ is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5]) @pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
# @pytest.mark.parametrize('rotary_fraction', [1.0]) # @pytest.mark.parametrize('rotary_fraction', [1.0])
@pytest.mark.parametrize("interleaved", [False, True]) @pytest.mark.parametrize("interleaved", [False, True])
# @pytest.mark.parametrize('interleaved', [False]) # @pytest.mark.parametrize('interleaved', [True])
@pytest.mark.parametrize("inplace", [False, True]) @pytest.mark.parametrize("inplace", [False, True])
# @pytest.mark.parametrize('inplace', [False]) # @pytest.mark.parametrize('inplace', [False])
def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype): def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment