Commit 942fcbf0 authored by Tri Dao's avatar Tri Dao
Browse files

[Rotary] Implement rotary in Triton

parent 08e98471
This diff is collapsed.
...@@ -68,6 +68,8 @@ def remap_state_dict_hf_gpt_neox(state_dict, config): ...@@ -68,6 +68,8 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
# We don't store these biases # We don't store these biases
state_dict.pop(f"transformer.layers.{l}.attention.bias") state_dict.pop(f"transformer.layers.{l}.attention.bias")
state_dict.pop(f"transformer.layers.{l}.attention.masked_bias") state_dict.pop(f"transformer.layers.{l}.attention.masked_bias")
# We don't store these
state_dict.pop(f"transformer.layers.{l}.attention.rotary_emb.inv_freq", None)
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim) # GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim) # while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim = config.hidden_size // config.num_attention_heads headdim = config.hidden_size // config.num_attention_heads
...@@ -89,11 +91,6 @@ def remap_state_dict_hf_gpt_neox(state_dict, config): ...@@ -89,11 +91,6 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
r"transformer.layers.\1.mixer.out_proj.", r"transformer.layers.\1.mixer.out_proj.",
key, key,
) )
key = re.sub(
r"^transformer.layers.(\d+).attention.rotary_emb.",
r"transformer.layers.\1.mixer.rotary_emb.",
key,
)
return key return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
......
# Adapted on https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py # Adapted from https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py # and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py
from typing import Optional from typing import Optional
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from torch.autograd.function import FunctionCtx
from torch.cuda.amp import custom_fwd
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
from flash_attn.ops.triton.k_activations import ( from flash_attn.ops.triton.k_activations import (
......
from typing import Union
import torch
import triton
import triton.language as tl
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_M": 2}),
# triton.Config({"BLOCK_M": 4}),
# triton.Config({"BLOCK_M": 8}),
# triton.Config({"BLOCK_M": 16}),
# ],
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"]
# )
@triton.jit
def rotary_kernel(
OUT, # Pointers to matrices
X,
COS,
SIN,
SEQLEN_OFFSETS, # this could be int or a pointer
# Matrix dimensions
seqlen,
nheads,
rotary_dim,
seqlen_ro,
CACHE_KEY_SEQLEN,
# strides
stride_out_batch,
stride_out_seqlen,
stride_out_nheads,
stride_out_headdim,
stride_x_batch,
stride_x_seqlen,
stride_x_nheads,
stride_x_headdim,
# Meta-parameters
BLOCK_K: tl.constexpr,
IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
INTERLEAVED: tl.constexpr,
CONJUGATE: tl.constexpr,
BLOCK_M: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_batch = tl.program_id(axis=1)
pid_head = tl.program_id(axis=2)
rotary_dim_half = rotary_dim // 2
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rk = tl.arange(0, BLOCK_K // 2)
if not IS_SEQLEN_OFFSETS_TENSOR:
rm_cs = rm + SEQLEN_OFFSETS
else:
rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
X = X + (
pid_batch * stride_x_batch
+ rm[:, None] * stride_x_seqlen
+ pid_head * stride_x_nheads
+ rk[None, :] * stride_x_headdim * (2 if INTERLEAVED else 1)
)
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk[None, :])
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk[None, :])
cos = tl.load(
COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk[None, :] < rotary_dim_half), other=1.0
).to(tl.float32)
sin = tl.load(
SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk[None, :] < rotary_dim_half), other=0.0
).to(tl.float32)
x0 = tl.load(X, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half), other=0.0).to(
tl.float32
)
x1 = tl.load(
X + stride_x_headdim * (1 if INTERLEAVED else rotary_dim_half),
mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half),
other=0.0,
).to(tl.float32)
if not CONJUGATE:
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos
else:
o0 = x0 * cos + x1 * sin
o1 = -x0 * sin + x1 * cos
# write back result
OUT = OUT + (
pid_batch * stride_out_batch
+ rm[:, None] * stride_out_seqlen
+ pid_head * stride_out_nheads
+ rk[None, :] * stride_out_headdim * (2 if INTERLEAVED else 1)
)
tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half))
tl.store(
OUT + stride_out_headdim * (1 if INTERLEAVED else rotary_dim_half),
o1,
mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half),
)
def apply_rotary(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
seqlen_offsets: Union[int, torch.Tensor] = 0,
interleaved=False,
inplace=False,
conjugate=False,
) -> torch.Tensor:
"""
Arguments:
x: (batch, seqlen, nheads, headdim)
cos: (seqlen_ro, rotary_dim / 2)
sin: (seqlen_ro, rotary_dim / 2)
seqlen_offsets: integer or integer tensor of size (batch,)
Returns:
y: (batch, seqlen, nheads, headdim)
"""
batch, seqlen, nheads, headdim = x.shape
seqlen_ro, rotary_dim = cos.shape
assert sin.shape == cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
assert headdim <= 256, "Only support headdim <= 256"
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
assert (
cos.dtype == sin.dtype
), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
assert (
x.dtype == cos.dtype
), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
cos, sin = cos.contiguous(), sin.contiguous()
if isinstance(seqlen_offsets, torch.Tensor):
assert seqlen_offsets.shape == (batch,)
assert seqlen_offsets.dtype in [torch.int32, torch.int64]
seqlen_offsets = seqlen_offsets.contiguous()
else:
assert seqlen_offsets + seqlen <= seqlen_ro
output = torch.empty_like(x) if not inplace else x
if rotary_dim < headdim and not inplace:
output[..., rotary_dim:].copy_(x[..., rotary_dim:])
BLOCK_K = (
32
if rotary_dim <= 32
else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
)
grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
rotary_kernel[grid](
output, # data ptrs
x,
cos,
sin,
seqlen_offsets,
seqlen, # shapes
nheads,
rotary_dim,
seqlen_ro,
seqlen // 128, # key for triton cache (limit number of compilations)
output.stride(0), # strides
output.stride(1),
output.stride(2),
output.stride(3),
x.stride(0),
x.stride(1),
x.stride(2),
x.stride(3),
BLOCK_K,
isinstance(seqlen_offsets, torch.Tensor),
interleaved,
conjugate,
BLOCK_M,
)
return output
...@@ -131,6 +131,8 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): ...@@ -131,6 +131,8 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
) )
print(out_cg.sequences) print(out_cg.sequences)
parallel_state.destroy_model_parallel()
if not rotary: if not rotary:
out_hf = model_hf.generate( out_hf = model_hf.generate(
input_ids=input_ids, input_ids=input_ids,
...@@ -171,5 +173,3 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): ...@@ -171,5 +173,3 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
).abs().max().item() < 3 * ( ).abs().max().item() < 3 * (
torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1) torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item() ).abs().max().item()
parallel_state.destroy_model_parallel()
import math import math
import random
import pytest import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from flash_attn.layers.rotary import apply_rotary_emb_func, apply_rotary_emb_torch from flash_attn.layers.rotary import apply_rotary_emb, apply_rotary_emb_torch
from flash_attn.layers.rotary import apply_rotary_emb_qkv_, apply_rotary_emb_kv_
is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0) is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
...@@ -13,33 +15,198 @@ is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0) ...@@ -13,33 +15,198 @@ is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16]) "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
) )
# @pytest.mark.parametrize('dtype', ([torch.float16])) # @pytest.mark.parametrize('dtype', ([torch.float16]))
@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5]) @pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
# @pytest.mark.parametrize('rotary_fraction', [0.5]) # @pytest.mark.parametrize('rotary_fraction', [1.0])
@pytest.mark.parametrize("interleaved", [False, True])
# @pytest.mark.parametrize('interleaved', [False])
@pytest.mark.parametrize("inplace", [False, True]) @pytest.mark.parametrize("inplace", [False, True])
# @pytest.mark.parametrize('inplace', [False]) # @pytest.mark.parametrize('inplace', [False])
def test_rotary_single_tensor(inplace, rotary_fraction, dtype): def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype):
rtol = 1e-3 rtol = 1e-3
batch_size = 32 batch_size = 32
nheads = 4 nheads = 4
seqlen = 217 seqlen = 217
headdim = 128 headdim = 128
device = "cuda"
torch.manual_seed(42)
x = torch.randn( x = torch.randn(
batch_size, seqlen, nheads, headdim, dtype=dtype, device="cuda", requires_grad=True batch_size, seqlen, nheads, headdim, dtype=dtype, device=device, requires_grad=True
) )
x_pt = x.detach().clone().requires_grad_() x_pt = x.detach().clone().requires_grad_()
rotary_dim = int(rotary_fraction * headdim) rotary_dim = int(rotary_fraction * headdim)
assert rotary_dim % 2 == 0 assert rotary_dim % 2 == 0
angle = torch.randn(seqlen, rotary_dim // 2, device="cuda") angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype) cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype) sin = torch.sin(angle).to(dtype=dtype)
out = apply_rotary_emb_func(x, cos, sin, inplace) if seqlen_offsets_type == 0:
out_pt = apply_rotary_emb_torch(x_pt, cos, sin) seqlen_offsets = 0
# Numerical error if we just do any arithmetic elif seqlen_offsets_type is int:
atol = ((out + 0.3 - 0.3) - out).abs().max().item() seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item()
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol) elif seqlen_offsets_type is torch.Tensor:
seqlen_offsets = torch.randint(
0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device
)
out = apply_rotary_emb(
x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace
)
if seqlen_offsets_type is torch.Tensor:
arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s")
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
else:
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
out_pt = apply_rotary_emb_torch(
x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
print(f"Output max diff: {(out - out_pt).abs().max().item()}")
g = torch.randn_like(out) g = torch.randn_like(out)
g_pt = g.clone() # If inplace=True, we might modify the gradient inplace g_pt = g.clone() # If inplace=True, we might modify the gradient inplace
out.backward(g) out.backward(g)
out_pt.backward(g_pt) out_pt.backward(g_pt)
print(f"Grad max diff: {(x.grad - x_pt.grad).abs().max().item()}")
if not inplace:
assert torch.equal(x, x_pt)
# Numerical error if we just do any arithmetic
atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item() atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=2 * atol) assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=2 * atol)
@pytest.mark.parametrize(
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@pytest.mark.parametrize("interleaved", [False, True])
# @pytest.mark.parametrize('interleaved', [False])
def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, dtype):
rtol = 1e-3
batch_size = 32
nheads = 4
seqlen = 512
headdim = 128
device = "cuda"
torch.manual_seed(42)
qkv = torch.randn(
batch_size, seqlen, 3, nheads, headdim, dtype=dtype, device=device, requires_grad=True
)
qkv_pt = qkv.detach().clone().requires_grad_()
rotary_dim = int(rotary_fraction * headdim)
assert rotary_dim % 2 == 0
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
if seqlen_offsets_type == 0:
seqlen_offsets = 0
elif seqlen_offsets_type is int:
seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item()
elif seqlen_offsets_type is torch.Tensor:
seqlen_offsets = torch.randint(
0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device
)
out = apply_rotary_emb_qkv_(
qkv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved
)
if seqlen_offsets_type is torch.Tensor:
arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s")
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
else:
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
q_pt = apply_rotary_emb_torch(
qkv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
k_pt = apply_rotary_emb_torch(
qkv_pt[:, :, 1].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
out_pt = torch.stack([q_pt, k_pt, qkv_pt[:, :, 2]], dim=2)
print(f"Output max diff: {(out - out_pt).abs().max().item()}")
g = torch.randn_like(out)
g_pt = g.clone() # Since inplace=True, we modify the gradient inplace
out.backward(g)
out_pt.backward(g_pt)
print(f"Grad max diff: {(qkv.grad - qkv_pt.grad).abs().max().item()}")
# Numerical error if we just do any arithmetic
atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
atol = ((qkv_pt.grad + 0.3 - 0.3) - qkv_pt.grad).abs().max().item()
assert torch.allclose(qkv.grad, qkv_pt.grad, rtol=rtol, atol=2 * atol)
@pytest.mark.parametrize(
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@pytest.mark.parametrize("interleaved", [False, True])
# @pytest.mark.parametrize('interleaved', [False])
def test_rotary_emb_kv(interleaved, rotary_fraction, seqlen_offsets_type, dtype):
rtol = 1e-3
batch_size = 32
nheads = 4
seqlen = 781
headdim = 64
device = "cuda"
torch.manual_seed(42)
kv = torch.randn(
batch_size, seqlen, 2, nheads, headdim, dtype=dtype, device=device, requires_grad=True
)
kv_pt = kv.detach().clone().requires_grad_()
rotary_dim = int(rotary_fraction * headdim)
assert rotary_dim % 2 == 0
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
if seqlen_offsets_type == 0:
seqlen_offsets = 0
elif seqlen_offsets_type is int:
seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item()
elif seqlen_offsets_type is torch.Tensor:
seqlen_offsets = torch.randint(
0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device
)
out = apply_rotary_emb_kv_(
kv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved
)
if seqlen_offsets_type is torch.Tensor:
arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s")
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
else:
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
k_pt = apply_rotary_emb_torch(
kv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
out_pt = torch.stack([k_pt, kv_pt[:, :, 1]], dim=2)
print(f"Output max diff: {(out - out_pt).abs().max().item()}")
g = torch.randn_like(out)
g_pt = g.clone() # Since inplace=True, we modify the gradient inplace
out.backward(g)
out_pt.backward(g_pt)
print(f"Grad max diff: {(kv.grad - kv_pt.grad).abs().max().item()}")
# Numerical error if we just do any arithmetic
atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
atol = ((kv_pt.grad + 0.3 - 0.3) - kv_pt.grad).abs().max().item()
assert torch.allclose(kv.grad, kv_pt.grad, rtol=rtol, atol=2 * atol)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment