Unverified Commit 897cb2ae authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Optimize data movement (#20)

parent 1f01a18d
import torch
import torch.nn as nn
from cacheflow import activation_ops
class SiluAndMul(nn.Module):
def __init__(self):
super().__init__()
def forward(
self,
x: torch.Tensor, # (num_tokens, 2 * d)
) -> torch.Tensor: # (num_tokens, d)
num_tokens = x.shape[0]
d = x.shape[1] // 2
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
activation_ops.silu_and_mul(out, x)
return out
from typing import List, Optional from typing import Optional
from flash_attn.flash_attention import FlashAttention from flash_attn.flash_attn_interface import _flash_attn_forward
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -16,40 +16,38 @@ class GPTCacheFlowAttention(nn.Module): ...@@ -16,40 +16,38 @@ class GPTCacheFlowAttention(nn.Module):
super().__init__() super().__init__()
self.scale = float(scale) self.scale = float(scale)
self.flash_attn = FlashAttention(softmax_scale=self.scale)
def multi_query_kv_attention( def multi_query_kv_attention(
self, self,
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
prompt_lens: List[int], cumulative_prompt_lens: torch.Tensor, # [num_prompts + 1]
max_prompt_len: int,
) -> None: ) -> None:
if query.dtype == torch.float: if query.dtype == torch.float:
raise ValueError('The float data type is not supported by ' raise ValueError('The float data type is not supported by '
'FlashAttention. Use the half data type instead.') 'FlashAttention. Use the half data type instead.')
head_size = query.shape[2] head_size = query.shape[-1]
if head_size > 128: if head_size > 128:
raise ValueError('FlashAttention does not support head_size > 128.') raise ValueError('FlashAttention does not support head_size > 128.')
device = query.device # Directly call FlashAttention's internal function to avoid allocating
prefix_sum = [0] # a new tensor for the output.
for prompt_len in prompt_lens: _flash_attn_forward(
prefix_sum.append(prefix_sum[-1] + prompt_len) query,
prefix_sum = torch.tensor(prefix_sum, dtype=torch.int, device=device) key,
max_prompt_len = max(prompt_lens) value,
output,
# FIXME(woosuk): Unnecessary copy. Optimize this. cumulative_prompt_lens,
qkv = torch.stack([query, key, value], dim=1) cumulative_prompt_lens,
out = self.flash_attn( max_prompt_len,
qkv, max_prompt_len,
cu_seqlens=prefix_sum, dropout_p=0.0,
max_s=max_prompt_len, softmax_scale=self.scale,
causal=True, causal=True,
)[0] return_softmax=False,
# FIXME(woosuk): Unnecessary copy. Optimize this. )
output.copy_(out, non_blocking=True)
def single_query_cached_kv_attention( def single_query_cached_kv_attention(
self, self,
...@@ -90,21 +88,18 @@ class GPTCacheFlowAttention(nn.Module): ...@@ -90,21 +88,18 @@ class GPTCacheFlowAttention(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size] ) -> torch.Tensor: # [num_tokens, num_heads * head_size]
# Pre-allocate the output tensor. # NOTE: The query, key, and value tensors must be sliced from a qkv
output = torch.empty_like(query) # tensor of shape [num_tokens, 3 * num_heads * head_size].
# Prune out paddings if any.
query = query[:input_metadata.num_valid_tokens]
key = key[:input_metadata.num_valid_tokens]
value = value[:input_metadata.num_valid_tokens]
# Reshape the input tensors. # Reshape the query, key, and value tensors.
num_heads = value_cache.shape[1] num_heads = value_cache.shape[1]
head_size = value_cache.shape[2] head_size = value_cache.shape[2]
query = query.view(-1, num_heads, head_size) query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_heads, head_size) key = key.view(-1, num_heads, head_size)
value = value.view(-1, num_heads, head_size) value = value.view(-1, num_heads, head_size)
output = output.view(-1, num_heads, head_size)
# Pre-allocate the output tensor.
output = torch.empty_like(query)
# Compute the attention op for prompts. # Compute the attention op for prompts.
num_prompt_tokens = input_metadata.num_prompt_tokens num_prompt_tokens = input_metadata.num_prompt_tokens
...@@ -114,7 +109,8 @@ class GPTCacheFlowAttention(nn.Module): ...@@ -114,7 +109,8 @@ class GPTCacheFlowAttention(nn.Module):
query[:num_prompt_tokens], query[:num_prompt_tokens],
key[:num_prompt_tokens], key[:num_prompt_tokens],
value[:num_prompt_tokens], value[:num_prompt_tokens],
input_metadata.prompt_lens, input_metadata.cumulative_prompt_lens,
input_metadata.max_prompt_len,
) )
# Wait until the cache op is done. # Wait until the cache op is done.
...@@ -122,14 +118,22 @@ class GPTCacheFlowAttention(nn.Module): ...@@ -122,14 +118,22 @@ class GPTCacheFlowAttention(nn.Module):
cache_event.wait() cache_event.wait()
# Reshape the keys and values and store them in the cache. # Reshape the keys and values and store them in the cache.
cache_ops.reshape_and_cache( num_valid_tokens = input_metadata.num_valid_tokens
key, value, key_cache, value_cache, input_metadata.slot_mapping) if num_valid_tokens > 0:
# The stride is 3 because the key and value are sliced from qkv.
cache_ops.reshape_and_cache(
key[:num_valid_tokens],
value[:num_valid_tokens],
key_cache,
value_cache,
input_metadata.slot_mapping,
)
if input_metadata.num_generation_tokens > 0: if input_metadata.num_generation_tokens > 0:
# Compute the attention op for generation tokens. # Compute the attention op for generation tokens.
self.single_query_cached_kv_attention( self.single_query_cached_kv_attention(
output[num_prompt_tokens:], output[num_prompt_tokens:num_valid_tokens],
query[num_prompt_tokens:], query[num_prompt_tokens:num_valid_tokens],
key_cache, key_cache,
value_cache, value_cache,
input_metadata) input_metadata)
...@@ -186,19 +190,15 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention): ...@@ -186,19 +190,15 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
) -> torch.Tensor: # [num_tokens, num_heads * head_size] ) -> torch.Tensor: # [num_tokens, num_heads * head_size]
# Apply rotary embedding to the query and key before passing them # Apply rotary embedding to the query and key before passing them
# to the attention op. # to the attention op.
out_query = torch.empty_like(query)
out_key = torch.empty_like(key)
pos_encoding_ops.rotary_embedding_neox( pos_encoding_ops.rotary_embedding_neox(
out_query,
out_key,
positions, positions,
query, query,
key, key,
self.cos_sin_cache, self.cos_sin_cache,
) )
return super().forward( return super().forward(
out_query, query,
out_key, key,
value, value,
key_cache, key_cache,
value_cache, value_cache,
......
...@@ -12,6 +12,7 @@ class InputMetadata: ...@@ -12,6 +12,7 @@ class InputMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]], seq_groups: List[Tuple[List[int], SamplingParams]],
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs. seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
prompt_lens: List[int], prompt_lens: List[int],
cumulative_prompt_lens: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
context_lens: torch.Tensor, context_lens: torch.Tensor,
max_context_len: int, max_context_len: int,
...@@ -20,6 +21,7 @@ class InputMetadata: ...@@ -20,6 +21,7 @@ class InputMetadata:
self.seq_groups = seq_groups self.seq_groups = seq_groups
self.seq_logprobs = seq_logprobs self.seq_logprobs = seq_logprobs
self.prompt_lens = prompt_lens self.prompt_lens = prompt_lens
self.cumulative_prompt_lens = cumulative_prompt_lens
self.slot_mapping = slot_mapping self.slot_mapping = slot_mapping
self.context_lens = context_lens self.context_lens = context_lens
self.max_context_len = max_context_len self.max_context_len = max_context_len
...@@ -27,6 +29,7 @@ class InputMetadata: ...@@ -27,6 +29,7 @@ class InputMetadata:
self.num_prompts = len(prompt_lens) self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens) self.num_prompt_tokens = sum(prompt_lens)
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
self.num_generation_tokens = context_lens.shape[0] self.num_generation_tokens = context_lens.shape[0]
self.num_valid_tokens = slot_mapping.shape[0] self.num_valid_tokens = slot_mapping.shape[0]
if block_tables.numel() > 0: if block_tables.numel() > 0:
...@@ -40,11 +43,13 @@ class InputMetadata: ...@@ -40,11 +43,13 @@ class InputMetadata:
return (f'InputMetadata(' return (f'InputMetadata('
f'num_prompts={self.num_prompts}, ' f'num_prompts={self.num_prompts}, '
f'num_prompt_tokens={self.num_prompt_tokens}, ' f'num_prompt_tokens={self.num_prompt_tokens}, '
f'max_prompt_len={self.max_prompt_len}, '
f'num_generation_tokens={self.num_generation_tokens}, ' f'num_generation_tokens={self.num_generation_tokens}, '
f'num_valid_tokens={self.num_valid_tokens}, ' f'num_valid_tokens={self.num_valid_tokens}, '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, ' f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'max_context_len={self.max_context_len}), ' f'max_context_len={self.max_context_len}), '
f'prompt_lens={self.prompt_lens}, ' f'prompt_lens={self.prompt_lens}, '
f'cumulative_prompt_lens={self.cumulative_prompt_lens}, '
f'slot_mapping={self.slot_mapping}, ' f'slot_mapping={self.slot_mapping}, '
f'context_lens={self.context_lens}, ' f'context_lens={self.context_lens}, '
f'block_tables={self.block_tables})') f'block_tables={self.block_tables})')
...@@ -11,6 +11,7 @@ from torch import nn ...@@ -11,6 +11,7 @@ from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from cacheflow.models import InputMetadata from cacheflow.models import InputMetadata
from cacheflow.models.activation import SiluAndMul
from cacheflow.models.attention import LlamaCacheFlowAttention from cacheflow.models.attention import LlamaCacheFlowAttention
from cacheflow.models.layernorm import RMSNorm from cacheflow.models.layernorm import RMSNorm
from cacheflow.models.sample import Sampler from cacheflow.models.sample import Sampler
...@@ -39,16 +40,14 @@ class LlamaMLP(nn.Module): ...@@ -39,16 +40,14 @@ class LlamaMLP(nn.Module):
self.down_proj = RowParallelLinear(intermediate_size, hidden_size, self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
bias=False, input_is_parallel=True, bias=False, input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
assert hidden_act == 'silu' if hidden_act != 'silu':
self.act_fn = nn.SiLU() raise ValueError(f'Unsupported activation: {hidden_act}. '
'Only silu is supported for now.')
self.act_fn = SiluAndMul()
def forward(self, x): def forward(self, x):
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
gate_up = gate_up.reshape(gate_up.shape[:-1] + (2, -1)) x = self.act_fn(gate_up)
gate, up = torch.split(gate_up, 1, dim=-2)
gate = gate.squeeze(dim=-2).contiguous()
up = up.squeeze(dim=-2).contiguous()
x = self.act_fn(gate) * up
x, _ = self.down_proj(x) x, _ = self.down_proj(x)
return x return x
...@@ -94,11 +93,7 @@ class LlamaAttention(nn.Module): ...@@ -94,11 +93,7 @@ class LlamaAttention(nn.Module):
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1)) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k, v = torch.split(qkv, 1, dim=-2)
q = q.squeeze(dim=-2).contiguous()
k = k.squeeze(dim=-2).contiguous()
v = v.squeeze(dim=-2).contiguous()
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn( attn_output = self.attn(
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event) positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
......
...@@ -69,17 +69,14 @@ class OPTAttention(nn.Module): ...@@ -69,17 +69,14 @@ class OPTAttention(nn.Module):
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1)) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k, v = torch.split(qkv, 1, dim=-2)
q = q.squeeze(dim=-2).contiguous()
k = k.squeeze(dim=-2).contiguous()
v = v.squeeze(dim=-2).contiguous()
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
attn_output = self.attn( attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event) q, k, v, key_cache, value_cache, input_metadata, cache_event)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
return output return output
class OPTDecoderLayer(nn.Module): class OPTDecoderLayer(nn.Module):
def __init__(self, config: OPTConfig): def __init__(self, config: OPTConfig):
......
...@@ -128,6 +128,11 @@ class Worker: ...@@ -128,6 +128,11 @@ class Worker:
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append(slot) slot_mapping.append(slot)
cumulative_prompt_lens: List[int] = [0]
for prompt_len in prompt_lens:
cumulative_prompt_lens.append(
cumulative_prompt_lens[-1] + prompt_len)
# Add generation tokens. # Add generation tokens.
max_context_len = 0 max_context_len = 0
max_num_blocks_per_seq = 0 max_num_blocks_per_seq = 0
...@@ -183,11 +188,14 @@ class Worker: ...@@ -183,11 +188,14 @@ class Worker:
for block_table in generation_block_tables] for block_table in generation_block_tables]
block_tables_tensor = torch.tensor( block_tables_tensor = torch.tensor(
padded_block_tables, dtype=torch.int, device='cuda') padded_block_tables, dtype=torch.int, device='cuda')
cumulative_prompt_lens_tensor = torch.tensor(
cumulative_prompt_lens, dtype=torch.int, device='cuda')
input_metadata = InputMetadata( input_metadata = InputMetadata(
seq_groups=seq_groups, seq_groups=seq_groups,
seq_logprobs=seq_logprobs, seq_logprobs=seq_logprobs,
prompt_lens=prompt_lens, prompt_lens=prompt_lens,
cumulative_prompt_lens=cumulative_prompt_lens_tensor,
slot_mapping=slot_mapping_tensor, slot_mapping=slot_mapping_tensor,
context_lens=context_lens_tensor, context_lens=context_lens_tensor,
max_context_len=max_context_len, max_context_len=max_context_len,
......
#include <torch/extension.h>
void silu_and_mul(
torch::Tensor& out,
torch::Tensor& input);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"silu_and_mul",
&silu_and_mul,
"Activation function used in SwiGLU.");
}
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
namespace cacheflow {
template<typename T>
__device__ __forceinline__ T silu(const T& x) {
// x * sigmoid(x)
return (T) (((float) x) / (1.0f + expf((float) -x)));
}
template<typename scalar_t>
__global__ void silu_and_mul_kernel(
scalar_t* __restrict__ out, // [num_tokens, d]
const scalar_t* __restrict__ input, // [num_tokens, 2, d]
const int d) {
const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = silu(x) * y;
}
}
} // namespace cacheflow
void silu_and_mul(
torch::Tensor& out, // [num_tokens, d]
torch::Tensor& input) // [num_tokens, 2 * d]
{
int num_tokens = input.size(0);
int d = input.size(1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(),
"silu_and_mul_kernel",
[&] {
cacheflow::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
d);
});
}
...@@ -25,7 +25,8 @@ __global__ void single_query_cached_kv_attention_kernel( ...@@ -25,7 +25,8 @@ __global__ void single_query_cached_kv_attention_kernel(
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq) { const int max_num_blocks_per_seq,
const int q_stride) {
constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE; constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x; const int thread_idx = threadIdx.x;
...@@ -56,7 +57,8 @@ __global__ void single_query_cached_kv_attention_kernel( ...@@ -56,7 +57,8 @@ __global__ void single_query_cached_kv_attention_kernel(
// For example, if the the thread group size is 4, then the first thread in the group // For example, if the the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
// th vectors of the query, and so on. // th vectors of the query, and so on.
const scalar_t* q_ptr = q + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
Q_vec q_vecs[NUM_VECS_PER_THREAD]; Q_vec q_vecs[NUM_VECS_PER_THREAD];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
...@@ -264,7 +266,8 @@ __global__ void single_query_cached_kv_attention_kernel( ...@@ -264,7 +266,8 @@ __global__ void single_query_cached_kv_attention_kernel(
scale, \ scale, \
block_tables_ptr, \ block_tables_ptr, \
context_lens_ptr, \ context_lens_ptr, \
max_num_blocks_per_seq); max_num_blocks_per_seq, \
query_stride);
// TODO(woosuk): Tune NUM_THREADS. // TODO(woosuk): Tune NUM_THREADS.
template< template<
...@@ -284,6 +287,7 @@ void single_query_cached_kv_attention_launcher( ...@@ -284,6 +287,7 @@ void single_query_cached_kv_attention_launcher(
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1); int max_num_blocks_per_seq = block_tables.size(1);
int query_stride = query.stride(0);
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
...@@ -333,13 +337,13 @@ void single_query_cached_kv_attention_launcher( ...@@ -333,13 +337,13 @@ void single_query_cached_kv_attention_launcher(
} }
void single_query_cached_kv_attention( void single_query_cached_kv_attention(
torch::Tensor& out, torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
float scale, float scale,
torch::Tensor& block_tables, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, torch::Tensor& context_lens, // [num_seqs]
int block_size, int block_size,
int max_context_len) { int max_context_len) {
// TODO(woosuk): Support BF16. // TODO(woosuk): Support BF16.
......
...@@ -81,6 +81,8 @@ __global__ void reshape_and_cache_kernel( ...@@ -81,6 +81,8 @@ __global__ void reshape_and_cache_kernel(
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
const int* __restrict__ slot_mapping, // [num_tokens] const int* __restrict__ slot_mapping, // [num_tokens]
const int key_stride,
const int value_stride,
const int num_heads, const int num_heads,
const int head_size, const int head_size,
const int block_size, const int block_size,
...@@ -92,7 +94,8 @@ __global__ void reshape_and_cache_kernel( ...@@ -92,7 +94,8 @@ __global__ void reshape_and_cache_kernel(
const int n = num_heads * head_size; const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) { for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int src_idx = token_idx * n + i; const int src_key_idx = token_idx * key_stride + i;
const int src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size; const int head_idx = i / head_size;
const int head_offset = i % head_size; const int head_offset = i % head_size;
...@@ -108,25 +111,29 @@ __global__ void reshape_and_cache_kernel( ...@@ -108,25 +111,29 @@ __global__ void reshape_and_cache_kernel(
+ head_idx * head_size * block_size + head_idx * head_size * block_size
+ head_offset * block_size + head_offset * block_size
+ block_offset; + block_offset;
key_cache[tgt_key_idx] = __ldg(&key[src_idx]); key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]);
value_cache[tgt_value_idx] = __ldg(&value[src_idx]); value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]);
} }
} }
} // namespace cacheflow } // namespace cacheflow
void reshape_and_cache( void reshape_and_cache(
torch::Tensor& key, torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping) { torch::Tensor& slot_mapping) // [num_tokens]
{
int num_tokens = key.size(0); int num_tokens = key.size(0);
int num_heads = key.size(1); int num_heads = key.size(1);
int head_size = key.size(2); int head_size = key.size(2);
int block_size = key_cache.size(3); int block_size = key_cache.size(3);
int x = key_cache.size(4); int x = key_cache.size(4);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512)); dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
...@@ -140,6 +147,8 @@ void reshape_and_cache( ...@@ -140,6 +147,8 @@ void reshape_and_cache(
key_cache.data_ptr<scalar_t>(), key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int>(), slot_mapping.data_ptr<int>(),
key_stride,
value_stride,
num_heads, num_heads,
head_size, head_size,
block_size, block_size,
......
#include <torch/extension.h> #include <torch/extension.h>
void rotary_embedding_neox( void rotary_embedding_neox(
torch::Tensor& out_query,
torch::Tensor& out_key,
torch::Tensor& positions, torch::Tensor& positions,
torch::Tensor& query, torch::Tensor& query,
torch::Tensor& key, torch::Tensor& key,
......
...@@ -5,12 +5,11 @@ namespace cacheflow { ...@@ -5,12 +5,11 @@ namespace cacheflow {
template<typename scalar_t> template<typename scalar_t>
__global__ void rotary_embedding_neox_kernel( __global__ void rotary_embedding_neox_kernel(
scalar_t* __restrict__ out_query, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ out_key, // [num_tokens, num_heads, head_size]
const int64_t* __restrict__ positions, // [num_tokens] const int64_t* __restrict__ positions, // [num_tokens]
const scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, head_size // 2] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, head_size // 2]
const int stride,
const int num_heads, const int num_heads,
const int head_size) { const int head_size) {
// Each thread block is responsible for one token. // Each thread block is responsible for one token.
...@@ -19,41 +18,36 @@ __global__ void rotary_embedding_neox_kernel( ...@@ -19,41 +18,36 @@ __global__ void rotary_embedding_neox_kernel(
const scalar_t* cache_ptr = cos_sin_cache + pos * head_size; const scalar_t* cache_ptr = cos_sin_cache + pos * head_size;
const int embed_dim = head_size / 2; const int embed_dim = head_size / 2;
const int n = num_heads * head_size; const int n = num_heads * embed_dim;
for (int i = threadIdx.x; i < n; i += blockDim.x) { for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int idx = token_idx * n + i; const int head_idx = i / embed_dim;
const int token_head = token_idx * stride + head_idx * head_size;
const int head_idx = i / head_size; const int rot_offset = i % embed_dim;
const int head_offset = i % head_size;
const int token_head = token_idx * n + head_idx * head_size;
const bool is_first_half = head_offset < embed_dim;
const int rot_offset = head_offset % embed_dim;
const int x_index = rot_offset; const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset; const int y_index = embed_dim + rot_offset;
const int out_x = token_idx * stride + head_idx * head_size + x_index;
const int out_y = token_idx * stride + head_idx * head_size + y_index;
const scalar_t cos = __ldg(cache_ptr + x_index); const scalar_t cos = __ldg(cache_ptr + x_index);
const scalar_t sin = __ldg(cache_ptr + y_index); const scalar_t sin = __ldg(cache_ptr + y_index);
const scalar_t q_x = __ldg(query + token_head + x_index); const scalar_t q_x = query[token_head + x_index];
const scalar_t q_y = __ldg(query + token_head + y_index); const scalar_t q_y = query[token_head + y_index];
const scalar_t q_cos = is_first_half ? q_x : q_y; query[out_x] = q_x * cos - q_y * sin;
const scalar_t q_sin = is_first_half ? -q_y : q_x; query[out_y] = q_y * cos + q_x * sin;
out_query[idx] = q_cos * cos + q_sin * sin;
const scalar_t k_x = __ldg(key + token_head + x_index); const scalar_t k_x = key[token_head + x_index];
const scalar_t k_y = __ldg(key + token_head + y_index); const scalar_t k_y = key[token_head + y_index];
const scalar_t k_cos = is_first_half ? k_x : k_y; key[out_x] = k_x * cos - k_y * sin;
const scalar_t k_sin = is_first_half ? -k_y : k_x; key[out_y] = k_y * cos + k_x * sin;
out_key[idx] = k_cos * cos + k_sin * sin;
} }
} }
} // namespace cacheflow } // namespace cacheflow
void rotary_embedding_neox( void rotary_embedding_neox(
torch::Tensor& out_query, // [num_tokens, num_heads * head_size]
torch::Tensor& out_key, // [num_tokens, num_heads * head_size]
torch::Tensor& positions, // [num_tokens] torch::Tensor& positions, // [num_tokens]
torch::Tensor& query, // [num_tokens, num_heads * head_size] torch::Tensor& query, // [num_tokens, num_heads * head_size]
torch::Tensor& key, // [num_tokens, num_heads * head_size] torch::Tensor& key, // [num_tokens, num_heads * head_size]
...@@ -62,21 +56,22 @@ void rotary_embedding_neox( ...@@ -62,21 +56,22 @@ void rotary_embedding_neox(
int num_tokens = query.size(0); int num_tokens = query.size(0);
int head_size = cos_sin_cache.size(1); int head_size = cos_sin_cache.size(1);
int num_heads = query.size(1) / head_size; int num_heads = query.size(1) / head_size;
int stride = query.stride(0);
TORCH_CHECK(stride == key.stride(0));
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512)); dim3 block(std::min(num_heads * head_size / 2, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
query.scalar_type(), query.scalar_type(),
"rotary_embedding_neox", "rotary_embedding_neox",
[&] { [&] {
cacheflow::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>( cacheflow::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
out_query.data_ptr<scalar_t>(),
out_key.data_ptr<scalar_t>(),
positions.data_ptr<int64_t>(), positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
stride,
num_heads, num_heads,
head_size); head_size);
}); });
......
...@@ -39,6 +39,13 @@ layernorm_extension = cpp_extension.CUDAExtension( ...@@ -39,6 +39,13 @@ layernorm_extension = cpp_extension.CUDAExtension(
) )
ext_modules.append(layernorm_extension) ext_modules.append(layernorm_extension)
activation_extension = cpp_extension.CUDAExtension(
name='cacheflow.activation_ops',
sources=['csrc/activation.cpp', 'csrc/activation_kernels.cu'],
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
)
ext_modules.append(activation_extension)
setuptools.setup( setuptools.setup(
name='cacheflow', name='cacheflow',
ext_modules=ext_modules, ext_modules=ext_modules,
......
import torch
import torch.nn.functional as F
from cacheflow import activation_ops
def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(chunks=2, dim=1)
return F.silu(x1) * x2
@torch.inference_mode()
def test_silu_and_mul(
num_tokens: int,
d: int,
dtype: torch.dtype,
) -> None:
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda')
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
activation_ops.silu_and_mul(out, x)
ref_out = ref_silu_and_mul(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
if __name__ == '__main__':
for dtype in [torch.half, torch.float]:
for num_tokens in [7, 83, 2048]:
for d in [512, 4096, 13824]:
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
test_silu_and_mul(num_tokens, d, dtype)
import random import random
from typing import List, Optional from typing import List, Optional
from flash_attn.flash_attention import FlashAttention from flash_attn.flash_attn_interface import _flash_attn_forward
import torch import torch
from cacheflow import attention_ops from cacheflow import attention_ops
...@@ -105,8 +105,9 @@ def test_single_query_cached_kv_attention( ...@@ -105,8 +105,9 @@ def test_single_query_cached_kv_attention(
num_blocks: int, num_blocks: int,
dtype: torch.dtype, dtype: torch.dtype,
) -> None: ) -> None:
query = torch.randn( qkv = torch.randn(
num_tokens, num_heads, head_size, dtype=dtype, device='cuda') num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
query, _, _ = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size() x = 16 // torch.tensor([], dtype=dtype).element_size()
key_block_shape = (num_heads, head_size // x, block_size, x) key_block_shape = (num_heads, head_size // x, block_size, x)
key_cache = torch.randn( key_cache = torch.randn(
...@@ -115,6 +116,11 @@ def test_single_query_cached_kv_attention( ...@@ -115,6 +116,11 @@ def test_single_query_cached_kv_attention(
value_cache = torch.randn( value_cache = torch.randn(
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda') size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
# Adjust the range of the values to reduce precision errors.
query = query / (head_size ** 0.5)
key_cache = key_cache / (head_size ** 0.5)
value_cache = value_cache / (head_size ** 0.5)
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)] context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
max_context_len = max(context_lens) max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda') context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
...@@ -130,7 +136,8 @@ def test_single_query_cached_kv_attention( ...@@ -130,7 +136,8 @@ def test_single_query_cached_kv_attention(
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda') block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
scale = float(1.0 / (head_size ** 0.5)) scale = float(1.0 / (head_size ** 0.5))
output = torch.empty_like(query) output = torch.empty(
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
attention_ops.single_query_cached_kv_attention( attention_ops.single_query_cached_kv_attention(
output, output,
query, query,
...@@ -175,19 +182,28 @@ def test_multi_query_kv_attention( ...@@ -175,19 +182,28 @@ def test_multi_query_kv_attention(
cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device='cuda') cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device='cuda')
scale = float(1.0 / (head_size ** 0.5)) scale = float(1.0 / (head_size ** 0.5))
query = torch.randn( qkv = torch.randn(
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
# Adjust the range of the values to reduce precision errors.
qkv = qkv / (head_size ** 0.5)
query, key, value = qkv.unbind(dim=1)
output = torch.empty(
num_tokens, num_heads, head_size, dtype=dtype, device='cuda') num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
key = torch.rand_like(query) _flash_attn_forward(
value = torch.rand_like(query) query,
key,
qkv = torch.stack([query, key, value], dim=1) value,
flash_attn = FlashAttention(softmax_scale=scale) output,
output = flash_attn( cu_seq_lens,
qkv, cu_seq_lens,
cu_seqlens=cu_seq_lens, max_seq_len,
max_s=max_seq_len, max_seq_len,
dropout_p=0.0,
softmax_scale=scale,
causal=True, causal=True,
)[0] return_softmax=False,
)
cu_seq_lens = cu_seq_lens.cpu().tolist() cu_seq_lens = cu_seq_lens.cpu().tolist()
ref_output = ref_multi_query_kv_attention( ref_output = ref_multi_query_kv_attention(
......
...@@ -17,10 +17,10 @@ def test_reshape_and_cache( ...@@ -17,10 +17,10 @@ def test_reshape_and_cache(
slot_mapping = random.sample(range(num_slots), num_tokens) slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda') slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
kv_shape = (num_tokens, num_heads, head_size) qkv = torch.randn(
key = torch.randn(size=kv_shape, dtype=dtype, device='cuda') num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
value = torch.randn(size=kv_shape, dtype=dtype, device='cuda') _, key, value = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size() x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda') key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
...@@ -35,7 +35,7 @@ def test_reshape_and_cache( ...@@ -35,7 +35,7 @@ def test_reshape_and_cache(
for i in range(num_tokens): for i in range(num_tokens):
reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x) reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
block_idx = slot_mapping[i] // block_size block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
block_offset = slot_mapping[i] % block_size block_offset = slot_mapping[i] % block_size
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i] cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
cloned_value_cache[block_idx, :, :, block_offset] = value[i] cloned_value_cache[block_idx, :, :, block_offset] = value[i]
......
...@@ -85,15 +85,13 @@ def test_rotary_embedding_neox( ...@@ -85,15 +85,13 @@ def test_rotary_embedding_neox(
cos_sin_cache = torch.cat((cos, sin), dim=-1) cos_sin_cache = torch.cat((cos, sin), dim=-1)
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda') cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')
# Run the kernel. # Run the kernel. The kernel is in-place, so we need to clone the inputs.
out_query = torch.empty_like(query) out_query = query.clone()
out_key = torch.empty_like(key) out_key = key.clone()
pos_encoding_ops.rotary_embedding_neox( pos_encoding_ops.rotary_embedding_neox(
positions,
out_query, out_query,
out_key, out_key,
positions,
query,
key,
cos_sin_cache, cos_sin_cache,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment