Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
897cb2ae
Unverified
Commit
897cb2ae
authored
Apr 02, 2023
by
Woosuk Kwon
Committed by
GitHub
Apr 02, 2023
Browse files
Optimize data movement (#20)
parent
1f01a18d
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
275 additions
and
135 deletions
+275
-135
cacheflow/models/activation.py
cacheflow/models/activation.py
+20
-0
cacheflow/models/attention.py
cacheflow/models/attention.py
+46
-46
cacheflow/models/input_metadata.py
cacheflow/models/input_metadata.py
+5
-0
cacheflow/models/llama.py
cacheflow/models/llama.py
+7
-12
cacheflow/models/opt.py
cacheflow/models/opt.py
+2
-5
cacheflow/worker/worker.py
cacheflow/worker/worker.py
+8
-0
csrc/activation.cpp
csrc/activation.cpp
+12
-0
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+46
-0
csrc/attention_kernels.cu
csrc/attention_kernels.cu
+13
-9
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+17
-8
csrc/pos_encoding.cpp
csrc/pos_encoding.cpp
+0
-2
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+22
-27
setup.py
setup.py
+7
-0
tests/kernels/activation.py
tests/kernels/activation.py
+30
-0
tests/kernels/attention.py
tests/kernels/attention.py
+31
-15
tests/kernels/cache.py
tests/kernels/cache.py
+5
-5
tests/kernels/pos_encoding.py
tests/kernels/pos_encoding.py
+4
-6
No files found.
cacheflow/models/activation.py
0 → 100644
View file @
897cb2ae
import
torch
import
torch.nn
as
nn
from
cacheflow
import
activation_ops
class
SiluAndMul
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
,
# (num_tokens, 2 * d)
)
->
torch
.
Tensor
:
# (num_tokens, d)
num_tokens
=
x
.
shape
[
0
]
d
=
x
.
shape
[
1
]
//
2
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
activation_ops
.
silu_and_mul
(
out
,
x
)
return
out
cacheflow/models/attention.py
View file @
897cb2ae
from
typing
import
List
,
Optional
from
typing
import
Optional
from
flash_attn.flash_att
ention
import
F
lash
Attention
from
flash_attn.flash_att
n_interface
import
_f
lash
_attn_forward
import
torch
import
torch.nn
as
nn
...
...
@@ -16,40 +16,38 @@ class GPTCacheFlowAttention(nn.Module):
super
().
__init__
()
self
.
scale
=
float
(
scale
)
self
.
flash_attn
=
FlashAttention
(
softmax_scale
=
self
.
scale
)
def
multi_query_kv_attention
(
self
,
output
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
query
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
key
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
value
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
prompt_lens
:
List
[
int
],
cumulative_prompt_lens
:
torch
.
Tensor
,
# [num_prompts + 1]
max_prompt_len
:
int
,
)
->
None
:
if
query
.
dtype
==
torch
.
float
:
raise
ValueError
(
'The float data type is not supported by '
'FlashAttention. Use the half data type instead.'
)
head_size
=
query
.
shape
[
2
]
head_size
=
query
.
shape
[
-
1
]
if
head_size
>
128
:
raise
ValueError
(
'FlashAttention does not support head_size > 128.'
)
device
=
query
.
device
prefix_sum
=
[
0
]
for
prompt_len
in
prompt_lens
:
prefix_sum
.
append
(
prefix_sum
[
-
1
]
+
prompt_len
)
prefix_sum
=
torch
.
tensor
(
prefix_sum
,
dtype
=
torch
.
int
,
device
=
device
)
max_prompt_len
=
max
(
prompt_lens
)
# FIXME(woosuk): Unnecessary copy. Optimize this.
qkv
=
torch
.
stack
([
query
,
key
,
value
],
dim
=
1
)
out
=
self
.
flash_attn
(
qkv
,
cu_seqlens
=
prefix_sum
,
max_s
=
max_prompt_
le
n
,
# Directly call FlashAttention's internal function to avoid allocating
# a new tensor for the output.
_flash_attn_forward
(
query
,
key
,
value
,
output
,
cumulative_prompt_lens
,
cumulative_prompt_lens
,
max_prompt_len
,
max_prompt_len
,
dropout_p
=
0.0
,
soft
max_s
cale
=
self
.
sca
le
,
causal
=
True
,
)[
0
]
# FIXME(woosuk): Unnecessary copy. Optimize this.
output
.
copy_
(
out
,
non_blocking
=
True
)
return_softmax
=
False
,
)
def
single_query_cached_kv_attention
(
self
,
...
...
@@ -90,21 +88,18 @@ class GPTCacheFlowAttention(nn.Module):
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
# [num_tokens, num_heads * head_size]
# Pre-allocate the output tensor.
output
=
torch
.
empty_like
(
query
)
# Prune out paddings if any.
query
=
query
[:
input_metadata
.
num_valid_tokens
]
key
=
key
[:
input_metadata
.
num_valid_tokens
]
value
=
value
[:
input_metadata
.
num_valid_tokens
]
# NOTE: The query, key, and value tensors must be sliced from a qkv
# tensor of shape [num_tokens, 3 * num_heads * head_size].
# Reshape the
input
tensors.
# Reshape the
query, key, and value
tensors.
num_heads
=
value_cache
.
shape
[
1
]
head_size
=
value_cache
.
shape
[
2
]
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
key
=
key
.
view
(
-
1
,
num_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_heads
,
head_size
)
output
=
output
.
view
(
-
1
,
num_heads
,
head_size
)
# Pre-allocate the output tensor.
output
=
torch
.
empty_like
(
query
)
# Compute the attention op for prompts.
num_prompt_tokens
=
input_metadata
.
num_prompt_tokens
...
...
@@ -114,7 +109,8 @@ class GPTCacheFlowAttention(nn.Module):
query
[:
num_prompt_tokens
],
key
[:
num_prompt_tokens
],
value
[:
num_prompt_tokens
],
input_metadata
.
prompt_lens
,
input_metadata
.
cumulative_prompt_lens
,
input_metadata
.
max_prompt_len
,
)
# Wait until the cache op is done.
...
...
@@ -122,14 +118,22 @@ class GPTCacheFlowAttention(nn.Module):
cache_event
.
wait
()
# Reshape the keys and values and store them in the cache.
num_valid_tokens
=
input_metadata
.
num_valid_tokens
if
num_valid_tokens
>
0
:
# The stride is 3 because the key and value are sliced from qkv.
cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
input_metadata
.
slot_mapping
)
key
[:
num_valid_tokens
],
value
[:
num_valid_tokens
],
key_cache
,
value_cache
,
input_metadata
.
slot_mapping
,
)
if
input_metadata
.
num_generation_tokens
>
0
:
# Compute the attention op for generation tokens.
self
.
single_query_cached_kv_attention
(
output
[
num_prompt_tokens
:],
query
[
num_prompt_tokens
:],
output
[
num_prompt_tokens
:
num_valid_tokens
],
query
[
num_prompt_tokens
:
num_valid_tokens
],
key_cache
,
value_cache
,
input_metadata
)
...
...
@@ -186,19 +190,15 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
)
->
torch
.
Tensor
:
# [num_tokens, num_heads * head_size]
# Apply rotary embedding to the query and key before passing them
# to the attention op.
out_query
=
torch
.
empty_like
(
query
)
out_key
=
torch
.
empty_like
(
key
)
pos_encoding_ops
.
rotary_embedding_neox
(
out_query
,
out_key
,
positions
,
query
,
key
,
self
.
cos_sin_cache
,
)
return
super
().
forward
(
out_
query
,
out_
key
,
query
,
key
,
value
,
key_cache
,
value_cache
,
...
...
cacheflow/models/input_metadata.py
View file @
897cb2ae
...
...
@@ -12,6 +12,7 @@ class InputMetadata:
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
seq_logprobs
:
Dict
[
int
,
float
],
# Seq id -> cumulative logprobs.
prompt_lens
:
List
[
int
],
cumulative_prompt_lens
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
max_context_len
:
int
,
...
...
@@ -20,6 +21,7 @@ class InputMetadata:
self
.
seq_groups
=
seq_groups
self
.
seq_logprobs
=
seq_logprobs
self
.
prompt_lens
=
prompt_lens
self
.
cumulative_prompt_lens
=
cumulative_prompt_lens
self
.
slot_mapping
=
slot_mapping
self
.
context_lens
=
context_lens
self
.
max_context_len
=
max_context_len
...
...
@@ -27,6 +29,7 @@ class InputMetadata:
self
.
num_prompts
=
len
(
prompt_lens
)
self
.
num_prompt_tokens
=
sum
(
prompt_lens
)
self
.
max_prompt_len
=
max
(
prompt_lens
)
if
prompt_lens
else
0
self
.
num_generation_tokens
=
context_lens
.
shape
[
0
]
self
.
num_valid_tokens
=
slot_mapping
.
shape
[
0
]
if
block_tables
.
numel
()
>
0
:
...
...
@@ -40,11 +43,13 @@ class InputMetadata:
return
(
f
'InputMetadata('
f
'num_prompts=
{
self
.
num_prompts
}
, '
f
'num_prompt_tokens=
{
self
.
num_prompt_tokens
}
, '
f
'max_prompt_len=
{
self
.
max_prompt_len
}
, '
f
'num_generation_tokens=
{
self
.
num_generation_tokens
}
, '
f
'num_valid_tokens=
{
self
.
num_valid_tokens
}
, '
f
'max_num_blocks_per_seq=
{
self
.
max_num_blocks_per_seq
}
, '
f
'max_context_len=
{
self
.
max_context_len
}
), '
f
'prompt_lens=
{
self
.
prompt_lens
}
, '
f
'cumulative_prompt_lens=
{
self
.
cumulative_prompt_lens
}
, '
f
'slot_mapping=
{
self
.
slot_mapping
}
, '
f
'context_lens=
{
self
.
context_lens
}
, '
f
'block_tables=
{
self
.
block_tables
}
)'
)
cacheflow/models/llama.py
View file @
897cb2ae
...
...
@@ -11,6 +11,7 @@ from torch import nn
from
transformers
import
LlamaConfig
from
cacheflow.models
import
InputMetadata
from
cacheflow.models.activation
import
SiluAndMul
from
cacheflow.models.attention
import
LlamaCacheFlowAttention
from
cacheflow.models.layernorm
import
RMSNorm
from
cacheflow.models.sample
import
Sampler
...
...
@@ -39,16 +40,14 @@ class LlamaMLP(nn.Module):
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
assert
hidden_act
==
'silu'
self
.
act_fn
=
nn
.
SiLU
()
if
hidden_act
!=
'silu'
:
raise
ValueError
(
f
'Unsupported activation:
{
hidden_act
}
. '
'Only silu is supported for now.'
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
gate_up
=
gate_up
.
reshape
(
gate_up
.
shape
[:
-
1
]
+
(
2
,
-
1
))
gate
,
up
=
torch
.
split
(
gate_up
,
1
,
dim
=-
2
)
gate
=
gate
.
squeeze
(
dim
=-
2
).
contiguous
()
up
=
up
.
squeeze
(
dim
=-
2
).
contiguous
()
x
=
self
.
act_fn
(
gate
)
*
up
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
...
...
@@ -94,11 +93,7 @@ class LlamaAttention(nn.Module):
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
=
qkv
.
reshape
(
qkv
.
shape
[:
-
1
]
+
(
3
,
-
1
))
q
,
k
,
v
=
torch
.
split
(
qkv
,
1
,
dim
=-
2
)
q
=
q
.
squeeze
(
dim
=-
2
).
contiguous
()
k
=
k
.
squeeze
(
dim
=-
2
).
contiguous
()
v
=
v
.
squeeze
(
dim
=-
2
).
contiguous
()
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
...
...
cacheflow/models/opt.py
View file @
897cb2ae
...
...
@@ -69,17 +69,14 @@ class OPTAttention(nn.Module):
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
=
qkv
.
reshape
(
qkv
.
shape
[:
-
1
]
+
(
3
,
-
1
))
q
,
k
,
v
=
torch
.
split
(
qkv
,
1
,
dim
=-
2
)
q
=
q
.
squeeze
(
dim
=-
2
).
contiguous
()
k
=
k
.
squeeze
(
dim
=-
2
).
contiguous
()
v
=
v
.
squeeze
(
dim
=-
2
).
contiguous
()
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
class
OPTDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
OPTConfig
):
...
...
cacheflow/worker/worker.py
View file @
897cb2ae
...
...
@@ -128,6 +128,11 @@ class Worker:
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
cumulative_prompt_lens
:
List
[
int
]
=
[
0
]
for
prompt_len
in
prompt_lens
:
cumulative_prompt_lens
.
append
(
cumulative_prompt_lens
[
-
1
]
+
prompt_len
)
# Add generation tokens.
max_context_len
=
0
max_num_blocks_per_seq
=
0
...
...
@@ -183,11 +188,14 @@ class Worker:
for
block_table
in
generation_block_tables
]
block_tables_tensor
=
torch
.
tensor
(
padded_block_tables
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
cumulative_prompt_lens_tensor
=
torch
.
tensor
(
cumulative_prompt_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
input_metadata
=
InputMetadata
(
seq_groups
=
seq_groups
,
seq_logprobs
=
seq_logprobs
,
prompt_lens
=
prompt_lens
,
cumulative_prompt_lens
=
cumulative_prompt_lens_tensor
,
slot_mapping
=
slot_mapping_tensor
,
context_lens
=
context_lens_tensor
,
max_context_len
=
max_context_len
,
...
...
csrc/activation.cpp
0 → 100644
View file @
897cb2ae
#include <torch/extension.h>
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"silu_and_mul"
,
&
silu_and_mul
,
"Activation function used in SwiGLU."
);
}
csrc/activation_kernels.cu
0 → 100644
View file @
897cb2ae
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
namespace
cacheflow
{
template
<
typename
T
>
__device__
__forceinline__
T
silu
(
const
T
&
x
)
{
// x * sigmoid(x)
return
(
T
)
(((
float
)
x
)
/
(
1.0
f
+
expf
((
float
)
-
x
)));
}
template
<
typename
scalar_t
>
__global__
void
silu_and_mul_kernel
(
scalar_t
*
__restrict__
out
,
// [num_tokens, d]
const
scalar_t
*
__restrict__
input
,
// [num_tokens, 2, d]
const
int
d
)
{
const
int
token_idx
=
blockIdx
.
x
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
__ldg
(
&
input
[
token_idx
*
2
*
d
+
idx
]);
const
scalar_t
y
=
__ldg
(
&
input
[
token_idx
*
2
*
d
+
d
+
idx
]);
out
[
token_idx
*
d
+
idx
]
=
silu
(
x
)
*
y
;
}
}
}
// namespace cacheflow
void
silu_and_mul
(
torch
::
Tensor
&
out
,
// [num_tokens, d]
torch
::
Tensor
&
input
)
// [num_tokens, 2 * d]
{
int
num_tokens
=
input
.
size
(
0
);
int
d
=
input
.
size
(
1
)
/
2
;
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
d
,
1024
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
scalar_type
(),
"silu_and_mul_kernel"
,
[
&
]
{
cacheflow
::
silu_and_mul_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
d
);
});
}
csrc/attention_kernels.cu
View file @
897cb2ae
...
...
@@ -25,7 +25,8 @@ __global__ void single_query_cached_kv_attention_kernel(
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
)
{
const
int
max_num_blocks_per_seq
,
const
int
q_stride
)
{
constexpr
int
THREAD_GROUP_SIZE
=
WARP_SIZE
/
BLOCK_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
thread_idx
=
threadIdx
.
x
;
...
...
@@ -56,7 +57,8 @@ __global__ void single_query_cached_kv_attention_kernel(
// For example, if the the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
// th vectors of the query, and so on.
const
scalar_t
*
q_ptr
=
q
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
const
scalar_t
*
q_ptr
=
q
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
Q_vec
q_vecs
[
NUM_VECS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_VECS_PER_THREAD
;
i
++
)
{
...
...
@@ -264,7 +266,8 @@ __global__ void single_query_cached_kv_attention_kernel(
scale, \
block_tables_ptr, \
context_lens_ptr, \
max_num_blocks_per_seq);
max_num_blocks_per_seq, \
query_stride);
// TODO(woosuk): Tune NUM_THREADS.
template
<
...
...
@@ -284,6 +287,7 @@ void single_query_cached_kv_attention_launcher(
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
int
max_num_blocks_per_seq
=
block_tables
.
size
(
1
);
int
query_stride
=
query
.
stride
(
0
);
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
...
...
@@ -333,13 +337,13 @@ void single_query_cached_kv_attention_launcher(
}
void
single_query_cached_kv_attention
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
context_lens
,
// [num_seqs]
int
block_size
,
int
max_context_len
)
{
// TODO(woosuk): Support BF16.
...
...
csrc/cache_kernels.cu
View file @
897cb2ae
...
...
@@ -81,6 +81,8 @@ __global__ void reshape_and_cache_kernel(
scalar_t
*
__restrict__
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
scalar_t
*
__restrict__
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
const
int
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int
key_stride
,
const
int
value_stride
,
const
int
num_heads
,
const
int
head_size
,
const
int
block_size
,
...
...
@@ -92,7 +94,8 @@ __global__ void reshape_and_cache_kernel(
const
int
n
=
num_heads
*
head_size
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int
src_idx
=
token_idx
*
n
+
i
;
const
int
src_key_idx
=
token_idx
*
key_stride
+
i
;
const
int
src_value_idx
=
token_idx
*
value_stride
+
i
;
const
int
head_idx
=
i
/
head_size
;
const
int
head_offset
=
i
%
head_size
;
...
...
@@ -108,25 +111,29 @@ __global__ void reshape_and_cache_kernel(
+
head_idx
*
head_size
*
block_size
+
head_offset
*
block_size
+
block_offset
;
key_cache
[
tgt_key_idx
]
=
__ldg
(
&
key
[
src_idx
]);
value_cache
[
tgt_value_idx
]
=
__ldg
(
&
value
[
src_idx
]);
key_cache
[
tgt_key_idx
]
=
__ldg
(
&
key
[
src_
key_
idx
]);
value_cache
[
tgt_value_idx
]
=
__ldg
(
&
value
[
src_
value_
idx
]);
}
}
}
// namespace cacheflow
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
slot_mapping
)
{
torch
::
Tensor
&
key
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
value
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
torch
::
Tensor
&
slot_mapping
)
// [num_tokens]
{
int
num_tokens
=
key
.
size
(
0
);
int
num_heads
=
key
.
size
(
1
);
int
head_size
=
key
.
size
(
2
);
int
block_size
=
key_cache
.
size
(
3
);
int
x
=
key_cache
.
size
(
4
);
int
key_stride
=
key
.
stride
(
0
);
int
value_stride
=
value
.
stride
(
0
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
...
@@ -140,6 +147,8 @@ void reshape_and_cache(
key_cache
.
data_ptr
<
scalar_t
>
(),
value_cache
.
data_ptr
<
scalar_t
>
(),
slot_mapping
.
data_ptr
<
int
>
(),
key_stride
,
value_stride
,
num_heads
,
head_size
,
block_size
,
...
...
csrc/pos_encoding.cpp
View file @
897cb2ae
#include <torch/extension.h>
void
rotary_embedding_neox
(
torch
::
Tensor
&
out_query
,
torch
::
Tensor
&
out_key
,
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
...
...
csrc/pos_encoding_kernels.cu
View file @
897cb2ae
...
...
@@ -5,12 +5,11 @@ namespace cacheflow {
template
<
typename
scalar_t
>
__global__
void
rotary_embedding_neox_kernel
(
scalar_t
*
__restrict__
out_query
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
out_key
,
// [num_tokens, num_heads, head_size]
const
int64_t
*
__restrict__
positions
,
// [num_tokens]
const
scalar_t
*
__restrict__
query
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
query
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, head_size // 2]
const
int
stride
,
const
int
num_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
...
...
@@ -19,41 +18,36 @@ __global__ void rotary_embedding_neox_kernel(
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
head_size
;
const
int
embed_dim
=
head_size
/
2
;
const
int
n
=
num_heads
*
head_size
;
const
int
n
=
num_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int
idx
=
token_idx
*
n
+
i
;
const
int
head_idx
=
i
/
embed_dim
;
const
int
token_head
=
token_idx
*
stride
+
head_idx
*
head_size
;
const
int
head_idx
=
i
/
head_size
;
const
int
head_offset
=
i
%
head_size
;
const
int
token_head
=
token_idx
*
n
+
head_idx
*
head_size
;
const
bool
is_first_half
=
head_offset
<
embed_dim
;
const
int
rot_offset
=
head_offset
%
embed_dim
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int
x_index
=
rot_offset
;
const
int
y_index
=
embed_dim
+
rot_offset
;
const
int
out_x
=
token_idx
*
stride
+
head_idx
*
head_size
+
x_index
;
const
int
out_y
=
token_idx
*
stride
+
head_idx
*
head_size
+
y_index
;
const
scalar_t
cos
=
__ldg
(
cache_ptr
+
x_index
);
const
scalar_t
sin
=
__ldg
(
cache_ptr
+
y_index
);
const
scalar_t
q_x
=
__ldg
(
query
+
token_head
+
x_index
);
const
scalar_t
q_y
=
__ldg
(
query
+
token_head
+
y_index
);
const
scalar_t
q_cos
=
is_first_half
?
q_x
:
q_y
;
const
scalar_t
q_sin
=
is_first_half
?
-
q_y
:
q_x
;
out_query
[
idx
]
=
q_cos
*
cos
+
q_sin
*
sin
;
const
scalar_t
q_x
=
query
[
token_head
+
x_index
];
const
scalar_t
q_y
=
query
[
token_head
+
y_index
];
query
[
out_x
]
=
q_x
*
cos
-
q_y
*
sin
;
query
[
out_y
]
=
q_y
*
cos
+
q_x
*
sin
;
const
scalar_t
k_x
=
__ldg
(
key
+
token_head
+
x_index
);
const
scalar_t
k_y
=
__ldg
(
key
+
token_head
+
y_index
);
const
scalar_t
k_cos
=
is_first_half
?
k_x
:
k_y
;
const
scalar_t
k_sin
=
is_first_half
?
-
k_y
:
k_x
;
out_key
[
idx
]
=
k_cos
*
cos
+
k_sin
*
sin
;
const
scalar_t
k_x
=
key
[
token_head
+
x_index
];
const
scalar_t
k_y
=
key
[
token_head
+
y_index
];
key
[
out_x
]
=
k_x
*
cos
-
k_y
*
sin
;
key
[
out_y
]
=
k_y
*
cos
+
k_x
*
sin
;
}
}
}
// namespace cacheflow
void
rotary_embedding_neox
(
torch
::
Tensor
&
out_query
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
out_key
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
positions
,
// [num_tokens]
torch
::
Tensor
&
query
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [num_tokens, num_heads * head_size]
...
...
@@ -62,21 +56,22 @@ void rotary_embedding_neox(
int
num_tokens
=
query
.
size
(
0
);
int
head_size
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
1
)
/
head_size
;
int
stride
=
query
.
stride
(
0
);
TORCH_CHECK
(
stride
==
key
.
stride
(
0
));
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
dim3
block
(
std
::
min
(
num_heads
*
head_size
/
2
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
query
.
scalar_type
(),
"rotary_embedding_neox"
,
[
&
]
{
cacheflow
::
rotary_embedding_neox_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out_query
.
data_ptr
<
scalar_t
>
(),
out_key
.
data_ptr
<
scalar_t
>
(),
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
stride
,
num_heads
,
head_size
);
});
...
...
setup.py
View file @
897cb2ae
...
...
@@ -39,6 +39,13 @@ layernorm_extension = cpp_extension.CUDAExtension(
)
ext_modules
.
append
(
layernorm_extension
)
activation_extension
=
cpp_extension
.
CUDAExtension
(
name
=
'cacheflow.activation_ops'
,
sources
=
[
'csrc/activation.cpp'
,
'csrc/activation_kernels.cu'
],
extra_compile_args
=
{
'cxx'
:
CXX_FLAGS
,
'nvcc'
:
NVCC_FLAGS
},
)
ext_modules
.
append
(
activation_extension
)
setuptools
.
setup
(
name
=
'cacheflow'
,
ext_modules
=
ext_modules
,
...
...
tests/kernels/activation.py
0 → 100644
View file @
897cb2ae
import
torch
import
torch.nn.functional
as
F
from
cacheflow
import
activation_ops
def
ref_silu_and_mul
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
,
x2
=
x
.
chunk
(
chunks
=
2
,
dim
=
1
)
return
F
.
silu
(
x1
)
*
x2
@
torch
.
inference_mode
()
def
test_silu_and_mul
(
num_tokens
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
)
->
None
:
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
,
device
=
'cuda'
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
activation_ops
.
silu_and_mul
(
out
,
x
)
ref_out
=
ref_silu_and_mul
(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
if
__name__
==
'__main__'
:
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
num_tokens
in
[
7
,
83
,
2048
]:
for
d
in
[
512
,
4096
,
13824
]:
print
(
f
'Testing dtype=
{
dtype
}
, num_tokens=
{
num_tokens
}
, d=
{
d
}
'
)
test_silu_and_mul
(
num_tokens
,
d
,
dtype
)
tests/kernels/attention.py
View file @
897cb2ae
import
random
from
typing
import
List
,
Optional
from
flash_attn.flash_att
ention
import
F
lash
Attention
from
flash_attn.flash_att
n_interface
import
_f
lash
_attn_forward
import
torch
from
cacheflow
import
attention_ops
...
...
@@ -105,8 +105,9 @@ def test_single_query_cached_kv_attention(
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
)
->
None
:
query
=
torch
.
randn
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
qkv
=
torch
.
randn
(
num_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
query
,
_
,
_
=
qkv
.
unbind
(
dim
=
1
)
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_block_shape
=
(
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_cache
=
torch
.
randn
(
...
...
@@ -115,6 +116,11 @@ def test_single_query_cached_kv_attention(
value_cache
=
torch
.
randn
(
size
=
(
num_blocks
,
*
value_block_shape
),
dtype
=
dtype
,
device
=
'cuda'
)
# Adjust the range of the values to reduce precision errors.
query
=
query
/
(
head_size
**
0.5
)
key_cache
=
key_cache
/
(
head_size
**
0.5
)
value_cache
=
value_cache
/
(
head_size
**
0.5
)
context_lens
=
[
random
.
randint
(
1
,
MAX_SEQ_LEN
)
for
_
in
range
(
num_tokens
)]
max_context_len
=
max
(
context_lens
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
...
...
@@ -130,7 +136,8 @@ def test_single_query_cached_kv_attention(
block_tables
=
torch
.
tensor
(
block_tables
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
attention_ops
.
single_query_cached_kv_attention
(
output
,
query
,
...
...
@@ -175,19 +182,28 @@ def test_multi_query_kv_attention(
cu_seq_lens
=
torch
.
tensor
(
cu_seq_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
query
=
torch
.
randn
(
qkv
=
torch
.
randn
(
num_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
# Adjust the range of the values to reduce precision errors.
qkv
=
qkv
/
(
head_size
**
0.5
)
query
,
key
,
value
=
qkv
.
unbind
(
dim
=
1
)
output
=
torch
.
empty
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
key
=
torch
.
rand_like
(
query
)
value
=
torch
.
rand_like
(
query
)
qkv
=
torch
.
stack
([
query
,
key
,
value
],
dim
=
1
)
flash_attn
=
FlashAttention
(
softmax_scale
=
scale
)
output
=
flash_attn
(
qkv
,
cu_seqlens
=
cu_seq_lens
,
max_s
=
max_seq_len
,
_flash_attn_forward
(
query
,
key
,
value
,
output
,
cu_seq_lens
,
cu_seq_lens
,
max_seq_len
,
max_seq_len
,
dropout_p
=
0.0
,
softmax_scale
=
scale
,
causal
=
True
,
)[
0
]
return_softmax
=
False
,
)
cu_seq_lens
=
cu_seq_lens
.
cpu
().
tolist
()
ref_output
=
ref_multi_query_kv_attention
(
...
...
tests/kernels/cache.py
View file @
897cb2ae
...
...
@@ -17,9 +17,9 @@ def test_reshape_and_cache(
slot_mapping
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
kv
_shape
=
(
num_tokens
,
num_heads
,
head_size
)
key
=
torch
.
randn
(
size
=
kv_shap
e
,
dtype
=
dtype
,
device
=
'cuda'
)
value
=
torch
.
randn
(
size
=
kv_shape
,
dtype
=
dtype
,
device
=
'cuda'
)
q
kv
=
torch
.
randn
(
num_tokens
,
3
,
num_heads
,
head_siz
e
,
dtype
=
dtype
,
device
=
'cuda'
)
_
,
key
,
value
=
qkv
.
unbind
(
dim
=
1
)
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
//
x
,
block_size
,
x
)
...
...
@@ -35,7 +35,7 @@ def test_reshape_and_cache(
for
i
in
range
(
num_tokens
):
reshaped_key
=
key
.
reshape
(
num_tokens
,
num_heads
,
head_size
//
x
,
x
)
block_idx
=
slot_mapping
[
i
]
//
block_size
block_idx
=
torch
.
div
(
slot_mapping
[
i
]
,
block_size
,
rounding_mode
=
'floor'
)
block_offset
=
slot_mapping
[
i
]
%
block_size
cloned_key_cache
[
block_idx
,
:,
:,
block_offset
,
:]
=
reshaped_key
[
i
]
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
...
...
tests/kernels/pos_encoding.py
View file @
897cb2ae
...
...
@@ -85,15 +85,13 @@ def test_rotary_embedding_neox(
cos_sin_cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cos_sin_cache
=
cos_sin_cache
.
to
(
dtype
=
dtype
,
device
=
'cuda'
)
# Run the kernel.
out_query
=
torch
.
empty_like
(
query
)
out_key
=
torch
.
empty_like
(
key
)
# Run the kernel.
The kernel is in-place, so we need to clone the inputs.
out_query
=
query
.
clone
(
)
out_key
=
key
.
clone
(
)
pos_encoding_ops
.
rotary_embedding_neox
(
positions
,
out_query
,
out_key
,
positions
,
query
,
key
,
cos_sin_cache
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment