Commit ca81f32e authored by Tri Dao's avatar Tri Dao
Browse files

Implement rotary embedding in CUDA

parent 62025e1a
#include <torch/extension.h>
#define CHECK_DEVICE(x) \
TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) \
TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \
#x " must have shape (" #__VA_ARGS__ ")")
void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2,
const torch::Tensor cos, const torch::Tensor sin,
torch::Tensor out1, torch::Tensor out2,
const bool conj);
void apply_rotary(const torch::Tensor x1, const torch::Tensor x2,
const torch::Tensor cos, const torch::Tensor sin,
torch::Tensor out1, torch::Tensor out2,
const bool conj) {
CHECK_DEVICE(x1); CHECK_DEVICE(x2);
CHECK_DEVICE(cos); CHECK_DEVICE(sin);
CHECK_DEVICE(out1); CHECK_DEVICE(out1);
TORCH_CHECK(x1.dtype() == x2.dtype());
TORCH_CHECK(cos.dtype() == sin.dtype());
TORCH_CHECK(out1.dtype() == out2.dtype());
TORCH_CHECK(x1.dtype() == cos.dtype());
TORCH_CHECK(x1.dtype() == out1.dtype());
TORCH_CHECK(x1.sizes() == x2.sizes());
TORCH_CHECK(cos.sizes() == sin.sizes());
TORCH_CHECK(out1.sizes() == out2.sizes());
apply_rotary_cuda(x1, x2, cos, sin, out1, out2, conj);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("apply_rotary", &apply_rotary, "Apply rotary embedding");
}
#include <torch/python.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Loops.cuh>
void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2,
const torch::Tensor cos, const torch::Tensor sin,
torch::Tensor out1, torch::Tensor out2,
const bool conj) {
auto iter = at::TensorIteratorConfig()
.add_output(out1)
.add_output(out2)
.add_input(x1)
.add_input(x2)
.add_input(cos)
.add_input(sin)
.check_all_same_dtype(false)
.promote_inputs_to_common_dtype(false)
.build();
if (!conj) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] {
at::native::gpu_kernel_multiple_outputs(
iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos,
scalar_t sin) -> thrust::tuple<scalar_t, scalar_t> {
scalar_t out1 = float(x1) * float(cos) - float(x2) * float(sin);
scalar_t out2 = float(x1) * float(sin) + float(x2) * float(cos);
return {out1, out2};
});
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] {
at::native::gpu_kernel_multiple_outputs(
iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos,
scalar_t sin) -> thrust::tuple<scalar_t, scalar_t> {
scalar_t out1 = float(x1) * float(cos) + float(x2) * float(sin);
scalar_t out2 = -float(x1) * float(sin) + float(x2) * float(cos);
return {out1, out2};
});
});
}
}
\ No newline at end of file
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from setuptools import setup, find_packages
import subprocess
import sys
import warnings
import os
# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
torch_binary_major = torch.version.cuda.split(".")[0]
torch_binary_minor = torch.version.cuda.split(".")[1]
print("\nCompiling cuda extensions with")
print(raw_output + "from " + cuda_dir + "/bin\n")
if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
raise RuntimeError(
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
"Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda)
+ "In some cases, a minor-version mismatch will not cause later errors: "
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
"You can try commenting out this check (at your own risk)."
)
def raise_if_cuda_home_none(global_option: str) -> None:
if CUDA_HOME is not None:
return
raise RuntimeError(
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
"only images whose names contain 'devel' will provide nvcc."
)
def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args
if not torch.cuda.is_available():
# https://github.com/NVIDIA/apex/issues/486
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
# which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
print(
"\nWarning: Torch did not find available GPUs on this system.\n",
"If your intention is to cross-compile, this is not an error.\n"
"By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n"
"Volta (compute capability 7.0), Turing (compute capability 7.5),\n"
"and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n"
"If you wish to cross-compile for a single specific architecture,\n"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
)
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) == 11:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
if int(bare_metal_minor) > 0:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
else:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
cmdclass = {}
ext_modules = []
raise_if_cuda_home_none("rotary_emb")
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
cc_flag.append("-gencode")
cc_flag.append("arch=compute_70,code=sm_70")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
ext_modules.append(
CUDAExtension(
'rotary_emb', [
'rotary.cpp',
'rotary_cuda.cu',
],
extra_compile_args={'cxx': ['-g', '-march=native', '-funroll-loops'],
'nvcc': append_nvcc_threads([
'-O3', '--use_fast_math', '--expt-extended-lambda'
] + cc_flag)
}
)
)
setup(
name="rotary_emb",
version="0.1",
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension} if ext_modules else {},
)
# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/components/positional_embedding/rotary.py
# We split the input differently ((d 2) -> d 2 instead of (2 d) -> d 2), following the original
# paper's implementation. This should not matter.
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: This implementation is inspired by GPT-NeoX https://github.com/EleutherAI/gpt-neox
# NOTE: Almost the same right now, moving parts to Triton is the next step
# Inspired by https://github.com/facebookresearch/xformers/blob/main/xformers/components/positional_embedding/rotary.py
from typing import Tuple
import math
......@@ -18,28 +7,118 @@ import torch
from einops import rearrange, repeat
import rotary_emb
def rotate_half(x):
# rearrange doesn't work with torch.jit
# x = rearrange(x, '... (d r) -> ... d r', r=2)
x = x.unflatten(dim=-1, sizes=(-1, 2))
x1, x2 = x.unbind(dim=-1)
rotated_x = torch.stack((-x2, x1), dim=-1)
# return rearrange(rotated_x, '... d r -> ... (d r)')
return rotated_x.flatten(start_dim=-2)
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
@torch.jit.script
def apply_rotary_pos_emb(x, cos, sin, seq_dimension: int = -2):
# NOTE: This could probably be moved to Triton
# Handle a possible sequence length mismatch in between q and k
cos = cos[:x.shape[seq_dimension], :]
sin = sin[:x.shape[seq_dimension], :]
if seq_dimension == -3:
cos = cos[:, None, :]
sin = sin[:, None, :]
return (x * cos) + (rotate_half(x) * sin)
def apply_rotary_emb_torch(x, cos, sin):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
"""
rotary_dim = cos.shape[-1] * 2
assert rotary_dim <= x.shape[-1]
cos = repeat(cos, 's d -> s 1 (2 d)')
sin = repeat(sin, 's d -> s 1 (2 d)')
return torch.cat([x[..., :rotary_dim] * cos + rotate_half(x[..., :rotary_dim]) * sin,
x[..., rotary_dim:]], dim=-1)
class ApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def forward(ctx, x, cos, sin, inplace=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
batch, seqlen, nheads, headdim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert cos.shape == (rotary_seqlen, rotary_dim // 2)
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x1, x2 = x[..., :rotary_dim].chunk(2, dim=-1)
out = torch.empty_like(x) if not inplace else x
o1, o2 = out[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (x1, x2)
rotary_emb.apply_rotary(x1, x2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
rearrange(sin[:, :seqlen], 's d -> s 1 d'), o1, o2, False)
if not inplace and rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin)
ctx.inplace = inplace
return out if not inplace else x
@staticmethod
def backward(ctx, do):
cos, sin = ctx.saved_tensors
_, seqlen, _, headdim = do.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
inplace = ctx.inplace
do1, do2 = do[..., :rotary_dim].chunk(2, dim=-1)
dx = torch.empty_like(do) if not inplace else do
dx1, dx2 = dx[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (do1, do2)
rotary_emb.apply_rotary(do1, do2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
rearrange(sin[:, :seqlen], 's d -> s 1 d'), dx1, dx2, True)
if not inplace and rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None
apply_rotary_emb_func = ApplyRotaryEmb.apply
class ApplyRotaryEmbQKV_(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cos, sin):
"""
qkv: (batch_size, seqlen, 3, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
"""
batch, seqlen, three, nheads, headdim = qkv.shape
assert three == 3
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert cos.shape == (seqlen, rotary_dim // 2)
assert sin.shape == (seqlen, rotary_dim // 2)
q1, q2 = qkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1)
rotary_emb.apply_rotary(q1, q2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
rearrange(sin[:, :seqlen], 's d -> s 1 d'), q1, q2, False)
k1, k2 = qkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1)
rotary_emb.apply_rotary(k1, k2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
rearrange(sin[:, :seqlen], 's d -> s 1 d'), k1, k2, False)
ctx.save_for_backward(cos, sin)
return qkv
@staticmethod
def backward(ctx, dqkv):
cos, sin = ctx.saved_tensors
_, seqlen, _, _, headdim = dqkv.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
dq1, dq2 = dqkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1)
rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
rearrange(sin[:, :seqlen], 's d -> s 1 d'), dq1, dq2, True)
dk1, dk2 = dqkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1)
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
rearrange(sin[:, :seqlen], 's d -> s 1 d'), dk1, dk2, True)
return dqkv, None, None
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
class RotaryEmbedding(torch.nn.Module):
......@@ -55,9 +134,6 @@ class RotaryEmbedding(torch.nn.Module):
.. _repo: https://github.com/ZhuiyiTechnology/roformer
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
.. warning: Please note that this embedding is not registered on purpose, as it is transformative
(it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
"""
def __init__(self, dim_model: int, *_, **__):
......@@ -66,70 +142,26 @@ class RotaryEmbedding(torch.nn.Module):
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model))
self.register_buffer("inv_freq", inv_freq)
self._seq_len_cached = None
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_tables(self, x, seq_dimension=-2):
seq_len = x.shape[seq_dimension]
def _update_cos_sin_cache(self, x):
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)
"""
seqlen = x.shape[1]
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (seq_len != self._seq_len_cached or self._cos_cached.device != x.device
if (seqlen > self._seq_len_cached or self._cos_cached.device != x.device
or self._cos_cached.dtype != x.dtype):
self._seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dimension], device=x.device, dtype=self.inv_freq.dtype)
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
self._cos_cached = repeat(torch.cos(freqs).to(x.dtype), '... d -> ... (d 2)')
self._sin_cached = repeat(torch.sin(freqs).to(x.dtype), '... d -> ... (d 2)')
return self._cos_cached, self._sin_cached
def forward(self, q: torch.Tensor, k: torch.Tensor,
seq_dimension=-2) -> Tuple[torch.Tensor, torch.Tensor]:
assert seq_dimension in [-2, -3] # Either (bs, h, s, d) or (bs, s, h, d)
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
k, seq_dimension=seq_dimension
)
return (
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, seq_dimension),
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, seq_dimension),
)
self._cos_cached = torch.cos(freqs).to(x.dtype)
self._sin_cached = torch.sin(freqs).to(x.dtype)
class RotaryEmbedding2D(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
assert dim % 4 == 0
self.rotary_emb1d = RotaryEmbedding(dim // 2)
def forward(self, q: torch.Tensor, k: torch.Tensor, seq_dimension=-2):
assert seq_dimension in [-2, -3] # Either (bs, h, s, d) or (bs, s, h, d)
seqlen = q.shape[seq_dimension]
seqlen_sqrt = int(math.sqrt(seqlen))
assert seqlen == seqlen_sqrt ** 2
if seq_dimension == -3: # (bs, s, h, d)
q = rearrange(q, 'b s h d -> b h s d')
k = rearrange(k, 'b s h d -> b h s d')
q0, q1 = q.chunk(2, dim=-1)
k0, k1 = k.chunk(2, dim=-1)
# (bs, h, s, d)
q0 = rearrange(q0, 'b nheads (h w) d -> b nheads h w d', h=seqlen_sqrt)
k0 = rearrange(k0, 'b nheads (h w) d -> b nheads h w d', h=seqlen_sqrt)
q0_emb, k0_emb = self.rotary_emb1d(q0, k0, seq_dimension=-2)
q0_emb = rearrange(q0_emb, 'b nheads h w d -> b nheads (h w) d')
k0_emb = rearrange(k0_emb, 'b nheads h w d -> b nheads (h w) d')
q1 = rearrange(q1, 'b nheads (h w) d -> b nheads h w d', h=seqlen_sqrt)
k1 = rearrange(k1, 'b nheads (h w) d -> b nheads h w d', h=seqlen_sqrt)
q1_emb, k1_emb = self.rotary_emb1d(q1, k1, seq_dimension=-3)
q1_emb = rearrange(q1_emb, 'b nheads h w d -> b nheads (h w) d')
k1_emb = rearrange(k1_emb, 'b nheads h w d -> b nheads (h w) d')
q_emb, k_emb = torch.cat([q0_emb, q1_emb], dim=-1), torch.cat([k0_emb, k1_emb], dim=-1)
if seq_dimension == -3:
q_emb = rearrange(q_emb, 'b h s d -> b s h d')
k_emb = rearrange(k_emb, 'b h s d -> b s h d')
return q_emb, k_emb
def forward(self, qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
self._update_cos_sin_cache(qkv)
return apply_rotary_emb_qkv_(qkv, self._cos_cached, self._sin_cached)
import math
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange
from flash_attn.rotary import apply_rotary_emb_func, apply_rotary_emb_torch
is_sm8x = torch.cuda.get_device_capability('cuda') >= (8, 0)
@pytest.mark.parametrize('dtype', ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@pytest.mark.parametrize('rotary_fraction', [1.0, 0.5])
# @pytest.mark.parametrize('rotary_fraction', [0.5])
@pytest.mark.parametrize('inplace', [False, True])
# @pytest.mark.parametrize('inplace', [False])
def test_rotary_single_tensor(inplace, rotary_fraction, dtype):
rtol = 1e-3
batch_size = 32
nheads = 4
seqlen = 217
headdim = 128
x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device='cuda',
requires_grad=True)
x_pt = x.detach().clone().requires_grad_()
rotary_dim = int(rotary_fraction * headdim)
assert rotary_dim % 2 == 0
angle = torch.randn(seqlen, rotary_dim // 2, device='cuda')
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
out = apply_rotary_emb_func(x, cos, sin, inplace)
out_pt = apply_rotary_emb_torch(x_pt, cos, sin)
# Numerical error if we just do any arithmetic
atol = ((out + 0.3 - 0.3) - out).abs().max().item()
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
g = torch.randn_like(out)
g_pt = g.clone() # If inplace=True, we might modify the gradient inplace
out.backward(g)
out_pt.backward(g_pt)
atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=2 * atol)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment