Commit 942fcbf0 authored by Tri Dao's avatar Tri Dao
Browse files

[Rotary] Implement rotary in Triton

parent 08e98471
This diff is collapsed.
......@@ -68,6 +68,8 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
# We don't store these biases
state_dict.pop(f"transformer.layers.{l}.attention.bias")
state_dict.pop(f"transformer.layers.{l}.attention.masked_bias")
# We don't store these
state_dict.pop(f"transformer.layers.{l}.attention.rotary_emb.inv_freq", None)
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim = config.hidden_size // config.num_attention_heads
......@@ -89,11 +91,6 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
r"transformer.layers.\1.mixer.out_proj.",
key,
)
key = re.sub(
r"^transformer.layers.(\d+).attention.rotary_emb.",
r"transformer.layers.\1.mixer.rotary_emb.",
key,
)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
......
# Adapted on https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
# Adapted from https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py
from typing import Optional
import torch
import triton
import triton.language as tl
from torch.autograd.function import FunctionCtx
from torch.cuda.amp import custom_fwd
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
from flash_attn.ops.triton.k_activations import (
......
from typing import Union
import torch
import triton
import triton.language as tl
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_M": 2}),
# triton.Config({"BLOCK_M": 4}),
# triton.Config({"BLOCK_M": 8}),
# triton.Config({"BLOCK_M": 16}),
# ],
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"]
# )
@triton.jit
def rotary_kernel(
OUT, # Pointers to matrices
X,
COS,
SIN,
SEQLEN_OFFSETS, # this could be int or a pointer
# Matrix dimensions
seqlen,
nheads,
rotary_dim,
seqlen_ro,
CACHE_KEY_SEQLEN,
# strides
stride_out_batch,
stride_out_seqlen,
stride_out_nheads,
stride_out_headdim,
stride_x_batch,
stride_x_seqlen,
stride_x_nheads,
stride_x_headdim,
# Meta-parameters
BLOCK_K: tl.constexpr,
IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
INTERLEAVED: tl.constexpr,
CONJUGATE: tl.constexpr,
BLOCK_M: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_batch = tl.program_id(axis=1)
pid_head = tl.program_id(axis=2)
rotary_dim_half = rotary_dim // 2
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rk = tl.arange(0, BLOCK_K // 2)
if not IS_SEQLEN_OFFSETS_TENSOR:
rm_cs = rm + SEQLEN_OFFSETS
else:
rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
X = X + (
pid_batch * stride_x_batch
+ rm[:, None] * stride_x_seqlen
+ pid_head * stride_x_nheads
+ rk[None, :] * stride_x_headdim * (2 if INTERLEAVED else 1)
)
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk[None, :])
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk[None, :])
cos = tl.load(
COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk[None, :] < rotary_dim_half), other=1.0
).to(tl.float32)
sin = tl.load(
SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk[None, :] < rotary_dim_half), other=0.0
).to(tl.float32)
x0 = tl.load(X, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half), other=0.0).to(
tl.float32
)
x1 = tl.load(
X + stride_x_headdim * (1 if INTERLEAVED else rotary_dim_half),
mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half),
other=0.0,
).to(tl.float32)
if not CONJUGATE:
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos
else:
o0 = x0 * cos + x1 * sin
o1 = -x0 * sin + x1 * cos
# write back result
OUT = OUT + (
pid_batch * stride_out_batch
+ rm[:, None] * stride_out_seqlen
+ pid_head * stride_out_nheads
+ rk[None, :] * stride_out_headdim * (2 if INTERLEAVED else 1)
)
tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half))
tl.store(
OUT + stride_out_headdim * (1 if INTERLEAVED else rotary_dim_half),
o1,
mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half),
)
def apply_rotary(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
seqlen_offsets: Union[int, torch.Tensor] = 0,
interleaved=False,
inplace=False,
conjugate=False,
) -> torch.Tensor:
"""
Arguments:
x: (batch, seqlen, nheads, headdim)
cos: (seqlen_ro, rotary_dim / 2)
sin: (seqlen_ro, rotary_dim / 2)
seqlen_offsets: integer or integer tensor of size (batch,)
Returns:
y: (batch, seqlen, nheads, headdim)
"""
batch, seqlen, nheads, headdim = x.shape
seqlen_ro, rotary_dim = cos.shape
assert sin.shape == cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
assert headdim <= 256, "Only support headdim <= 256"
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
assert (
cos.dtype == sin.dtype
), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
assert (
x.dtype == cos.dtype
), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
cos, sin = cos.contiguous(), sin.contiguous()
if isinstance(seqlen_offsets, torch.Tensor):
assert seqlen_offsets.shape == (batch,)
assert seqlen_offsets.dtype in [torch.int32, torch.int64]
seqlen_offsets = seqlen_offsets.contiguous()
else:
assert seqlen_offsets + seqlen <= seqlen_ro
output = torch.empty_like(x) if not inplace else x
if rotary_dim < headdim and not inplace:
output[..., rotary_dim:].copy_(x[..., rotary_dim:])
BLOCK_K = (
32
if rotary_dim <= 32
else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
)
grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
rotary_kernel[grid](
output, # data ptrs
x,
cos,
sin,
seqlen_offsets,
seqlen, # shapes
nheads,
rotary_dim,
seqlen_ro,
seqlen // 128, # key for triton cache (limit number of compilations)
output.stride(0), # strides
output.stride(1),
output.stride(2),
output.stride(3),
x.stride(0),
x.stride(1),
x.stride(2),
x.stride(3),
BLOCK_K,
isinstance(seqlen_offsets, torch.Tensor),
interleaved,
conjugate,
BLOCK_M,
)
return output
......@@ -131,6 +131,8 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
)
print(out_cg.sequences)
parallel_state.destroy_model_parallel()
if not rotary:
out_hf = model_hf.generate(
input_ids=input_ids,
......@@ -171,5 +173,3 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
).abs().max().item() < 3 * (
torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item()
parallel_state.destroy_model_parallel()
import math
import random
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange
from flash_attn.layers.rotary import apply_rotary_emb_func, apply_rotary_emb_torch
from flash_attn.layers.rotary import apply_rotary_emb, apply_rotary_emb_torch
from flash_attn.layers.rotary import apply_rotary_emb_qkv_, apply_rotary_emb_kv_
is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
......@@ -13,33 +15,198 @@ is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
# @pytest.mark.parametrize('rotary_fraction', [0.5])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@pytest.mark.parametrize("interleaved", [False, True])
# @pytest.mark.parametrize('interleaved', [False])
@pytest.mark.parametrize("inplace", [False, True])
# @pytest.mark.parametrize('inplace', [False])
def test_rotary_single_tensor(inplace, rotary_fraction, dtype):
def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype):
rtol = 1e-3
batch_size = 32
nheads = 4
seqlen = 217
headdim = 128
device = "cuda"
torch.manual_seed(42)
x = torch.randn(
batch_size, seqlen, nheads, headdim, dtype=dtype, device="cuda", requires_grad=True
batch_size, seqlen, nheads, headdim, dtype=dtype, device=device, requires_grad=True
)
x_pt = x.detach().clone().requires_grad_()
rotary_dim = int(rotary_fraction * headdim)
assert rotary_dim % 2 == 0
angle = torch.randn(seqlen, rotary_dim // 2, device="cuda")
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
out = apply_rotary_emb_func(x, cos, sin, inplace)
out_pt = apply_rotary_emb_torch(x_pt, cos, sin)
# Numerical error if we just do any arithmetic
atol = ((out + 0.3 - 0.3) - out).abs().max().item()
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
if seqlen_offsets_type == 0:
seqlen_offsets = 0
elif seqlen_offsets_type is int:
seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item()
elif seqlen_offsets_type is torch.Tensor:
seqlen_offsets = torch.randint(
0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device
)
out = apply_rotary_emb(
x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace
)
if seqlen_offsets_type is torch.Tensor:
arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s")
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
else:
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
out_pt = apply_rotary_emb_torch(
x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
print(f"Output max diff: {(out - out_pt).abs().max().item()}")
g = torch.randn_like(out)
g_pt = g.clone() # If inplace=True, we might modify the gradient inplace
out.backward(g)
out_pt.backward(g_pt)
print(f"Grad max diff: {(x.grad - x_pt.grad).abs().max().item()}")
if not inplace:
assert torch.equal(x, x_pt)
# Numerical error if we just do any arithmetic
atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=2 * atol)
@pytest.mark.parametrize(
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@pytest.mark.parametrize("interleaved", [False, True])
# @pytest.mark.parametrize('interleaved', [False])
def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, dtype):
rtol = 1e-3
batch_size = 32
nheads = 4
seqlen = 512
headdim = 128
device = "cuda"
torch.manual_seed(42)
qkv = torch.randn(
batch_size, seqlen, 3, nheads, headdim, dtype=dtype, device=device, requires_grad=True
)
qkv_pt = qkv.detach().clone().requires_grad_()
rotary_dim = int(rotary_fraction * headdim)
assert rotary_dim % 2 == 0
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
if seqlen_offsets_type == 0:
seqlen_offsets = 0
elif seqlen_offsets_type is int:
seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item()
elif seqlen_offsets_type is torch.Tensor:
seqlen_offsets = torch.randint(
0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device
)
out = apply_rotary_emb_qkv_(
qkv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved
)
if seqlen_offsets_type is torch.Tensor:
arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s")
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
else:
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
q_pt = apply_rotary_emb_torch(
qkv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
k_pt = apply_rotary_emb_torch(
qkv_pt[:, :, 1].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
out_pt = torch.stack([q_pt, k_pt, qkv_pt[:, :, 2]], dim=2)
print(f"Output max diff: {(out - out_pt).abs().max().item()}")
g = torch.randn_like(out)
g_pt = g.clone() # Since inplace=True, we modify the gradient inplace
out.backward(g)
out_pt.backward(g_pt)
print(f"Grad max diff: {(qkv.grad - qkv_pt.grad).abs().max().item()}")
# Numerical error if we just do any arithmetic
atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
atol = ((qkv_pt.grad + 0.3 - 0.3) - qkv_pt.grad).abs().max().item()
assert torch.allclose(qkv.grad, qkv_pt.grad, rtol=rtol, atol=2 * atol)
@pytest.mark.parametrize(
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@pytest.mark.parametrize("interleaved", [False, True])
# @pytest.mark.parametrize('interleaved', [False])
def test_rotary_emb_kv(interleaved, rotary_fraction, seqlen_offsets_type, dtype):
rtol = 1e-3
batch_size = 32
nheads = 4
seqlen = 781
headdim = 64
device = "cuda"
torch.manual_seed(42)
kv = torch.randn(
batch_size, seqlen, 2, nheads, headdim, dtype=dtype, device=device, requires_grad=True
)
kv_pt = kv.detach().clone().requires_grad_()
rotary_dim = int(rotary_fraction * headdim)
assert rotary_dim % 2 == 0
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
if seqlen_offsets_type == 0:
seqlen_offsets = 0
elif seqlen_offsets_type is int:
seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item()
elif seqlen_offsets_type is torch.Tensor:
seqlen_offsets = torch.randint(
0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device
)
out = apply_rotary_emb_kv_(
kv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved
)
if seqlen_offsets_type is torch.Tensor:
arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s")
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
else:
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
k_pt = apply_rotary_emb_torch(
kv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
out_pt = torch.stack([k_pt, kv_pt[:, :, 1]], dim=2)
print(f"Output max diff: {(out - out_pt).abs().max().item()}")
g = torch.randn_like(out)
g_pt = g.clone() # Since inplace=True, we modify the gradient inplace
out.backward(g)
out_pt.backward(g_pt)
print(f"Grad max diff: {(kv.grad - kv_pt.grad).abs().max().item()}")
# Numerical error if we just do any arithmetic
atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
atol = ((kv_pt.grad + 0.3 - 0.3) - kv_pt.grad).abs().max().item()
assert torch.allclose(kv.grad, kv_pt.grad, rtol=rtol, atol=2 * atol)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment