Unverified Commit f692b98d authored by Ivan Komarov's avatar Ivan Komarov Committed by GitHub
Browse files

Fix spurious re-compilations of `rotary_kernel` (#911)

All integer parameters are specialized by default, so the two parameters
removed in this commit could lead to kernel re-compilation, even if
they were completely unused.
parent 23e8fa5a
...@@ -8,15 +8,6 @@ import triton ...@@ -8,15 +8,6 @@ import triton
import triton.language as tl import triton.language as tl
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_M": 2}),
# triton.Config({"BLOCK_M": 4}),
# triton.Config({"BLOCK_M": 8}),
# triton.Config({"BLOCK_M": 16}),
# ],
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"],
# )
@triton.jit @triton.jit
def rotary_kernel( def rotary_kernel(
OUT, # Pointers to matrices OUT, # Pointers to matrices
...@@ -27,10 +18,8 @@ def rotary_kernel( ...@@ -27,10 +18,8 @@ def rotary_kernel(
SEQLEN_OFFSETS, # this could be int or a pointer SEQLEN_OFFSETS, # this could be int or a pointer
# Matrix dimensions # Matrix dimensions
seqlen, seqlen,
nheads,
rotary_dim, rotary_dim,
seqlen_ro, seqlen_ro,
CACHE_KEY_SEQLEN,
# strides # strides
stride_out_batch, stride_out_batch,
stride_out_seqlen, stride_out_seqlen,
...@@ -218,10 +207,8 @@ def apply_rotary( ...@@ -218,10 +207,8 @@ def apply_rotary(
cu_seqlens, cu_seqlens,
seqlen_offsets, seqlen_offsets,
seqlen, # shapes seqlen, # shapes
nheads,
rotary_dim, rotary_dim,
seqlen_ro, seqlen_ro,
seqlen // 128, # key for triton cache (limit number of compilations)
output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
output.stride(-3), # seqlen_stride or total_seqlen_stride output.stride(-3), # seqlen_stride or total_seqlen_stride
output.stride(-2), # nheads_stride output.stride(-2), # nheads_stride
......
...@@ -252,3 +252,40 @@ def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_of ...@@ -252,3 +252,40 @@ def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_of
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol) assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item() atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()
assert torch.allclose(x_grad, x_pt.grad, rtol=rtol, atol=2 * atol) assert torch.allclose(x_grad, x_pt.grad, rtol=rtol, atol=2 * atol)
def test_compilation_count():
batch_size = 1
headdim = 128
device = "cuda"
dtype = torch.float16
torch.manual_seed(42)
from triton.runtime.jit import JITFunction
from flash_attn.ops.triton.rotary import rotary_kernel
compilation_count = 0
def count_compilations(*args, **kwargs):
nonlocal compilation_count
compilation_count += 1
old_cache_func = JITFunction.cache_hook
try:
rotary_kernel.cache.clear()
JITFunction.cache_hook = count_compilations
for seqlen in (128, 256):
for nheads in (4, 32):
x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device=device)
x.requires_grad_()
cos, sin = generate_cos_sin(seqlen, headdim, device, dtype)
out = apply_rotary_emb(x, cos, sin)
out.backward(torch.randn_like(out))
# Only two kernels are expected to be compiled:
# * for the forward pass (conjugate=False)
# * for the backward pass (conjugate=True)
assert compilation_count == 2
finally:
JITFunction.cache_hook = old_cache_func
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment