Commit 942fcbf0 authored by Tri Dao's avatar Tri Dao
Browse files

[Rotary] Implement rotary in Triton

parent 08e98471
# Copyright (c) 2023, Tri Dao. # Copyright (c) 2023, Tri Dao.
import math import math
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import rotary_emb
import torch import torch
from einops import rearrange, repeat from einops import rearrange, repeat
from flash_attn.ops.triton.rotary import apply_rotary
def rotate_half(x, interleaved=False): def rotate_half(x, interleaved=False):
...@@ -20,12 +20,12 @@ def rotate_half(x, interleaved=False): ...@@ -20,12 +20,12 @@ def rotate_half(x, interleaved=False):
def apply_rotary_emb_torch(x, cos, sin, interleaved=False): def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
""" """
x: (batch_size, seqlen, nheads, headdim) x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
""" """
ro_dim = cos.shape[-1] * 2 ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1] assert ro_dim <= x.shape[-1]
cos = repeat(cos, "s d -> s 1 (2 d)") cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
sin = repeat(sin, "s d -> s 1 (2 d)") sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
return torch.cat( return torch.cat(
[x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
dim=-1, dim=-1,
...@@ -34,229 +34,242 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False): ...@@ -34,229 +34,242 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
class ApplyRotaryEmb(torch.autograd.Function): class ApplyRotaryEmb(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x, cos, sin, interleaved=False, inplace=False): def forward(
""" ctx,
x: (batch_size, seqlen, nheads, headdim) x,
cos, sin: (seqlen, rotary_dim / 2) cos,
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead sin,
of 1st half and 2nd half (GPT-NeoX style). interleaved=False,
rotary_dim must be <= headdim inplace=False,
Apply rotary embedding to the first rotary_dim of x. seqlen_offsets: Union[int, torch.Tensor] = 0,
""" ):
batch, seqlen, nheads, headdim = x.shape out = apply_rotary(
rotary_seqlen, rotary_dim = cos.shape x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x_ro = x[..., :rotary_dim]
x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
out = torch.empty_like(x) if not inplace else x
out_ro = out[..., :rotary_dim]
if inplace:
o1, o2 = x1, x2
else:
o1, o2 = (
out_ro.chunk(2, dim=-1)
if not interleaved
else (out_ro[..., ::2], out_ro[..., 1::2])
)
rotary_emb.apply_rotary(
x1,
x2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
o1,
o2,
False,
) )
if not inplace and rotary_dim < headdim: if isinstance(seqlen_offsets, int):
out[..., rotary_dim:].copy_(x[..., rotary_dim:]) ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward
ctx.save_for_backward(cos, sin) ctx.seqlen_offsets = seqlen_offsets
else:
ctx.save_for_backward(cos, sin, seqlen_offsets)
ctx.seqlen_offsets = None
ctx.interleaved = interleaved ctx.interleaved = interleaved
ctx.inplace = inplace ctx.inplace = inplace
return out if not inplace else x return out if not inplace else x
@staticmethod @staticmethod
def backward(ctx, do): def backward(ctx, do):
cos, sin = ctx.saved_tensors seqlen_offsets = ctx.seqlen_offsets
_, seqlen, _, headdim = do.shape if seqlen_offsets is None:
rotary_dim = cos.shape[-1] cos, sin, seqlen_offsets = ctx.saved_tensors
rotary_dim *= 2
inplace = ctx.inplace
do_ro = do[..., :rotary_dim]
do1, do2 = (
do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2])
)
dx = torch.empty_like(do) if not inplace else do
if inplace:
dx1, dx2 = do1, do2
else: else:
dx_ro = dx[..., :rotary_dim] cos, sin = ctx.saved_tensors
dx1, dx2 = ( # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
dx_ro.chunk(2, dim=-1) # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
if not ctx.interleaved if not ctx.interleaved and not ctx.inplace:
else (dx_ro[..., ::2], dx_ro[..., 1::2]) do = do.clone()
) dx = apply_rotary(
rotary_emb.apply_rotary( do,
do1, cos,
do2, sin,
rearrange(cos[:seqlen], "s d -> s 1 d"), seqlen_offsets=seqlen_offsets,
rearrange(sin[:seqlen], "s d -> s 1 d"), interleaved=ctx.interleaved,
dx1, inplace=ctx.inplace,
dx2, conjugate=True,
True,
) )
if not inplace and rotary_dim < headdim: return dx, None, None, None, None, None
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None, None
def apply_rotary_emb(
x, cos, sin, interleaved=False, inplace=False, seqlen_offsets: Union[int, torch.Tensor] = 0
):
"""
Arguments:
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen_rotary, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
inplace: if True, apply rotary embedding in-place.
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
Return:
out: (batch_size, seqlen, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
return ApplyRotaryEmb.apply(x, cos, sin, interleaved, inplace, seqlen_offsets)
apply_rotary_emb_func = ApplyRotaryEmb.apply # For backward compatibility
apply_rotary_emb_func = apply_rotary_emb
class ApplyRotaryEmbQKV_(torch.autograd.Function): class ApplyRotaryEmbQKV_(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): def forward(
""" ctx,
qkv: (batch_size, seqlen, 3, nheads, headdim) qkv,
cos, sin: (seqlen, rotary_dim / 2) cos,
cos_k, sin_k: (seqlen, rotary_dim / 2), optional sin,
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of cos_k=None,
1st half and 2nd half (GPT-NeoX style). sin_k=None,
rotary_dim must be <= headdim interleaved=False,
Apply rotary embedding *inplace* to the first rotary_dim of q and k. seqlen_offsets: Union[int, torch.Tensor] = 0,
""" ):
batch, seqlen, three, nheads, headdim = qkv.shape batch, seqlen, three, nheads, headdim = qkv.shape
assert three == 3 assert three == 3
rotary_seqlen, rotary_dim = cos.shape if cos_k is None and sin_k is None and qkv.is_contiguous():
rotary_dim *= 2 # Call 1 kernel instead of 2 kernels
assert rotary_dim <= headdim # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
assert seqlen <= rotary_seqlen # dimensions, we get the same tensor
cos_k = cos if cos_k is None else cos_k qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
sin_k = sin if sin_k is None else sin_k apply_rotary(
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
q_ro = qkv[:, :, 0, :, :rotary_dim] )
q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2]) else:
rotary_emb.apply_rotary( cos_k = cos if cos_k is None else cos_k
q1, sin_k = sin if sin_k is None else sin_k
q2, q, k = qkv[:, :, 0], qkv[:, :, 1]
rearrange(cos[:seqlen], "s d -> s 1 d"), apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True)
rearrange(sin[:seqlen], "s d -> s 1 d"), apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True)
q1, ctx.save_for_backward(cos, sin, cos_k, sin_k)
q2, if isinstance(seqlen_offsets, int):
False, ctx.save_for_backward(cos, sin, cos_k, sin_k)
) ctx.seqlen_offsets = seqlen_offsets
k_ro = qkv[:, :, 1, :, :rotary_dim] else:
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2]) ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets)
rotary_emb.apply_rotary( ctx.seqlen_offsets = None
k1,
k2,
rearrange(cos_k[:seqlen], "s d -> s 1 d"),
rearrange(sin_k[:seqlen], "s d -> s 1 d"),
k1,
k2,
False,
)
ctx.save_for_backward(cos, sin, cos_k, sin_k)
ctx.interleaved = interleaved ctx.interleaved = interleaved
return qkv return qkv
@staticmethod @staticmethod
def backward(ctx, dqkv): def backward(ctx, dqkv):
cos, sin, cos_k, sin_k = ctx.saved_tensors seqlen_offsets = ctx.seqlen_offsets
_, seqlen, _, _, headdim = dqkv.shape if seqlen_offsets is None:
rotary_dim = cos.shape[-1] cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors
rotary_dim *= 2 else:
dq_ro = dqkv[:, :, 0, :, :rotary_dim] cos, sin, cos_k, sin_k = ctx.saved_tensors
dq1, dq2 = ( if cos_k is None and sin_k is None and dqkv.is_contiguous():
dq_ro.chunk(2, dim=-1) if not ctx.interleaved else (dq_ro[..., ::2], dq_ro[..., 1::2]) # Call 1 kernel instead of 2 kernels
) # We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
rotary_emb.apply_rotary( # dimensions, we get the same tensor
dq1, dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d")
dq2, apply_rotary(
rearrange(cos[:seqlen], "s d -> s 1 d"), dqk,
rearrange(sin[:seqlen], "s d -> s 1 d"), cos,
dq1, sin,
dq2, seqlen_offsets=seqlen_offsets,
True, interleaved=ctx.interleaved,
) inplace=True,
dk_ro = dqkv[:, :, 1, :, :rotary_dim] conjugate=True,
dk1, dk2 = ( )
dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2]) else:
) cos_k = cos if cos_k is None else cos_k
rotary_emb.apply_rotary( sin_k = sin if sin_k is None else sin_k
dk1, dq, dk = dqkv[:, :, 0], dqkv[:, :, 1]
dk2, apply_rotary(
rearrange(cos_k[:seqlen], "s d -> s 1 d"), dq, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True, conjugate=True
rearrange(sin_k[:seqlen], "s d -> s 1 d"), )
dk1, apply_rotary(
dk2, dk,
True, cos_k,
) sin_k,
return dqkv, None, None, None, None, None seqlen_offsets,
interleaved=interleaved,
inplace=True,
conjudate=True,
)
return dqkv, None, None, None, None, None, None
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply def apply_rotary_emb_qkv_(
qkv,
cos,
sin,
cos_k=None,
sin_k=None,
interleaved=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
Most commonly used in inference when we have KV cache.
Return:
qkv: (batch_size, seqlen, 3, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
"""
return ApplyRotaryEmbQKV_.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets)
class ApplyRotaryEmbKV_(torch.autograd.Function): class ApplyRotaryEmbKV_(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, kv, cos, sin, interleaved=False): def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0):
"""
kv: (batch_size, seqlen, 2, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of k.
"""
batch, seqlen, two, nheads, headdim = kv.shape batch, seqlen, two, nheads, headdim = kv.shape
assert two == 2 assert two == 2
rotary_seqlen, rotary_dim = cos.shape k = kv[:, :, 0]
rotary_dim *= 2 apply_rotary(
assert rotary_dim <= headdim k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
assert seqlen <= rotary_seqlen )
k_ro = kv[:, :, 0, :, :rotary_dim] if isinstance(seqlen_offsets, int):
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2]) ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward
rotary_emb.apply_rotary( ctx.seqlen_offsets = seqlen_offsets
k1, else:
k2, ctx.save_for_backward(cos, sin, seqlen_offsets)
rearrange(cos[:seqlen], "s d -> s 1 d"), ctx.seqlen_offsets = None
rearrange(sin[:seqlen], "s d -> s 1 d"),
k1,
k2,
False,
) # conj=False since this is the forward pass
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved ctx.interleaved = interleaved
return kv return kv
@staticmethod @staticmethod
def backward(ctx, dkv): def backward(ctx, dkv):
cos, sin = ctx.saved_tensors seqlen_offsets = ctx.seqlen_offsets
_, seqlen, _, _, headdim = dkv.shape if seqlen_offsets is None:
rotary_dim = cos.shape[-1] cos, sin, seqlen_offsets = ctx.saved_tensors
rotary_dim *= 2 else:
dk_ro = dkv[:, :, 0, :, :rotary_dim] cos, sin = ctx.saved_tensors
dk1, dk2 = ( apply_rotary(
dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2]) dkv[:, :, 0],
cos,
sin,
seqlen_offsets=seqlen_offsets,
interleaved=ctx.interleaved,
inplace=True,
conjugate=True,
) )
rotary_emb.apply_rotary( return dkv, None, None, None, None
dk1,
dk2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
dk1,
dk2,
True,
) # conj=True since this is the backward pass
return dkv, None, None, None
apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply
def apply_rotary_emb_kv_(
kv,
cos,
sin,
interleaved=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
):
"""
Arguments:
kv: (batch_size, seqlen, 2, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
Most commonly used in inference when we have KV cache.
Return:
kv: (batch_size, seqlen, 2, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of K.
"""
return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets)
class RotaryEmbedding(torch.nn.Module): class RotaryEmbedding(torch.nn.Module):
""" """
The rotary position embeddings from RoFormer_ (Su et. al). The rotary position embeddings from RoFormer_ (Su et. al).
...@@ -372,57 +385,70 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -372,57 +385,70 @@ class RotaryEmbedding(torch.nn.Module):
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
def forward( def forward(
self, qkv: torch.Tensor, kv: Optional[torch.Tensor] = None, seqlen_offset: int = 0 self,
qkv: torch.Tensor,
kv: Optional[torch.Tensor] = None,
seqlen_offset: Union[int, torch.Tensor] = 0,
max_seqlen: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
else it's just q of shape (batch, seqlen, nheads, headdim) else it's just q of shape (batch, seqlen, nheads, headdim)
kv: (batch, seqlen, 2, nheads, headdim) kv: (batch, seqlen, 2, nheads, headdim)
seqlen_offset: can be used in generation where the qkv being passed in is only the last seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
token in the batch. Most commonly used in inference when we have KV cache.
If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
should pass in max_seqlen, which will update the cos / sin cache up to that length.
Apply rotary embedding *inplace* to qkv and / or kv.
""" """
seqlen = qkv.shape[1] seqlen = qkv.shape[1]
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) if isinstance(seqlen_offset, int):
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
elif max_seqlen is not None:
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
if kv is None: if kv is None:
if self.scale is None: if self.scale is None:
return apply_rotary_emb_qkv_( return apply_rotary_emb_qkv_(
qkv, qkv,
self._cos_cached[seqlen_offset:], self._cos_cached,
self._sin_cached[seqlen_offset:], self._sin_cached,
None, interleaved=self.interleaved,
None, seqlen_offsets=seqlen_offset,
self.interleaved,
) )
else: else:
return apply_rotary_emb_qkv_( return apply_rotary_emb_qkv_(
qkv, qkv,
self._cos_cached[seqlen_offset:], self._cos_cached,
self._sin_cached[seqlen_offset:], self._sin_cached,
self._cos_k_cached[seqlen_offset:], self._cos_k_cached,
self._sin_k_cached[seqlen_offset:], self._sin_k_cached,
self.interleaved, interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
) )
else: else:
q = qkv q = qkv
q = apply_rotary_emb_func( q = apply_rotary_emb_func(
q, q,
self._cos_cached[seqlen_offset:], self._cos_cached,
self._sin_cached[seqlen_offset:], self._sin_cached,
self.interleaved, interleaved=self.interleaved,
True, inplace=True,
seqlen_offsets=seqlen_offset,
) )
if self.scale is None: if self.scale is None:
kv = apply_rotary_emb_kv_( kv = apply_rotary_emb_kv_(
kv, kv,
self._cos_cached[seqlen_offset:], self._cos_cached,
self._sin_cached[seqlen_offset:], self._sin_cached,
self.interleaved, interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
) )
else: else:
kv = apply_rotary_emb_kv_( kv = apply_rotary_emb_kv_(
kv, kv,
self._cos_k_cached[seqlen_offset:], self._cos_k_cached,
self._sin_k_cached[seqlen_offset:], self._sin_k_cached,
self.interleaved, interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
) )
return q, kv return q, kv
...@@ -68,6 +68,8 @@ def remap_state_dict_hf_gpt_neox(state_dict, config): ...@@ -68,6 +68,8 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
# We don't store these biases # We don't store these biases
state_dict.pop(f"transformer.layers.{l}.attention.bias") state_dict.pop(f"transformer.layers.{l}.attention.bias")
state_dict.pop(f"transformer.layers.{l}.attention.masked_bias") state_dict.pop(f"transformer.layers.{l}.attention.masked_bias")
# We don't store these
state_dict.pop(f"transformer.layers.{l}.attention.rotary_emb.inv_freq", None)
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim) # GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim) # while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim = config.hidden_size // config.num_attention_heads headdim = config.hidden_size // config.num_attention_heads
...@@ -89,11 +91,6 @@ def remap_state_dict_hf_gpt_neox(state_dict, config): ...@@ -89,11 +91,6 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
r"transformer.layers.\1.mixer.out_proj.", r"transformer.layers.\1.mixer.out_proj.",
key, key,
) )
key = re.sub(
r"^transformer.layers.(\d+).attention.rotary_emb.",
r"transformer.layers.\1.mixer.rotary_emb.",
key,
)
return key return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
......
# Adapted on https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py # Adapted from https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py # and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py
from typing import Optional from typing import Optional
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from torch.autograd.function import FunctionCtx
from torch.cuda.amp import custom_fwd
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
from flash_attn.ops.triton.k_activations import ( from flash_attn.ops.triton.k_activations import (
......
from typing import Union
import torch
import triton
import triton.language as tl
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_M": 2}),
# triton.Config({"BLOCK_M": 4}),
# triton.Config({"BLOCK_M": 8}),
# triton.Config({"BLOCK_M": 16}),
# ],
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"]
# )
@triton.jit
def rotary_kernel(
OUT, # Pointers to matrices
X,
COS,
SIN,
SEQLEN_OFFSETS, # this could be int or a pointer
# Matrix dimensions
seqlen,
nheads,
rotary_dim,
seqlen_ro,
CACHE_KEY_SEQLEN,
# strides
stride_out_batch,
stride_out_seqlen,
stride_out_nheads,
stride_out_headdim,
stride_x_batch,
stride_x_seqlen,
stride_x_nheads,
stride_x_headdim,
# Meta-parameters
BLOCK_K: tl.constexpr,
IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
INTERLEAVED: tl.constexpr,
CONJUGATE: tl.constexpr,
BLOCK_M: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_batch = tl.program_id(axis=1)
pid_head = tl.program_id(axis=2)
rotary_dim_half = rotary_dim // 2
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rk = tl.arange(0, BLOCK_K // 2)
if not IS_SEQLEN_OFFSETS_TENSOR:
rm_cs = rm + SEQLEN_OFFSETS
else:
rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
X = X + (
pid_batch * stride_x_batch
+ rm[:, None] * stride_x_seqlen
+ pid_head * stride_x_nheads
+ rk[None, :] * stride_x_headdim * (2 if INTERLEAVED else 1)
)
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk[None, :])
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk[None, :])
cos = tl.load(
COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk[None, :] < rotary_dim_half), other=1.0
).to(tl.float32)
sin = tl.load(
SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk[None, :] < rotary_dim_half), other=0.0
).to(tl.float32)
x0 = tl.load(X, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half), other=0.0).to(
tl.float32
)
x1 = tl.load(
X + stride_x_headdim * (1 if INTERLEAVED else rotary_dim_half),
mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half),
other=0.0,
).to(tl.float32)
if not CONJUGATE:
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos
else:
o0 = x0 * cos + x1 * sin
o1 = -x0 * sin + x1 * cos
# write back result
OUT = OUT + (
pid_batch * stride_out_batch
+ rm[:, None] * stride_out_seqlen
+ pid_head * stride_out_nheads
+ rk[None, :] * stride_out_headdim * (2 if INTERLEAVED else 1)
)
tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half))
tl.store(
OUT + stride_out_headdim * (1 if INTERLEAVED else rotary_dim_half),
o1,
mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half),
)
def apply_rotary(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
seqlen_offsets: Union[int, torch.Tensor] = 0,
interleaved=False,
inplace=False,
conjugate=False,
) -> torch.Tensor:
"""
Arguments:
x: (batch, seqlen, nheads, headdim)
cos: (seqlen_ro, rotary_dim / 2)
sin: (seqlen_ro, rotary_dim / 2)
seqlen_offsets: integer or integer tensor of size (batch,)
Returns:
y: (batch, seqlen, nheads, headdim)
"""
batch, seqlen, nheads, headdim = x.shape
seqlen_ro, rotary_dim = cos.shape
assert sin.shape == cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
assert headdim <= 256, "Only support headdim <= 256"
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
assert (
cos.dtype == sin.dtype
), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
assert (
x.dtype == cos.dtype
), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
cos, sin = cos.contiguous(), sin.contiguous()
if isinstance(seqlen_offsets, torch.Tensor):
assert seqlen_offsets.shape == (batch,)
assert seqlen_offsets.dtype in [torch.int32, torch.int64]
seqlen_offsets = seqlen_offsets.contiguous()
else:
assert seqlen_offsets + seqlen <= seqlen_ro
output = torch.empty_like(x) if not inplace else x
if rotary_dim < headdim and not inplace:
output[..., rotary_dim:].copy_(x[..., rotary_dim:])
BLOCK_K = (
32
if rotary_dim <= 32
else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
)
grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
rotary_kernel[grid](
output, # data ptrs
x,
cos,
sin,
seqlen_offsets,
seqlen, # shapes
nheads,
rotary_dim,
seqlen_ro,
seqlen // 128, # key for triton cache (limit number of compilations)
output.stride(0), # strides
output.stride(1),
output.stride(2),
output.stride(3),
x.stride(0),
x.stride(1),
x.stride(2),
x.stride(3),
BLOCK_K,
isinstance(seqlen_offsets, torch.Tensor),
interleaved,
conjugate,
BLOCK_M,
)
return output
...@@ -131,6 +131,8 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): ...@@ -131,6 +131,8 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
) )
print(out_cg.sequences) print(out_cg.sequences)
parallel_state.destroy_model_parallel()
if not rotary: if not rotary:
out_hf = model_hf.generate( out_hf = model_hf.generate(
input_ids=input_ids, input_ids=input_ids,
...@@ -171,5 +173,3 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): ...@@ -171,5 +173,3 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
).abs().max().item() < 3 * ( ).abs().max().item() < 3 * (
torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1) torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item() ).abs().max().item()
parallel_state.destroy_model_parallel()
import math import math
import random
import pytest import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from flash_attn.layers.rotary import apply_rotary_emb_func, apply_rotary_emb_torch from flash_attn.layers.rotary import apply_rotary_emb, apply_rotary_emb_torch
from flash_attn.layers.rotary import apply_rotary_emb_qkv_, apply_rotary_emb_kv_
is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0) is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
...@@ -13,33 +15,198 @@ is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0) ...@@ -13,33 +15,198 @@ is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16]) "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
) )
# @pytest.mark.parametrize('dtype', ([torch.float16])) # @pytest.mark.parametrize('dtype', ([torch.float16]))
@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5]) @pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
# @pytest.mark.parametrize('rotary_fraction', [0.5]) # @pytest.mark.parametrize('rotary_fraction', [1.0])
@pytest.mark.parametrize("interleaved", [False, True])
# @pytest.mark.parametrize('interleaved', [False])
@pytest.mark.parametrize("inplace", [False, True]) @pytest.mark.parametrize("inplace", [False, True])
# @pytest.mark.parametrize('inplace', [False]) # @pytest.mark.parametrize('inplace', [False])
def test_rotary_single_tensor(inplace, rotary_fraction, dtype): def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype):
rtol = 1e-3 rtol = 1e-3
batch_size = 32 batch_size = 32
nheads = 4 nheads = 4
seqlen = 217 seqlen = 217
headdim = 128 headdim = 128
device = "cuda"
torch.manual_seed(42)
x = torch.randn( x = torch.randn(
batch_size, seqlen, nheads, headdim, dtype=dtype, device="cuda", requires_grad=True batch_size, seqlen, nheads, headdim, dtype=dtype, device=device, requires_grad=True
) )
x_pt = x.detach().clone().requires_grad_() x_pt = x.detach().clone().requires_grad_()
rotary_dim = int(rotary_fraction * headdim) rotary_dim = int(rotary_fraction * headdim)
assert rotary_dim % 2 == 0 assert rotary_dim % 2 == 0
angle = torch.randn(seqlen, rotary_dim // 2, device="cuda") angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype) cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype) sin = torch.sin(angle).to(dtype=dtype)
out = apply_rotary_emb_func(x, cos, sin, inplace) if seqlen_offsets_type == 0:
out_pt = apply_rotary_emb_torch(x_pt, cos, sin) seqlen_offsets = 0
# Numerical error if we just do any arithmetic elif seqlen_offsets_type is int:
atol = ((out + 0.3 - 0.3) - out).abs().max().item() seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item()
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol) elif seqlen_offsets_type is torch.Tensor:
seqlen_offsets = torch.randint(
0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device
)
out = apply_rotary_emb(
x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace
)
if seqlen_offsets_type is torch.Tensor:
arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s")
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
else:
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
out_pt = apply_rotary_emb_torch(
x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
print(f"Output max diff: {(out - out_pt).abs().max().item()}")
g = torch.randn_like(out) g = torch.randn_like(out)
g_pt = g.clone() # If inplace=True, we might modify the gradient inplace g_pt = g.clone() # If inplace=True, we might modify the gradient inplace
out.backward(g) out.backward(g)
out_pt.backward(g_pt) out_pt.backward(g_pt)
print(f"Grad max diff: {(x.grad - x_pt.grad).abs().max().item()}")
if not inplace:
assert torch.equal(x, x_pt)
# Numerical error if we just do any arithmetic
atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item() atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=2 * atol) assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=2 * atol)
@pytest.mark.parametrize(
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@pytest.mark.parametrize("interleaved", [False, True])
# @pytest.mark.parametrize('interleaved', [False])
def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, dtype):
rtol = 1e-3
batch_size = 32
nheads = 4
seqlen = 512
headdim = 128
device = "cuda"
torch.manual_seed(42)
qkv = torch.randn(
batch_size, seqlen, 3, nheads, headdim, dtype=dtype, device=device, requires_grad=True
)
qkv_pt = qkv.detach().clone().requires_grad_()
rotary_dim = int(rotary_fraction * headdim)
assert rotary_dim % 2 == 0
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
if seqlen_offsets_type == 0:
seqlen_offsets = 0
elif seqlen_offsets_type is int:
seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item()
elif seqlen_offsets_type is torch.Tensor:
seqlen_offsets = torch.randint(
0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device
)
out = apply_rotary_emb_qkv_(
qkv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved
)
if seqlen_offsets_type is torch.Tensor:
arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s")
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
else:
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
q_pt = apply_rotary_emb_torch(
qkv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
k_pt = apply_rotary_emb_torch(
qkv_pt[:, :, 1].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
out_pt = torch.stack([q_pt, k_pt, qkv_pt[:, :, 2]], dim=2)
print(f"Output max diff: {(out - out_pt).abs().max().item()}")
g = torch.randn_like(out)
g_pt = g.clone() # Since inplace=True, we modify the gradient inplace
out.backward(g)
out_pt.backward(g_pt)
print(f"Grad max diff: {(qkv.grad - qkv_pt.grad).abs().max().item()}")
# Numerical error if we just do any arithmetic
atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
atol = ((qkv_pt.grad + 0.3 - 0.3) - qkv_pt.grad).abs().max().item()
assert torch.allclose(qkv.grad, qkv_pt.grad, rtol=rtol, atol=2 * atol)
@pytest.mark.parametrize(
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@pytest.mark.parametrize("interleaved", [False, True])
# @pytest.mark.parametrize('interleaved', [False])
def test_rotary_emb_kv(interleaved, rotary_fraction, seqlen_offsets_type, dtype):
rtol = 1e-3
batch_size = 32
nheads = 4
seqlen = 781
headdim = 64
device = "cuda"
torch.manual_seed(42)
kv = torch.randn(
batch_size, seqlen, 2, nheads, headdim, dtype=dtype, device=device, requires_grad=True
)
kv_pt = kv.detach().clone().requires_grad_()
rotary_dim = int(rotary_fraction * headdim)
assert rotary_dim % 2 == 0
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
if seqlen_offsets_type == 0:
seqlen_offsets = 0
elif seqlen_offsets_type is int:
seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item()
elif seqlen_offsets_type is torch.Tensor:
seqlen_offsets = torch.randint(
0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device
)
out = apply_rotary_emb_kv_(
kv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved
)
if seqlen_offsets_type is torch.Tensor:
arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s")
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
else:
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
k_pt = apply_rotary_emb_torch(
kv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
out_pt = torch.stack([k_pt, kv_pt[:, :, 1]], dim=2)
print(f"Output max diff: {(out - out_pt).abs().max().item()}")
g = torch.randn_like(out)
g_pt = g.clone() # Since inplace=True, we modify the gradient inplace
out.backward(g)
out_pt.backward(g_pt)
print(f"Grad max diff: {(kv.grad - kv_pt.grad).abs().max().item()}")
# Numerical error if we just do any arithmetic
atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
atol = ((kv_pt.grad + 0.3 - 0.3) - kv_pt.grad).abs().max().item()
assert torch.allclose(kv.grad, kv_pt.grad, rtol=rtol, atol=2 * atol)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment