Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
897cb2ae
Unverified
Commit
897cb2ae
authored
Apr 02, 2023
by
Woosuk Kwon
Committed by
GitHub
Apr 02, 2023
Browse files
Optimize data movement (#20)
parent
1f01a18d
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
275 additions
and
135 deletions
+275
-135
cacheflow/models/activation.py
cacheflow/models/activation.py
+20
-0
cacheflow/models/attention.py
cacheflow/models/attention.py
+46
-46
cacheflow/models/input_metadata.py
cacheflow/models/input_metadata.py
+5
-0
cacheflow/models/llama.py
cacheflow/models/llama.py
+7
-12
cacheflow/models/opt.py
cacheflow/models/opt.py
+2
-5
cacheflow/worker/worker.py
cacheflow/worker/worker.py
+8
-0
csrc/activation.cpp
csrc/activation.cpp
+12
-0
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+46
-0
csrc/attention_kernels.cu
csrc/attention_kernels.cu
+13
-9
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+17
-8
csrc/pos_encoding.cpp
csrc/pos_encoding.cpp
+0
-2
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+22
-27
setup.py
setup.py
+7
-0
tests/kernels/activation.py
tests/kernels/activation.py
+30
-0
tests/kernels/attention.py
tests/kernels/attention.py
+31
-15
tests/kernels/cache.py
tests/kernels/cache.py
+5
-5
tests/kernels/pos_encoding.py
tests/kernels/pos_encoding.py
+4
-6
No files found.
cacheflow/models/activation.py
0 → 100644
View file @
897cb2ae
import
torch
import
torch.nn
as
nn
from
cacheflow
import
activation_ops
class
SiluAndMul
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
,
# (num_tokens, 2 * d)
)
->
torch
.
Tensor
:
# (num_tokens, d)
num_tokens
=
x
.
shape
[
0
]
d
=
x
.
shape
[
1
]
//
2
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
activation_ops
.
silu_and_mul
(
out
,
x
)
return
out
cacheflow/models/attention.py
View file @
897cb2ae
from
typing
import
List
,
Optional
from
typing
import
Optional
from
flash_attn.flash_att
ention
import
F
lash
Attention
from
flash_attn.flash_att
n_interface
import
_f
lash
_attn_forward
import
torch
import
torch.nn
as
nn
...
...
@@ -16,40 +16,38 @@ class GPTCacheFlowAttention(nn.Module):
super
().
__init__
()
self
.
scale
=
float
(
scale
)
self
.
flash_attn
=
FlashAttention
(
softmax_scale
=
self
.
scale
)
def
multi_query_kv_attention
(
self
,
output
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
query
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
key
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
value
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
prompt_lens
:
List
[
int
],
cumulative_prompt_lens
:
torch
.
Tensor
,
# [num_prompts + 1]
max_prompt_len
:
int
,
)
->
None
:
if
query
.
dtype
==
torch
.
float
:
raise
ValueError
(
'The float data type is not supported by '
'FlashAttention. Use the half data type instead.'
)
head_size
=
query
.
shape
[
2
]
head_size
=
query
.
shape
[
-
1
]
if
head_size
>
128
:
raise
ValueError
(
'FlashAttention does not support head_size > 128.'
)
device
=
query
.
device
prefix_sum
=
[
0
]
for
prompt_len
in
prompt_lens
:
prefix_sum
.
append
(
prefix_sum
[
-
1
]
+
prompt_len
)
prefix_sum
=
torch
.
tensor
(
prefix_sum
,
dtype
=
torch
.
int
,
device
=
device
)
max_prompt_len
=
max
(
prompt_lens
)
# FIXME(woosuk): Unnecessary copy. Optimize this.
qkv
=
torch
.
stack
([
query
,
key
,
value
],
dim
=
1
)
out
=
self
.
flash_attn
(
qkv
,
cu_seqlens
=
prefix_sum
,
max_s
=
max_prompt_
le
n
,
# Directly call FlashAttention's internal function to avoid allocating
# a new tensor for the output.
_flash_attn_forward
(
query
,
key
,
value
,
output
,
cumulative_prompt_lens
,
cumulative_prompt_lens
,
max_prompt_len
,
max_prompt_len
,
dropout_p
=
0.0
,
soft
max_s
cale
=
self
.
sca
le
,
causal
=
True
,
)[
0
]
# FIXME(woosuk): Unnecessary copy. Optimize this.
output
.
copy_
(
out
,
non_blocking
=
True
)
return_softmax
=
False
,
)
def
single_query_cached_kv_attention
(
self
,
...
...
@@ -90,21 +88,18 @@ class GPTCacheFlowAttention(nn.Module):
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
# [num_tokens, num_heads * head_size]
# Pre-allocate the output tensor.
output
=
torch
.
empty_like
(
query
)
# Prune out paddings if any.
query
=
query
[:
input_metadata
.
num_valid_tokens
]
key
=
key
[:
input_metadata
.
num_valid_tokens
]
value
=
value
[:
input_metadata
.
num_valid_tokens
]
# NOTE: The query, key, and value tensors must be sliced from a qkv
# tensor of shape [num_tokens, 3 * num_heads * head_size].
# Reshape the
input
tensors.
# Reshape the
query, key, and value
tensors.
num_heads
=
value_cache
.
shape
[
1
]
head_size
=
value_cache
.
shape
[
2
]
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
key
=
key
.
view
(
-
1
,
num_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_heads
,
head_size
)
output
=
output
.
view
(
-
1
,
num_heads
,
head_size
)
# Pre-allocate the output tensor.
output
=
torch
.
empty_like
(
query
)
# Compute the attention op for prompts.
num_prompt_tokens
=
input_metadata
.
num_prompt_tokens
...
...
@@ -114,7 +109,8 @@ class GPTCacheFlowAttention(nn.Module):
query
[:
num_prompt_tokens
],
key
[:
num_prompt_tokens
],
value
[:
num_prompt_tokens
],
input_metadata
.
prompt_lens
,
input_metadata
.
cumulative_prompt_lens
,
input_metadata
.
max_prompt_len
,
)
# Wait until the cache op is done.
...
...
@@ -122,14 +118,22 @@ class GPTCacheFlowAttention(nn.Module):
cache_event
.
wait
()
# Reshape the keys and values and store them in the cache.
num_valid_tokens
=
input_metadata
.
num_valid_tokens
if
num_valid_tokens
>
0
:
# The stride is 3 because the key and value are sliced from qkv.
cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
input_metadata
.
slot_mapping
)
key
[:
num_valid_tokens
],
value
[:
num_valid_tokens
],
key_cache
,
value_cache
,
input_metadata
.
slot_mapping
,
)
if
input_metadata
.
num_generation_tokens
>
0
:
# Compute the attention op for generation tokens.
self
.
single_query_cached_kv_attention
(
output
[
num_prompt_tokens
:],
query
[
num_prompt_tokens
:],
output
[
num_prompt_tokens
:
num_valid_tokens
],
query
[
num_prompt_tokens
:
num_valid_tokens
],
key_cache
,
value_cache
,
input_metadata
)
...
...
@@ -186,19 +190,15 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
)
->
torch
.
Tensor
:
# [num_tokens, num_heads * head_size]
# Apply rotary embedding to the query and key before passing them
# to the attention op.
out_query
=
torch
.
empty_like
(
query
)
out_key
=
torch
.
empty_like
(
key
)
pos_encoding_ops
.
rotary_embedding_neox
(
out_query
,
out_key
,
positions
,
query
,
key
,
self
.
cos_sin_cache
,
)
return
super
().
forward
(
out_
query
,
out_
key
,
query
,
key
,
value
,
key_cache
,
value_cache
,
...
...
cacheflow/models/input_metadata.py
View file @
897cb2ae
...
...
@@ -12,6 +12,7 @@ class InputMetadata:
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
seq_logprobs
:
Dict
[
int
,
float
],
# Seq id -> cumulative logprobs.
prompt_lens
:
List
[
int
],
cumulative_prompt_lens
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
max_context_len
:
int
,
...
...
@@ -20,6 +21,7 @@ class InputMetadata:
self
.
seq_groups
=
seq_groups
self
.
seq_logprobs
=
seq_logprobs
self
.
prompt_lens
=
prompt_lens
self
.
cumulative_prompt_lens
=
cumulative_prompt_lens
self
.
slot_mapping
=
slot_mapping
self
.
context_lens
=
context_lens
self
.
max_context_len
=
max_context_len
...
...
@@ -27,6 +29,7 @@ class InputMetadata:
self
.
num_prompts
=
len
(
prompt_lens
)
self
.
num_prompt_tokens
=
sum
(
prompt_lens
)
self
.
max_prompt_len
=
max
(
prompt_lens
)
if
prompt_lens
else
0
self
.
num_generation_tokens
=
context_lens
.
shape
[
0
]
self
.
num_valid_tokens
=
slot_mapping
.
shape
[
0
]
if
block_tables
.
numel
()
>
0
:
...
...
@@ -40,11 +43,13 @@ class InputMetadata:
return
(
f
'InputMetadata('
f
'num_prompts=
{
self
.
num_prompts
}
, '
f
'num_prompt_tokens=
{
self
.
num_prompt_tokens
}
, '
f
'max_prompt_len=
{
self
.
max_prompt_len
}
, '
f
'num_generation_tokens=
{
self
.
num_generation_tokens
}
, '
f
'num_valid_tokens=
{
self
.
num_valid_tokens
}
, '
f
'max_num_blocks_per_seq=
{
self
.
max_num_blocks_per_seq
}
, '
f
'max_context_len=
{
self
.
max_context_len
}
), '
f
'prompt_lens=
{
self
.
prompt_lens
}
, '
f
'cumulative_prompt_lens=
{
self
.
cumulative_prompt_lens
}
, '
f
'slot_mapping=
{
self
.
slot_mapping
}
, '
f
'context_lens=
{
self
.
context_lens
}
, '
f
'block_tables=
{
self
.
block_tables
}
)'
)
cacheflow/models/llama.py
View file @
897cb2ae
...
...
@@ -11,6 +11,7 @@ from torch import nn
from
transformers
import
LlamaConfig
from
cacheflow.models
import
InputMetadata
from
cacheflow.models.activation
import
SiluAndMul
from
cacheflow.models.attention
import
LlamaCacheFlowAttention
from
cacheflow.models.layernorm
import
RMSNorm
from
cacheflow.models.sample
import
Sampler
...
...
@@ -39,16 +40,14 @@ class LlamaMLP(nn.Module):
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
assert
hidden_act
==
'silu'
self
.
act_fn
=
nn
.
SiLU
()
if
hidden_act
!=
'silu'
:
raise
ValueError
(
f
'Unsupported activation:
{
hidden_act
}
. '
'Only silu is supported for now.'
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
gate_up
=
gate_up
.
reshape
(
gate_up
.
shape
[:
-
1
]
+
(
2
,
-
1
))
gate
,
up
=
torch
.
split
(
gate_up
,
1
,
dim
=-
2
)
gate
=
gate
.
squeeze
(
dim
=-
2
).
contiguous
()
up
=
up
.
squeeze
(
dim
=-
2
).
contiguous
()
x
=
self
.
act_fn
(
gate
)
*
up
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
...
...
@@ -94,11 +93,7 @@ class LlamaAttention(nn.Module):
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
=
qkv
.
reshape
(
qkv
.
shape
[:
-
1
]
+
(
3
,
-
1
))
q
,
k
,
v
=
torch
.
split
(
qkv
,
1
,
dim
=-
2
)
q
=
q
.
squeeze
(
dim
=-
2
).
contiguous
()
k
=
k
.
squeeze
(
dim
=-
2
).
contiguous
()
v
=
v
.
squeeze
(
dim
=-
2
).
contiguous
()
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
...
...
cacheflow/models/opt.py
View file @
897cb2ae
...
...
@@ -69,17 +69,14 @@ class OPTAttention(nn.Module):
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
=
qkv
.
reshape
(
qkv
.
shape
[:
-
1
]
+
(
3
,
-
1
))
q
,
k
,
v
=
torch
.
split
(
qkv
,
1
,
dim
=-
2
)
q
=
q
.
squeeze
(
dim
=-
2
).
contiguous
()
k
=
k
.
squeeze
(
dim
=-
2
).
contiguous
()
v
=
v
.
squeeze
(
dim
=-
2
).
contiguous
()
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
class
OPTDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
OPTConfig
):
...
...
cacheflow/worker/worker.py
View file @
897cb2ae
...
...
@@ -128,6 +128,11 @@ class Worker:
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
cumulative_prompt_lens
:
List
[
int
]
=
[
0
]
for
prompt_len
in
prompt_lens
:
cumulative_prompt_lens
.
append
(
cumulative_prompt_lens
[
-
1
]
+
prompt_len
)
# Add generation tokens.
max_context_len
=
0
max_num_blocks_per_seq
=
0
...
...
@@ -183,11 +188,14 @@ class Worker:
for
block_table
in
generation_block_tables
]
block_tables_tensor
=
torch
.
tensor
(
padded_block_tables
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
cumulative_prompt_lens_tensor
=
torch
.
tensor
(
cumulative_prompt_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
input_metadata
=
InputMetadata
(
seq_groups
=
seq_groups
,
seq_logprobs
=
seq_logprobs
,
prompt_lens
=
prompt_lens
,
cumulative_prompt_lens
=
cumulative_prompt_lens_tensor
,
slot_mapping
=
slot_mapping_tensor
,
context_lens
=
context_lens_tensor
,
max_context_len
=
max_context_len
,
...
...
csrc/activation.cpp
0 → 100644
View file @
897cb2ae
#include <torch/extension.h>
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"silu_and_mul"
,
&
silu_and_mul
,
"Activation function used in SwiGLU."
);
}
csrc/activation_kernels.cu
0 → 100644
View file @
897cb2ae
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
namespace
cacheflow
{
template
<
typename
T
>
__device__
__forceinline__
T
silu
(
const
T
&
x
)
{
// x * sigmoid(x)
return
(
T
)
(((
float
)
x
)
/
(
1.0
f
+
expf
((
float
)
-
x
)));
}
template
<
typename
scalar_t
>
__global__
void
silu_and_mul_kernel
(
scalar_t
*
__restrict__
out
,
// [num_tokens, d]
const
scalar_t
*
__restrict__
input
,
// [num_tokens, 2, d]
const
int
d
)
{
const
int
token_idx
=
blockIdx
.
x
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
__ldg
(
&
input
[
token_idx
*
2
*
d
+
idx
]);
const
scalar_t
y
=
__ldg
(
&
input
[
token_idx
*
2
*
d
+
d
+
idx
]);
out
[
token_idx
*
d
+
idx
]
=
silu
(
x
)
*
y
;
}
}
}
// namespace cacheflow
void
silu_and_mul
(
torch
::
Tensor
&
out
,
// [num_tokens, d]
torch
::
Tensor
&
input
)
// [num_tokens, 2 * d]
{
int
num_tokens
=
input
.
size
(
0
);
int
d
=
input
.
size
(
1
)
/
2
;
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
d
,
1024
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
scalar_type
(),
"silu_and_mul_kernel"
,
[
&
]
{
cacheflow
::
silu_and_mul_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
d
);
});
}
csrc/attention_kernels.cu
View file @
897cb2ae
...
...
@@ -25,7 +25,8 @@ __global__ void single_query_cached_kv_attention_kernel(
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
)
{
const
int
max_num_blocks_per_seq
,
const
int
q_stride
)
{
constexpr
int
THREAD_GROUP_SIZE
=
WARP_SIZE
/
BLOCK_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
thread_idx
=
threadIdx
.
x
;
...
...
@@ -56,7 +57,8 @@ __global__ void single_query_cached_kv_attention_kernel(
// For example, if the the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
// th vectors of the query, and so on.
const
scalar_t
*
q_ptr
=
q
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
const
scalar_t
*
q_ptr
=
q
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
Q_vec
q_vecs
[
NUM_VECS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_VECS_PER_THREAD
;
i
++
)
{
...
...
@@ -264,7 +266,8 @@ __global__ void single_query_cached_kv_attention_kernel(
scale, \
block_tables_ptr, \
context_lens_ptr, \
max_num_blocks_per_seq);
max_num_blocks_per_seq, \
query_stride);
// TODO(woosuk): Tune NUM_THREADS.
template
<
...
...
@@ -284,6 +287,7 @@ void single_query_cached_kv_attention_launcher(
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
int
max_num_blocks_per_seq
=
block_tables
.
size
(
1
);
int
query_stride
=
query
.
stride
(
0
);
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
...
...
@@ -333,13 +337,13 @@ void single_query_cached_kv_attention_launcher(
}
void
single_query_cached_kv_attention
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
context_lens
,
// [num_seqs]
int
block_size
,
int
max_context_len
)
{
// TODO(woosuk): Support BF16.
...
...
csrc/cache_kernels.cu
View file @
897cb2ae
...
...
@@ -81,6 +81,8 @@ __global__ void reshape_and_cache_kernel(
scalar_t
*
__restrict__
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
scalar_t
*
__restrict__
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
const
int
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int
key_stride
,
const
int
value_stride
,
const
int
num_heads
,
const
int
head_size
,
const
int
block_size
,
...
...
@@ -92,7 +94,8 @@ __global__ void reshape_and_cache_kernel(
const
int
n
=
num_heads
*
head_size
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int
src_idx
=
token_idx
*
n
+
i
;
const
int
src_key_idx
=
token_idx
*
key_stride
+
i
;
const
int
src_value_idx
=
token_idx
*
value_stride
+
i
;
const
int
head_idx
=
i
/
head_size
;
const
int
head_offset
=
i
%
head_size
;
...
...
@@ -108,25 +111,29 @@ __global__ void reshape_and_cache_kernel(
+
head_idx
*
head_size
*
block_size
+
head_offset
*
block_size
+
block_offset
;
key_cache
[
tgt_key_idx
]
=
__ldg
(
&
key
[
src_idx
]);
value_cache
[
tgt_value_idx
]
=
__ldg
(
&
value
[
src_idx
]);
key_cache
[
tgt_key_idx
]
=
__ldg
(
&
key
[
src_
key_
idx
]);
value_cache
[
tgt_value_idx
]
=
__ldg
(
&
value
[
src_
value_
idx
]);
}
}
}
// namespace cacheflow
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
slot_mapping
)
{
torch
::
Tensor
&
key
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
value
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
torch
::
Tensor
&
slot_mapping
)
// [num_tokens]
{
int
num_tokens
=
key
.
size
(
0
);
int
num_heads
=
key
.
size
(
1
);
int
head_size
=
key
.
size
(
2
);
int
block_size
=
key_cache
.
size
(
3
);
int
x
=
key_cache
.
size
(
4
);
int
key_stride
=
key
.
stride
(
0
);
int
value_stride
=
value
.
stride
(
0
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
...
@@ -140,6 +147,8 @@ void reshape_and_cache(
key_cache
.
data_ptr
<
scalar_t
>
(),
value_cache
.
data_ptr
<
scalar_t
>
(),
slot_mapping
.
data_ptr
<
int
>
(),
key_stride
,
value_stride
,
num_heads
,
head_size
,
block_size
,
...
...
csrc/pos_encoding.cpp
View file @
897cb2ae
#include <torch/extension.h>
void
rotary_embedding_neox
(
torch
::
Tensor
&
out_query
,
torch
::
Tensor
&
out_key
,
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
...
...
csrc/pos_encoding_kernels.cu
View file @
897cb2ae
...
...
@@ -5,12 +5,11 @@ namespace cacheflow {
template
<
typename
scalar_t
>
__global__
void
rotary_embedding_neox_kernel
(
scalar_t
*
__restrict__
out_query
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
out_key
,
// [num_tokens, num_heads, head_size]
const
int64_t
*
__restrict__
positions
,
// [num_tokens]
const
scalar_t
*
__restrict__
query
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
query
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, head_size // 2]
const
int
stride
,
const
int
num_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
...
...
@@ -19,41 +18,36 @@ __global__ void rotary_embedding_neox_kernel(
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
head_size
;
const
int
embed_dim
=
head_size
/
2
;
const
int
n
=
num_heads
*
head_size
;
const
int
n
=
num_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int
idx
=
token_idx
*
n
+
i
;
const
int
head_idx
=
i
/
embed_dim
;
const
int
token_head
=
token_idx
*
stride
+
head_idx
*
head_size
;
const
int
head_idx
=
i
/
head_size
;
const
int
head_offset
=
i
%
head_size
;
const
int
token_head
=
token_idx
*
n
+
head_idx
*
head_size
;
const
bool
is_first_half
=
head_offset
<
embed_dim
;
const
int
rot_offset
=
head_offset
%
embed_dim
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int
x_index
=
rot_offset
;
const
int
y_index
=
embed_dim
+
rot_offset
;
const
int
out_x
=
token_idx
*
stride
+
head_idx
*
head_size
+
x_index
;
const
int
out_y
=
token_idx
*
stride
+
head_idx
*
head_size
+
y_index
;
const
scalar_t
cos
=
__ldg
(
cache_ptr
+
x_index
);
const
scalar_t
sin
=
__ldg
(
cache_ptr
+
y_index
);
const
scalar_t
q_x
=
__ldg
(
query
+
token_head
+
x_index
);
const
scalar_t
q_y
=
__ldg
(
query
+
token_head
+
y_index
);
const
scalar_t
q_cos
=
is_first_half
?
q_x
:
q_y
;
const
scalar_t
q_sin
=
is_first_half
?
-
q_y
:
q_x
;
out_query
[
idx
]
=
q_cos
*
cos
+
q_sin
*
sin
;
const
scalar_t
q_x
=
query
[
token_head
+
x_index
];
const
scalar_t
q_y
=
query
[
token_head
+
y_index
];
query
[
out_x
]
=
q_x
*
cos
-
q_y
*
sin
;
query
[
out_y
]
=
q_y
*
cos
+
q_x
*
sin
;
const
scalar_t
k_x
=
__ldg
(
key
+
token_head
+
x_index
);
const
scalar_t
k_y
=
__ldg
(
key
+
token_head
+
y_index
);
const
scalar_t
k_cos
=
is_first_half
?
k_x
:
k_y
;
const
scalar_t
k_sin
=
is_first_half
?
-
k_y
:
k_x
;
out_key
[
idx
]
=
k_cos
*
cos
+
k_sin
*
sin
;
const
scalar_t
k_x
=
key
[
token_head
+
x_index
];
const
scalar_t
k_y
=
key
[
token_head
+
y_index
];
key
[
out_x
]
=
k_x
*
cos
-
k_y
*
sin
;
key
[
out_y
]
=
k_y
*
cos
+
k_x
*
sin
;
}
}
}
// namespace cacheflow
void
rotary_embedding_neox
(
torch
::
Tensor
&
out_query
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
out_key
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
positions
,
// [num_tokens]
torch
::
Tensor
&
query
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [num_tokens, num_heads * head_size]
...
...
@@ -62,21 +56,22 @@ void rotary_embedding_neox(
int
num_tokens
=
query
.
size
(
0
);
int
head_size
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
1
)
/
head_size
;
int
stride
=
query
.
stride
(
0
);
TORCH_CHECK
(
stride
==
key
.
stride
(
0
));
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
dim3
block
(
std
::
min
(
num_heads
*
head_size
/
2
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
query
.
scalar_type
(),
"rotary_embedding_neox"
,
[
&
]
{
cacheflow
::
rotary_embedding_neox_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out_query
.
data_ptr
<
scalar_t
>
(),
out_key
.
data_ptr
<
scalar_t
>
(),
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
stride
,
num_heads
,
head_size
);
});
...
...
setup.py
View file @
897cb2ae
...
...
@@ -39,6 +39,13 @@ layernorm_extension = cpp_extension.CUDAExtension(
)
ext_modules
.
append
(
layernorm_extension
)
activation_extension
=
cpp_extension
.
CUDAExtension
(
name
=
'cacheflow.activation_ops'
,
sources
=
[
'csrc/activation.cpp'
,
'csrc/activation_kernels.cu'
],
extra_compile_args
=
{
'cxx'
:
CXX_FLAGS
,
'nvcc'
:
NVCC_FLAGS
},
)
ext_modules
.
append
(
activation_extension
)
setuptools
.
setup
(
name
=
'cacheflow'
,
ext_modules
=
ext_modules
,
...
...
tests/kernels/activation.py
0 → 100644
View file @
897cb2ae
import
torch
import
torch.nn.functional
as
F
from
cacheflow
import
activation_ops
def
ref_silu_and_mul
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
,
x2
=
x
.
chunk
(
chunks
=
2
,
dim
=
1
)
return
F
.
silu
(
x1
)
*
x2
@
torch
.
inference_mode
()
def
test_silu_and_mul
(
num_tokens
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
)
->
None
:
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
,
device
=
'cuda'
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
activation_ops
.
silu_and_mul
(
out
,
x
)
ref_out
=
ref_silu_and_mul
(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
if
__name__
==
'__main__'
:
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
num_tokens
in
[
7
,
83
,
2048
]:
for
d
in
[
512
,
4096
,
13824
]:
print
(
f
'Testing dtype=
{
dtype
}
, num_tokens=
{
num_tokens
}
, d=
{
d
}
'
)
test_silu_and_mul
(
num_tokens
,
d
,
dtype
)
tests/kernels/attention.py
View file @
897cb2ae
import
random
from
typing
import
List
,
Optional
from
flash_attn.flash_att
ention
import
F
lash
Attention
from
flash_attn.flash_att
n_interface
import
_f
lash
_attn_forward
import
torch
from
cacheflow
import
attention_ops
...
...
@@ -105,8 +105,9 @@ def test_single_query_cached_kv_attention(
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
)
->
None
:
query
=
torch
.
randn
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
qkv
=
torch
.
randn
(
num_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
query
,
_
,
_
=
qkv
.
unbind
(
dim
=
1
)
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_block_shape
=
(
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_cache
=
torch
.
randn
(
...
...
@@ -115,6 +116,11 @@ def test_single_query_cached_kv_attention(
value_cache
=
torch
.
randn
(
size
=
(
num_blocks
,
*
value_block_shape
),
dtype
=
dtype
,
device
=
'cuda'
)
# Adjust the range of the values to reduce precision errors.
query
=
query
/
(
head_size
**
0.5
)
key_cache
=
key_cache
/
(
head_size
**
0.5
)
value_cache
=
value_cache
/
(
head_size
**
0.5
)
context_lens
=
[
random
.
randint
(
1
,
MAX_SEQ_LEN
)
for
_
in
range
(
num_tokens
)]
max_context_len
=
max
(
context_lens
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
...
...
@@ -130,7 +136,8 @@ def test_single_query_cached_kv_attention(
block_tables
=
torch
.
tensor
(
block_tables
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
attention_ops
.
single_query_cached_kv_attention
(
output
,
query
,
...
...
@@ -175,19 +182,28 @@ def test_multi_query_kv_attention(
cu_seq_lens
=
torch
.
tensor
(
cu_seq_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
query
=
torch
.
randn
(
qkv
=
torch
.
randn
(
num_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
# Adjust the range of the values to reduce precision errors.
qkv
=
qkv
/
(
head_size
**
0.5
)
query
,
key
,
value
=
qkv
.
unbind
(
dim
=
1
)
output
=
torch
.
empty
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
key
=
torch
.
rand_like
(
query
)
value
=
torch
.
rand_like
(
query
)
qkv
=
torch
.
stack
([
query
,
key
,
value
],
dim
=
1
)
flash_attn
=
FlashAttention
(
softmax_scale
=
scale
)
output
=
flash_attn
(
qkv
,
cu_seqlens
=
cu_seq_lens
,
max_s
=
max_seq_len
,
_flash_attn_forward
(
query
,
key
,
value
,
output
,
cu_seq_lens
,
cu_seq_lens
,
max_seq_len
,
max_seq_len
,
dropout_p
=
0.0
,
softmax_scale
=
scale
,
causal
=
True
,
)[
0
]
return_softmax
=
False
,
)
cu_seq_lens
=
cu_seq_lens
.
cpu
().
tolist
()
ref_output
=
ref_multi_query_kv_attention
(
...
...
tests/kernels/cache.py
View file @
897cb2ae
...
...
@@ -17,9 +17,9 @@ def test_reshape_and_cache(
slot_mapping
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
kv
_shape
=
(
num_tokens
,
num_heads
,
head_size
)
key
=
torch
.
randn
(
size
=
kv_shap
e
,
dtype
=
dtype
,
device
=
'cuda'
)
value
=
torch
.
randn
(
size
=
kv_shape
,
dtype
=
dtype
,
device
=
'cuda'
)
q
kv
=
torch
.
randn
(
num_tokens
,
3
,
num_heads
,
head_siz
e
,
dtype
=
dtype
,
device
=
'cuda'
)
_
,
key
,
value
=
qkv
.
unbind
(
dim
=
1
)
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
//
x
,
block_size
,
x
)
...
...
@@ -35,7 +35,7 @@ def test_reshape_and_cache(
for
i
in
range
(
num_tokens
):
reshaped_key
=
key
.
reshape
(
num_tokens
,
num_heads
,
head_size
//
x
,
x
)
block_idx
=
slot_mapping
[
i
]
//
block_size
block_idx
=
torch
.
div
(
slot_mapping
[
i
]
,
block_size
,
rounding_mode
=
'floor'
)
block_offset
=
slot_mapping
[
i
]
%
block_size
cloned_key_cache
[
block_idx
,
:,
:,
block_offset
,
:]
=
reshaped_key
[
i
]
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
...
...
tests/kernels/pos_encoding.py
View file @
897cb2ae
...
...
@@ -85,15 +85,13 @@ def test_rotary_embedding_neox(
cos_sin_cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cos_sin_cache
=
cos_sin_cache
.
to
(
dtype
=
dtype
,
device
=
'cuda'
)
# Run the kernel.
out_query
=
torch
.
empty_like
(
query
)
out_key
=
torch
.
empty_like
(
key
)
# Run the kernel.
The kernel is in-place, so we need to clone the inputs.
out_query
=
query
.
clone
(
)
out_key
=
key
.
clone
(
)
pos_encoding_ops
.
rotary_embedding_neox
(
positions
,
out_query
,
out_key
,
positions
,
query
,
key
,
cos_sin_cache
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment