"tests/python/common/test_heterograph.py" did not exist on "cded5b80fe566c174fce1185b5456448d6f9c671"
Unverified Commit 198ba2fb authored by Casper's avatar Casper Committed by GitHub
Browse files

Merge pull request #28 from casper-hansen/fix_rotary_emb

[BUG] Fix illegal memory access + Quantized Multi-GPU support
parents 85430ddc 562c0d52
...@@ -3,11 +3,11 @@ import time ...@@ -3,11 +3,11 @@ import time
import torch import torch
import argparse import argparse
from lm_eval import evaluator from lm_eval import evaluator
from transformers import AutoTokenizer
from awq import AutoAWQForCausalLM from awq import AutoAWQForCausalLM
from awq.quantize.auto_clip import apply_clip from awq.quantize.auto_clip import apply_clip
from awq.quantize.auto_scale import apply_scale from awq.quantize.auto_scale import apply_scale
from awq.utils.lm_eval_adaptor import LMEvalAdaptor from awq.utils.lm_eval_adaptor import LMEvalAdaptor
from transformers import AutoTokenizer, GenerationConfig
def load_search_result_into_memory(model, search_path): def load_search_result_into_memory(model, search_path):
...@@ -80,22 +80,6 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, bat ...@@ -80,22 +80,6 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, bat
out = func() out = func()
return out, time.time() - start return out, time.time() - start
def _generate(model, model_out, n_generate, batch_size):
past_key_values = model_out.past_key_values
for i in range(n_generate):
logits = model_out.logits[:, -1, :]
new_tokens = []
for batch_index in range(batch_size):
probs = torch.softmax(logits[batch_index], dim=-1)
token = torch.multinomial(probs, num_samples=1)
new_tokens.append(token)
tokens = torch.as_tensor(new_tokens, device=device).unsqueeze(-1)
model_out = model(tokens, use_cache=True, past_key_values=past_key_values)
def _warmup(device:str): def _warmup(device:str):
warm_up = torch.randn((4096,4096)).to(device) warm_up = torch.randn((4096,4096)).to(device)
torch.mm(warm_up,warm_up) torch.mm(warm_up,warm_up)
...@@ -114,19 +98,36 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, bat ...@@ -114,19 +98,36 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, bat
ids = torch.randint(0, tokenizer.vocab_size, (batch_size, n_context)).cuda() ids = torch.randint(0, tokenizer.vocab_size, (batch_size, n_context)).cuda()
# Context stage # Context stage
model_out, context_time = _timer(lambda: model(ids, use_cache=True)) _, context_time = _timer(lambda: model.generate(
ids,
generation_config=GenerationConfig(
max_new_tokens=0,
min_new_tokens=0,
use_cache=True
)
))
# Generation stage # Generation stage
_, generation_time = _timer(lambda: _generate(model, model_out, n_generate, batch_size)) _, generation_time = _timer(lambda: model.generate(
ids,
generation_config=GenerationConfig(
max_new_tokens=n_context,
min_new_tokens=n_context,
forced_eos_token_id=-100,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=-100,
use_cache=True
)
))
# Prints # Prints
memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 2) memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 2)
context_tokens_per_second = n_context / context_time * batch_size context_tokens_per_second = n_context / context_time * batch_size
context_ms_per_token = (context_time*1000) / n_context * batch_size context_ms_per_token = (context_time*1000) / n_context / batch_size
inference_tokens_per_second = n_generate / generation_time * batch_size inference_tokens_per_second = n_generate / generation_time * batch_size
inference_ms_per_token = (generation_time*1000) / n_generate * batch_size inference_ms_per_token = (generation_time*1000) / n_generate / batch_size
print(f"[======] Model summary: {model_path} [======]") print(f"[=] Model summary: {model_path} [=]")
print(f"[*] Load time: {load_time:.2f} seconds") print(f"[*] Load time: {load_time:.2f} seconds")
print(f"[*] Context speed: {context_tokens_per_second:.2f} tokens/second ({context_ms_per_token:.2f} ms/token)") print(f"[*] Context speed: {context_tokens_per_second:.2f} tokens/second ({context_ms_per_token:.2f} ms/token)")
print(f"[*] Generation speed: {inference_tokens_per_second:.2f} tokens/second ({inference_ms_per_token:.2f} ms/token)") print(f"[*] Generation speed: {inference_tokens_per_second:.2f} tokens/second ({inference_ms_per_token:.2f} ms/token)")
...@@ -185,9 +186,6 @@ if __name__ == '__main__': ...@@ -185,9 +186,6 @@ if __name__ == '__main__':
run_eval(args.model_path, args.quant_file, args.device, run_eval(args.model_path, args.quant_file, args.device,
args.tasks, args.task_batch_size, args.task_n_shot, args.task_use_pretrained) args.tasks, args.task_batch_size, args.task_n_shot, args.task_use_pretrained)
elif args.entry_type == 'speed': elif args.entry_type == 'speed':
if args.batch_size > 1 and not args.disable_fused_layers:
raise Exception('Fused layers only support batch_size=1. Pass --disable_fused_layers to run batch_size>1 (much slower).')
run_speed(args.model_path, args.quant_file, args.device, args.n_generate, args.n_context, args.batch_size, args.disable_fused_layers) run_speed(args.model_path, args.quant_file, args.device, args.n_generate, args.n_context, args.batch_size, args.disable_fused_layers)
else: else:
raise Exception('--entry_type must be one of (search|quant|eval|speed)') raise Exception('--entry_type must be one of (search|quant|eval|speed)')
...@@ -297,21 +297,26 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -297,21 +297,26 @@ class BaseAWQForCausalLM(nn.Module):
model.tie_weights() model.tie_weights()
device_map = infer_auto_device_map(
model,
no_split_module_classes=[self.layer_type],
dtype=torch_dtype
)
# Load model weights # Load model weights
if is_quantized: if is_quantized:
model = load_checkpoint_and_dispatch(model, model_filename, device_map=device, no_split_module_classes=[self.layer_type]) model = load_checkpoint_and_dispatch(
model,
model_filename,
device_map=device_map,
no_split_module_classes=[self.layer_type]
)
if fuse_layers: if fuse_layers:
self.fuse_layers(model) self.fuse_layers(model)
else: else:
# If not quantized, must load with AutoModelForCausalLM # If not quantized, must load with AutoModelForCausalLM
device_map = infer_auto_device_map(
model,
no_split_module_classes=[self.layer_type],
dtype=torch_dtype
)
del model del model
# Load model weights # Load model weights
......
...@@ -99,9 +99,10 @@ class LlamaFuser: ...@@ -99,9 +99,10 @@ class LlamaFuser:
attn = QuantLlamaAttention( attn = QuantLlamaAttention(
module.hidden_size, module.hidden_size,
module.num_heads, module.num_heads,
module.num_key_value_heads,
qkv_layer, qkv_layer,
module.o_proj, module.o_proj,
qkv_layer.qweight.device, next(iter(qkv_layer.state_dict().values())).device,
self.model.config.max_new_tokens self.model.config.max_new_tokens
) )
set_module_name(self.model, name, attn) set_module_name(self.model, name, attn)
...@@ -118,7 +119,7 @@ class LlamaFuser: ...@@ -118,7 +119,7 @@ class LlamaFuser:
q_proj.in_features, q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features, q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None, q_proj.bias is not None,
q_proj.qweight.device next(iter(module.state_dict().values())).device
) )
# replace buffers with real weights # replace buffers with real weights
......
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import awq_inference_engine import awq_inference_engine
from torch.nn import functional as F from torch.nn import functional as F
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, LlamaRotaryEmbedding
class QuantLlamaRotaryEmbedding(nn.Module): class QuantLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
...@@ -29,6 +30,7 @@ class QuantLlamaRotaryEmbedding(nn.Module): ...@@ -29,6 +30,7 @@ class QuantLlamaRotaryEmbedding(nn.Module):
sin = freqs.sin() sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
# [max_position, rot_dim]
self.register_buffer("cos_sin_cache", cache.half(), persistent=False) self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
def forward( def forward(
...@@ -41,13 +43,16 @@ class QuantLlamaRotaryEmbedding(nn.Module): ...@@ -41,13 +43,16 @@ class QuantLlamaRotaryEmbedding(nn.Module):
# to the attention op. # to the attention op.
query = query.contiguous() query = query.contiguous()
key = key.contiguous() key = key.contiguous()
awq_inference_engine.rotary_embedding_neox(
awq_inference_engine.rotary_embedding(
positions, positions,
query, query,
key, key,
self.dim, self.dim,
self.cos_sin_cache, self.cos_sin_cache,
True # is_neox
) )
return query, key return query, key
class QuantLlamaAttention(nn.Module): class QuantLlamaAttention(nn.Module):
...@@ -57,21 +62,29 @@ class QuantLlamaAttention(nn.Module): ...@@ -57,21 +62,29 @@ class QuantLlamaAttention(nn.Module):
self, self,
hidden_size, hidden_size,
num_heads, num_heads,
num_kv_heads,
qkv_proj, qkv_proj,
o_proj, o_proj,
dev, dev,
max_new_tokens max_new_tokens,
use_hf_rotary=False
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_heads = num_heads self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = hidden_size // num_heads self.head_dim = hidden_size // num_heads
self.use_hf_rotary = use_hf_rotary
if (self.head_dim * num_heads) != self.hidden_size: if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {num_heads}).") f" and `num_heads`: {num_heads}).")
self.qkv_proj = qkv_proj self.qkv_proj = qkv_proj
self.o_proj = o_proj self.o_proj = o_proj
if use_hf_rotary:
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_new_tokens, device=dev)
else:
self.rotary_emb = QuantLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=max_new_tokens, device = dev) self.rotary_emb = QuantLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=max_new_tokens, device = dev)
def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False): def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False):
...@@ -80,16 +93,46 @@ class QuantLlamaAttention(nn.Module): ...@@ -80,16 +93,46 @@ class QuantLlamaAttention(nn.Module):
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
qkv_states = self.qkv_proj(hidden_states) qkv_states = self.qkv_proj(hidden_states)
if self.use_hf_rotary:
# get qkv
qkv_states = qkv_states.view(bsz, q_len, 3, self.num_heads, self.head_dim) qkv_states = qkv_states.view(bsz, q_len, 3, self.num_heads, self.head_dim)
query, key, value = torch.split(qkv_states, 1, dim=2)
del qkv_states
# reshape for hf rotary
query = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
# This updates the query and key states in-place, saving VRAM. cos, sin = self.rotary_emb(value, seq_len=kv_seq_len)
query_states, key_states, value_states = torch.split(qkv_states, 1, dim=2) query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids)
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids)
else:
# get qkv
query, key, value = qkv_states.chunk(chunks=3, dim=-1)
del qkv_states del qkv_states
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) # [num_tokens, num_heads * head_size]
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_batch_size, query_len, _ = query.shape
query = query.view(query_len*query_batch_size, self.num_heads * self.head_dim)
# [num_tokens, num_kv_heads * head_size]
key_batch_size, key_len, _ = key.shape
key = key.view(key_len*key_batch_size, self.num_kv_heads * self.head_dim)
# [num_tokens]
positions = position_ids.view(-1).to(query.device)
query, key = self.rotary_emb(query, key, positions)
query = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
is_causal = past_key_value is None is_causal = past_key_value is None
...@@ -97,25 +140,25 @@ class QuantLlamaAttention(nn.Module): ...@@ -97,25 +140,25 @@ class QuantLlamaAttention(nn.Module):
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
value_states = value_states.to(key_states.device) value = value.to(key.device)
if past_key_value is not None: if past_key_value is not None:
# reuse k, v, self_attention # reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2) key = torch.cat([past_key_value[0], key], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2) value = torch.cat([past_key_value[1], value], dim=2)
if use_cache: if use_cache:
# Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor # Since qkv_proj is fused, query etc will hold a reference to the original qkv_states tensor
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this. # which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
key_states = key_states.contiguous() key = key.contiguous()
value_states = value_states.contiguous() value = value.contiguous()
query_states = query_states.contiguous() query = query.contiguous()
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key, value) if use_cache else None
# with torch.backends.cuda.sdp_kernel(enable_math=False): # with torch.backends.cuda.sdp_kernel(enable_math=False):
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=is_causal) attn_output = F.scaled_dot_product_attention(query, key, value, is_causal=is_causal)
del query_states, key_states, value_states del query, key, value
attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
......
#pragma once #pragma once
#include <torch/extension.h> #include <torch/extension.h>
void rotary_embedding_neox( void rotary_embedding(
torch::Tensor& positions, torch::Tensor& positions,
torch::Tensor& query, torch::Tensor& query,
torch::Tensor& key, torch::Tensor& key,
int head_size, int head_size,
torch::Tensor& cos_sin_cache); torch::Tensor& cos_sin_cache,
\ No newline at end of file bool is_neox);
\ No newline at end of file
...@@ -9,15 +9,56 @@ https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu ...@@ -9,15 +9,56 @@ https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include "pos_encoding.h" #include "pos_encoding.h"
template<typename scalar_t> #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
__global__ void rotary_embedding_neox_kernel( AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
template<typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding(
scalar_t* __restrict__ arr,
const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ sin_ptr,
int rot_offset,
int embed_dim)
{
int x_index, y_index;
scalar_t cos, sin;
if (IS_NEOX) {
// GPT-NeoX style rotary embedding.
x_index = rot_offset;
y_index = embed_dim + rot_offset;
cos = __ldg(cos_ptr + x_index);
sin = __ldg(sin_ptr + x_index);
} else {
// GPT-J style rotary embedding.
x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1;
cos = __ldg(cos_ptr + x_index / 2);
sin = __ldg(sin_ptr + x_index / 2);
}
const scalar_t x = arr[x_index];
const scalar_t y = arr[y_index];
arr[x_index] = x * cos - y * sin;
arr[y_index] = y * cos + x * sin;
}
template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [num_tokens] const int64_t* __restrict__ positions, // [num_tokens]
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim, const int rot_dim,
const int stride, const int query_stride,
const int key_stride,
const int num_heads, const int num_heads,
const int num_kv_heads,
const int head_size) { const int head_size) {
// Each thread block is responsible for one token. // Each thread block is responsible for one token.
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
...@@ -25,64 +66,72 @@ __global__ void rotary_embedding_neox_kernel( ...@@ -25,64 +66,72 @@ __global__ void rotary_embedding_neox_kernel(
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const int embed_dim = rot_dim / 2; const int embed_dim = rot_dim / 2;
const int n = num_heads * embed_dim; const scalar_t* cos_ptr = cache_ptr;
for (int i = threadIdx.x; i < n; i += blockDim.x) { const scalar_t* sin_ptr = cache_ptr + embed_dim;
const int head_idx = i / embed_dim;
const int token_head = token_idx * stride + head_idx * head_size;
const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
const int x_index = rot_offset; apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
const int y_index = embed_dim + rot_offset; sin_ptr, rot_offset, embed_dim);
}
const int out_x = token_idx * stride + head_idx * head_size + x_index;
const int out_y = token_idx * stride + head_idx * head_size + y_index;
const scalar_t cos = __ldg(cache_ptr + x_index);
const scalar_t sin = __ldg(cache_ptr + y_index);
const scalar_t q_x = query[token_head + x_index];
const scalar_t q_y = query[token_head + y_index];
query[out_x] = q_x * cos - q_y * sin;
query[out_y] = q_y * cos + q_x * sin;
const scalar_t k_x = key[token_head + x_index]; const int nk = num_kv_heads * embed_dim;
const scalar_t k_y = key[token_head + y_index]; for (int i = threadIdx.x; i < nk; i += blockDim.x) {
key[out_x] = k_x * cos - k_y * sin; const int head_idx = i / embed_dim;
key[out_y] = k_y * cos + k_x * sin; const int token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
sin_ptr, rot_offset, embed_dim);
} }
} }
void rotary_embedding(
void rotary_embedding_neox( torch::Tensor& positions, // [num_tokens]
torch::Tensor& positions, // [b, num_tokens] torch::Tensor& query, // [num_tokens, num_heads * head_size]
torch::Tensor& query, // [b, num_tokens, 1, num_heads, head_size] torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
torch::Tensor& key, // [b, num_tokens, 1, num_heads, head_size]
int head_size, int head_size,
torch::Tensor& cos_sin_cache) // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
{ bool is_neox) {
int num_tokens = query.size(0) * query.size(1); int num_tokens = query.size(0);
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-2); int num_heads = query.size(1) / head_size;
int stride = num_heads * head_size; int num_kv_heads = key.size(1) / head_size;
// TORCH_CHECK(stride == key.stride(0)); int query_stride = query.stride(0);
int key_stride = key.stride(0);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512)); dim3 block(std::min(num_heads * rot_dim / 2, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2( VLLM_DISPATCH_FLOATING_TYPES(
at::ScalarType::Half,
at::ScalarType::BFloat16,
query.scalar_type(), query.scalar_type(),
"rotary_embedding_neox", "rotary_embedding",
[&] { [&] {
rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>( if (is_neox) {
rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
rot_dim, rot_dim,
stride, query_stride,
key_stride,
num_heads, num_heads,
num_kv_heads,
head_size); head_size);
} else {
rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
}); });
} }
\ No newline at end of file
...@@ -8,5 +8,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) ...@@ -8,5 +8,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{ {
m.def("layernorm_forward_cuda", &layernorm_forward_cuda, "FasterTransformer layernorm kernel"); m.def("layernorm_forward_cuda", &layernorm_forward_cuda, "FasterTransformer layernorm kernel");
m.def("gemm_forward_cuda", &gemm_forward_cuda, "Quantized GEMM kernel."); m.def("gemm_forward_cuda", &gemm_forward_cuda, "Quantized GEMM kernel.");
m.def("rotary_embedding_neox", &rotary_embedding_neox, "Apply GPT-NeoX style rotary embedding to query and key"); m.def("rotary_embedding", &rotary_embedding, "Apply rotary embedding to query and key");
} }
...@@ -85,9 +85,11 @@ if os.name == "nt": ...@@ -85,9 +85,11 @@ if os.name == "nt":
"nvcc": arch_flags "nvcc": arch_flags
} }
else: else:
threads = ["--threads", str(min(os.cpu_count(), 8))]
extra_compile_args={ extra_compile_args={
"cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17"], "cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17"],
"nvcc": ["-O3", "-std=c++17"] + arch_flags "nvcc": ["-O3", "-std=c++17"] + arch_flags + threads
} }
extensions = [ extensions = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment