Commit a11c313a authored by Casper Hansen's avatar Casper Hansen
Browse files

Update kernel, remove unused code

parent fbfa9d82
......@@ -70,7 +70,7 @@ from typing import List, Tuple, Union
from awq.utils.utils import set_module_name
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.norm import FTLlamaRMSNorm
from awq.modules.fused.attn import QuantLlamaAttention, QuantLlamaAttentionFused
from awq.modules.fused.attn import QuantLlamaAttentionFused
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP
......@@ -96,16 +96,7 @@ class LlamaFuser:
def fuse_attention(self):
for name, module in self.attention_modules:
qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv2(module)
# attn = QuantLlamaAttention(
# module.hidden_size,
# module.num_heads,
# module.num_key_value_heads,
# qkv_layer,
# module.o_proj,
# next(iter(qkv_layer.state_dict().values())).device,
# self.model.config.max_new_tokens
# )
qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv(module)
attn = QuantLlamaAttentionFused(
module.hidden_size,
module.num_heads,
......@@ -116,21 +107,9 @@ class LlamaFuser:
)
set_module_name(self.model, name, attn)
def _fuse_qkv2(self, module: LlamaAttention):
q_proj = module.q_proj
k_proj = module.k_proj
v_proj = module.v_proj
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
# g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
g_idx = None
bias = (
torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0)
if q_proj.bias is not None
else None
)
def _fuse_qkv(self, module: LlamaAttention):
q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
qkv_layer = WQLinear_GEMV(
q_proj.w_bit,
......@@ -140,39 +119,10 @@ class LlamaFuser:
q_proj.bias is not None,
q_proj.qweight.device,
)
qkv_layer.qweight = qweights
qkv_layer.qzeros = qzeros
qkv_layer.scales = scales
qkv_layer.bias = bias
qkv_layer.split_k_iters = q_proj.split_k_iters
return qkv_layer
def _fuse_qkv(self, module: LlamaAttention):
# get qkv and bias
q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
# create module
if self.quant_config["version"] == 'GEMM':
qkv_module = WQLinear_GEMM
elif self.quant_config["version"] == 'GEMV':
qkv_module = WQLinear_GEMV
qkv_layer = qkv_module(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
next(iter(module.state_dict().values())).device
)
# replace buffers with real weights
qkv_layer.qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
qkv_layer.bias = bias
qkv_layer.split_k_iters = q_proj.split_k_iters
......
......@@ -4,74 +4,6 @@ import torch.nn as nn
import awq_inference_engine
from torch.nn import functional as F
try:
from flash_attn import flash_attn_func
FLASH_INSTALLED = True
except:
FLASH_INSTALLED = False
class QuantLlamaRotary(nn.Module):
def __init__(self, dim=4096, max_position_embeddings=2048, base=10000, device=None,
is_neox=True, num_heads=None, num_kv_heads=None):
super().__init__()
self.dim = dim
self.is_neox = is_neox
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
# create cache
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2, device=device) / dim))
t = torch.arange(max_position_embeddings, device=device).float()
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1).to(torch.get_default_dtype())
# Embedding size: [max_position, rotary_dim]
self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
def forward(
self,
qkv_states: torch.Tensor,
position_ids: torch.Tensor,
batch_size: int,
q_len: int
):
# get qkv
query, key, value = qkv_states.chunk(chunks=3, dim=-1)
del qkv_states
# [num_tokens, num_heads * head_size]
query_batch_size, query_len, _ = query.shape
query = query.view(query_len*query_batch_size, self.num_heads * self.dim)
# [num_tokens, num_kv_heads * head_size]
key_batch_size, key_len, _ = key.shape
key = key.view(key_len*key_batch_size, self.num_kv_heads * self.dim)
# [num_tokens]
positions = position_ids.view(-1).to(query.device)
# Apply rotary embedding to the query and key before passing them
# to the attention op.
query = query.contiguous()
key = key.contiguous()
awq_inference_engine.rotary_embedding(
positions,
query,
key,
self.dim,
self.cos_sin_cache,
self.is_neox
)
query = query.view(batch_size, q_len, self.num_heads, self.dim).transpose(1, 2)
key = key.view(batch_size, q_len, self.num_heads, self.dim).transpose(1, 2)
value = value.view(batch_size, q_len, self.num_heads, self.dim).transpose(1, 2)
return query, key, value
class QuantLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
......@@ -120,139 +52,16 @@ class QuantLlamaRotaryEmbedding(nn.Module):
# print(positions.shape, query.shape, key.shape, self.cos_sin_cache.shape)
query = query.contiguous()
key = key.contiguous()
awq_inference_engine.rotary_embedding(
awq_inference_engine.rotary_embedding_neox(
positions,
query,
key,
self.dim,
self.cos_sin_cache,
True
self.cos_sin_cache
)
return query, key
class TorchAttention(nn.Module):
def __init__(self, hidden_size, use_flash=False):
super().__init__()
self.hidden_size = hidden_size
self.use_flash = use_flash
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
use_cache: bool,
past_key_value: torch.Tensor,
batch_size: int,
q_len: int
):
is_causal = past_key_value is None
kv_seq_len = q_len
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
value = value.to(key.device)
if past_key_value is not None:
# reuse k, v, self_attention
key = torch.cat([past_key_value[0], key], dim=2)
value = torch.cat([past_key_value[1], value], dim=2)
if use_cache:
# Since qkv_proj is fused, query etc will hold a reference to the original qkv_states tensor
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
key = key.contiguous()
value = value.contiguous()
query = query.contiguous()
past_key_value = (key, value) if use_cache else None
if self.use_flash and FLASH_INSTALLED:
query = query.transpose(1,2)
key = key.transpose(1,2)
value = value.transpose(1,2)
attn_output = flash_attn_func(query, key, value, causal=is_causal)
else:
attn_output = F.scaled_dot_product_attention(query, key, value, is_causal=is_causal)
del query, key, value
attn_output = attn_output.transpose(1, 2).reshape(batch_size, q_len, self.hidden_size)
return attn_output, past_key_value
class QuantLlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
hidden_size,
num_heads,
num_kv_heads,
qkv_proj,
o_proj,
dev,
max_new_tokens
):
super().__init__()
self.qkv_proj = qkv_proj
self.o_proj = o_proj
self.attn = TorchAttention(hidden_size)
self.rotary_emb = QuantLlamaRotary(
dim=hidden_size // num_heads,
max_position_embeddings=max_new_tokens,
device=dev,
is_neox=True,
num_heads=num_heads,
num_kv_heads=num_heads
)
def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False):
"""Input shape: Batch x Time x Channel"""
batch_size, q_len, _ = hidden_states.size()
qkv_states = self.qkv_proj(hidden_states)
query, key, value = self.rotary_emb(qkv_states, position_ids, batch_size, q_len)
attn_output, past_key_value = self.attn(query, key, value, use_cache, past_key_value, batch_size, q_len)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
):
xq_ = torch.view_as_complex(
xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
xk_ = torch.view_as_complex(
xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class QuantLlamaAttentionFused(nn.Module):
def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, max_position_embeddings):
super().__init__()
......@@ -263,42 +72,24 @@ class QuantLlamaAttentionFused(nn.Module):
self.o_proj = o_proj
self.start_pos = 0
self.freqs_cis = precompute_freqs_cis(
self.head_dim ,
max_position_embeddings * 2,
)
# following fastertransformer definition
self.cache_v = (
torch.zeros(
(
1,
self.n_local_heads,
max_position_embeddings,
self.head_dim,
)
)
.to(dev)
.half()
) # added to half
( 1, self.n_local_heads, max_position_embeddings, self.head_dim, )
).to(dev).half()
)
# 8: pack 8 fp16 in FT, if fp32 then use 4
self.cache_k = (
torch.zeros(
(
1,
self.n_local_heads,
self.head_dim // 8,
max_position_embeddings,
8,
)
)
.to(dev)
.half()
) # added to half
( 1, self.n_local_heads, self.head_dim // 8, max_position_embeddings, 8, )
).to(dev).half()
)
# dummy
self.rotary_emb = QuantLlamaRotaryEmbedding(
hidden_size // num_heads, max_position_embeddings=max_position_embeddings, base=10000, device=dev
dim=hidden_size // num_heads,
max_position_embeddings=max_position_embeddings,
device=dev
)
def forward(
......
#pragma once
#include <torch/extension.h>
void rotary_embedding(
void rotary_embedding_neox(
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache,
bool is_neox);
\ No newline at end of file
torch::Tensor& cos_sin_cache);
\ No newline at end of file
......@@ -9,56 +9,15 @@ https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu
#include <ATen/cuda/CUDAContext.h>
#include "pos_encoding.h"
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
template<typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding(
scalar_t* __restrict__ arr,
const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ sin_ptr,
int rot_offset,
int embed_dim)
{
int x_index, y_index;
scalar_t cos, sin;
if (IS_NEOX) {
// GPT-NeoX style rotary embedding.
x_index = rot_offset;
y_index = embed_dim + rot_offset;
cos = __ldg(cos_ptr + x_index);
sin = __ldg(sin_ptr + x_index);
} else {
// GPT-J style rotary embedding.
x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1;
cos = __ldg(cos_ptr + x_index / 2);
sin = __ldg(sin_ptr + x_index / 2);
}
const scalar_t x = arr[x_index];
const scalar_t y = arr[y_index];
arr[x_index] = x * cos - y * sin;
arr[y_index] = y * cos + x * sin;
}
template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel(
template<typename scalar_t>
__global__ void rotary_embedding_neox_kernel(
const int64_t* __restrict__ positions, // [num_tokens]
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int query_stride,
const int key_stride,
const int stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
......@@ -66,72 +25,64 @@ __global__ void rotary_embedding_kernel(
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const int embed_dim = rot_dim / 2;
const scalar_t* cos_ptr = cache_ptr;
const scalar_t* sin_ptr = cache_ptr + embed_dim;
const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int n = num_heads * embed_dim;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
sin_ptr, rot_offset, embed_dim);
}
const int token_head = token_idx * stride + head_idx * head_size;
const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
sin_ptr, rot_offset, embed_dim);
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int out_x = token_idx * stride + head_idx * head_size + x_index;
const int out_y = token_idx * stride + head_idx * head_size + y_index;
const scalar_t cos = __ldg(cache_ptr + x_index);
const scalar_t sin = __ldg(cache_ptr + y_index);
const scalar_t q_x = query[token_head + x_index];
const scalar_t q_y = query[token_head + y_index];
query[out_x] = q_x * cos - q_y * sin;
query[out_y] = q_y * cos + q_x * sin;
const scalar_t k_x = key[token_head + x_index];
const scalar_t k_y = key[token_head + y_index];
key[out_x] = k_x * cos - k_y * sin;
key[out_y] = k_y * cos + k_x * sin;
}
}
void rotary_embedding(
torch::Tensor& positions, // [num_tokens]
torch::Tensor& query, // [num_tokens, num_heads * head_size]
torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
void rotary_embedding_neox(
torch::Tensor& positions, // [b, num_tokens]
torch::Tensor& query, // [b, num_tokens, 1, num_heads, head_size]
torch::Tensor& key, // [b, num_tokens, 1, num_heads, head_size]
int head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
int num_tokens = query.size(0);
torch::Tensor& cos_sin_cache) // [max_position, rot_dim]
{
int num_tokens = query.size(0) * query.size(1);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(1) / head_size;
int num_kv_heads = key.size(1) / head_size;
int query_stride = query.stride(0);
int key_stride = key.stride(0);
int num_heads = query.size(-2);
int stride = num_heads * head_size;
// TORCH_CHECK(stride == key.stride(0));
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
query.scalar_type(),
"rotary_embedding",
"rotary_embedding_neox",
[&] {
if (is_neox) {
rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} else {
rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
stride,
num_heads,
head_size);
});
}
\ No newline at end of file
}
......@@ -11,7 +11,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("layernorm_forward_cuda", &layernorm_forward_cuda, "FasterTransformer layernorm kernel");
m.def("gemm_forward_cuda", &gemm_forward_cuda, "Quantized GEMM kernel.");
m.def("gemv_forward_cuda", &gemv_forward_cuda, "Quantized GEMV kernel.");
m.def("rotary_embedding", &rotary_embedding, "Apply rotary embedding to query and key");
m.def("rotary_embedding_neox", &rotary_embedding_neox, "Apply GPT-NeoX style rotary embedding to query and key");
m.def("single_query_attention", &single_query_attention, "Attention with a single query",
py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"),
py::arg("length_per_sample_"), py::arg("alibi_slopes_"), py::arg("timestep"), py::arg("rotary_embedding_dim")=0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment