Commit 4fe9974a authored by Casper Hansen's avatar Casper Hansen
Browse files

Update neox kernel

parent 8bb8fe20
...@@ -71,7 +71,7 @@ from awq.quantize.qmodule import WQLinear ...@@ -71,7 +71,7 @@ from awq.quantize.qmodule import WQLinear
from awq.utils.utils import set_module_name from awq.utils.utils import set_module_name
from awq.modules.fused_mlp import QuantLlamaMLP from awq.modules.fused_mlp import QuantLlamaMLP
from awq.modules.fused_norm import FTLlamaRMSNorm from awq.modules.fused_norm import FTLlamaRMSNorm
from awq.modules.fused_attn import QuantLlamaAttention, CustomQuantLlamaAttention from awq.modules.fused_attn import QuantLlamaAttention
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP
class LlamaFuser: class LlamaFuser:
...@@ -96,7 +96,7 @@ class LlamaFuser: ...@@ -96,7 +96,7 @@ class LlamaFuser:
def fuse_attention(self): def fuse_attention(self):
for name, module in self.attention_modules: for name, module in self.attention_modules:
qkv_layer: WQLinear = self._fuse_qkv(module) qkv_layer: WQLinear = self._fuse_qkv(module)
attn = CustomQuantLlamaAttention( attn = QuantLlamaAttention(
module.hidden_size, module.hidden_size,
module.num_heads, module.num_heads,
qkv_layer, qkv_layer,
......
...@@ -2,22 +2,23 @@ import torch ...@@ -2,22 +2,23 @@ import torch
import torch.nn as nn import torch.nn as nn
import awq_inference_engine import awq_inference_engine
from torch.nn import functional as F from torch.nn import functional as F
from transformers.models.llama.modeling_llama import LlamaLinearScalingRotaryEmbedding
class QuantLlamaRotaryEmbedding(nn.Module): class RotaryEmbeddingNeox(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, head_dim, seq_len, device):
super().__init__() super().__init__()
self.head_dim = head_dim
self.seq_len = seq_len
self.base = 10000
self.dim = dim # create inv_frequency
self.max_position_embeddings = max_position_embeddings inv_freq = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2).float().to(device) / self.head_dim))
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq) self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype): # set cache
self._set_cos_sin_cache(seq_len=self.seq_len, device=self.inv_freq.device)
def _set_cos_sin_cache(self, seq_len, device):
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
...@@ -31,122 +32,89 @@ class QuantLlamaRotaryEmbedding(nn.Module): ...@@ -31,122 +32,89 @@ class QuantLlamaRotaryEmbedding(nn.Module):
self.register_buffer("cos_sin_cache", cache.half(), persistent=False) self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
def forward( def forward(self, positions, query, key):
self, batch_size, seq_len, _ = query.shape
query: torch.Tensor, query = query.view(batch_size * seq_len, -1)
key: torch.Tensor, key = key.view(batch_size * seq_len, -1)
positions: torch.Tensor, positions = positions.view(-1).to(query.device)
):
# Apply rotary embedding to the query and key before passing them
# to the attention op.
query = query.contiguous() query = query.contiguous()
key = key.contiguous() key = key.contiguous()
awq_inference_engine.rotary_embedding_neox( awq_inference_engine.rotary_embedding_neox(
positions, positions,
query, query,
key, key,
self.dim, self.head_dim,
self.cos_sin_cache, self.cos_sin_cache,
) )
return query, key query = query.view(batch_size, seq_len, -1)
key = key.view(batch_size, seq_len, -1)
return query, key
class QuantLlamaAttention(nn.Module): class QuantLlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__( def __init__(
self, self,
hidden_size, hidden_size,
num_heads, num_heads,
qkv_proj, qkv_proj,
o_proj, o_proj,
dev, device,
max_new_tokens max_new_tokens
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_heads = num_heads self.num_heads = num_heads
self.head_dim = hidden_size // num_heads self.head_dim = hidden_size // num_heads
self.seq_len = max_new_tokens
self.qkv_proj = qkv_proj
self.o_proj = o_proj
self.rotary_embedding_neox = RotaryEmbeddingNeox(self.head_dim, self.seq_len, device)
if (self.head_dim * num_heads) != self.hidden_size: if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {num_heads}).") f" and `num_heads`: {num_heads}).")
self.qkv_proj = qkv_proj
self.o_proj = o_proj
self.rotary_emb = QuantLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=max_new_tokens, device = dev)
def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False):
"""Input shape: Batch x Time x Channel"""
bsz, q_len, _ = hidden_states.size() def attn(self, query, key, value, past_key_value, use_cache, attention_mask):
batch_size, seq_len, _ = query.shape
qkv_states = self.qkv_proj(hidden_states) query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
qkv_states = qkv_states.view(bsz, q_len, 3, self.num_heads, self.head_dim) key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# This updates the query and key states in-place, saving VRAM.
query_states, key_states, value_states = torch.split(qkv_states, 1, dim=2)
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids)
del qkv_states value = value.to(key.device)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# cache ops
is_causal = past_key_value is None is_causal = past_key_value is None
kv_seq_len = q_len
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
value_states = value_states.to(key_states.device)
if past_key_value is not None: if past_key_value is not None:
# reuse k, v, self_attention # reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2) key = torch.cat([past_key_value[0], key], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2) value = torch.cat([past_key_value[1], value], dim=2)
if use_cache: if use_cache:
# Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor # Since qkv_proj is fused, query_states etc will hold a reference to the original qkv tensor
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this. # which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
key_states = key_states.contiguous() query = query.contiguous()
value_states = value_states.contiguous() key = key.contiguous()
query_states = query_states.contiguous() value = value.contiguous()
past_key_value = (key_states, value_states) if use_cache else None
# with torch.backends.cuda.sdp_kernel(enable_math=False):
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=is_causal)
del query_states, key_states, value_states
attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
past_key_value = (key, value) if use_cache else None
class CustomQuantLlamaAttention(nn.Module): # multi-head masked attention
def __init__( attn_output = F.scaled_dot_product_attention(
self, query,
hidden_size, key,
num_heads, value,
qkv_proj, attn_mask=None if is_causal else attention_mask,
o_proj, is_causal=is_causal
dev, )
max_new_tokens
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
if (self.head_dim * num_heads) != self.hidden_size: # reshape output
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" attn_output = attn_output.transpose(1, 2).contiguous()
f" and `num_heads`: {num_heads}).") attn_output = attn_output.reshape(batch_size, seq_len, self.hidden_size)
self.qkv_proj = qkv_proj
self.o_proj = o_proj
self.rotary_emb = QuantLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=max_new_tokens, device = dev)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return attn_output, past_key_value
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward( def forward(
self, self,
...@@ -158,46 +126,13 @@ class CustomQuantLlamaAttention(nn.Module): ...@@ -158,46 +126,13 @@ class CustomQuantLlamaAttention(nn.Module):
use_cache: bool = False, use_cache: bool = False,
): ):
# qkv proj # qkv proj
qkv_states = self.qkv_proj(hidden_states) query, key, value = self.qkv_proj(hidden_states).chunk(chunks=3, dim=-1)
# extract q,k,v # rotary embeddings
bsz, q_len, _ = hidden_states.size() query, key = self.rotary_embedding_neox(position_ids, query, key)
query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# rotary embedding # attention
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids) attn_output, past_key_value = self.attn(query, key, value, past_key_value, use_cache, attention_mask)
# cache ops
is_causal = past_key_value is None
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
if use_cache:
# Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
past_key_value = (key_states, value_states) if use_cache else None
# multi-head masked attention
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=None if is_causal else attention_mask,
is_causal=is_causal
)
# reshape output
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
# out projection # out projection
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
......
...@@ -9,15 +9,26 @@ https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu ...@@ -9,15 +9,26 @@ https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include "pos_encoding.h" #include "pos_encoding.h"
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
template<typename scalar_t> template<typename scalar_t>
__global__ void rotary_embedding_neox_kernel( __global__ void rotary_embedding_neox_kernel(
const int64_t* __restrict__ positions, // [num_tokens] const int64_t* __restrict__ positions, // [num_tokens]
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim, const int rot_dim,
const int stride, const int query_stride,
const int key_stride,
const int num_heads, const int num_heads,
const int num_kv_heads,
const int head_size) { const int head_size) {
// Each thread block is responsible for one token. // Each thread block is responsible for one token.
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
...@@ -25,17 +36,17 @@ __global__ void rotary_embedding_neox_kernel( ...@@ -25,17 +36,17 @@ __global__ void rotary_embedding_neox_kernel(
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const int embed_dim = rot_dim / 2; const int embed_dim = rot_dim / 2;
const int n = num_heads * embed_dim; const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < n; i += blockDim.x) { for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int token_head = token_idx * stride + head_idx * head_size; const int token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
const int x_index = rot_offset; const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset; const int y_index = embed_dim + rot_offset;
const int out_x = token_idx * stride + head_idx * head_size + x_index; const int out_x = token_idx * query_stride + head_idx * head_size + x_index;
const int out_y = token_idx * stride + head_idx * head_size + y_index; const int out_y = token_idx * query_stride + head_idx * head_size + y_index;
const scalar_t cos = __ldg(cache_ptr + x_index); const scalar_t cos = __ldg(cache_ptr + x_index);
const scalar_t sin = __ldg(cache_ptr + y_index); const scalar_t sin = __ldg(cache_ptr + y_index);
...@@ -44,6 +55,22 @@ __global__ void rotary_embedding_neox_kernel( ...@@ -44,6 +55,22 @@ __global__ void rotary_embedding_neox_kernel(
const scalar_t q_y = query[token_head + y_index]; const scalar_t q_y = query[token_head + y_index];
query[out_x] = q_x * cos - q_y * sin; query[out_x] = q_x * cos - q_y * sin;
query[out_y] = q_y * cos + q_x * sin; query[out_y] = q_y * cos + q_x * sin;
}
const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int out_x = token_idx * key_stride + head_idx * head_size + x_index;
const int out_y = token_idx * key_stride + head_idx * head_size + y_index;
const scalar_t cos = __ldg(cache_ptr + x_index);
const scalar_t sin = __ldg(cache_ptr + y_index);
const scalar_t k_x = key[token_head + x_index]; const scalar_t k_x = key[token_head + x_index];
const scalar_t k_y = key[token_head + y_index]; const scalar_t k_y = key[token_head + y_index];
...@@ -59,18 +86,17 @@ void rotary_embedding_neox( ...@@ -59,18 +86,17 @@ void rotary_embedding_neox(
int head_size, int head_size,
torch::Tensor& cos_sin_cache) // [max_position, rot_dim] torch::Tensor& cos_sin_cache) // [max_position, rot_dim]
{ {
int num_tokens = query.size(0) * query.size(1); int num_tokens = query.size(0);
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-2); int num_heads = query.size(1) / head_size;
int stride = num_heads * head_size; int num_kv_heads = key.size(1) / head_size;
// TORCH_CHECK(stride == key.stride(0)); int query_stride = query.stride(0);
int key_stride = key.stride(0);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512)); dim3 block(std::min(num_heads * rot_dim / 2, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2( VLLM_DISPATCH_FLOATING_TYPES(
at::ScalarType::Half,
at::ScalarType::BFloat16,
query.scalar_type(), query.scalar_type(),
"rotary_embedding_neox", "rotary_embedding_neox",
[&] { [&] {
...@@ -80,9 +106,12 @@ void rotary_embedding_neox( ...@@ -80,9 +106,12 @@ void rotary_embedding_neox(
key.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
rot_dim, rot_dim,
stride, query_stride,
key_stride,
num_heads, num_heads,
num_kv_heads,
head_size); head_size);
}); });
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment