Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
c413c41c
Commit
c413c41c
authored
Feb 18, 2023
by
Woosuk Kwon
Browse files
Add reshape_and_cache op
parent
ffad4e1e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
83 additions
and
0 deletions
+83
-0
csrc/cache.cpp
csrc/cache.cpp
+11
-0
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+72
-0
No files found.
csrc/cache.cpp
View file @
c413c41c
...
...
@@ -5,9 +5,20 @@ void copy_blocks(
torch
::
Tensor
&
dst
,
const
std
::
map
<
int64_t
,
int64_t
>&
block_mapping
);
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
slot_mapping
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"copy_cache_blocks"
,
&
copy_blocks
,
"Copy the cache blocks from src to dst"
);
m
.
def
(
"reshape_and_cache"
,
&
reshape_and_cache
,
"Reshape the key and value tensors and cache them"
);
}
csrc/cache_kernels.cu
View file @
c413c41c
...
...
@@ -2,6 +2,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <algorithm>
#include <cassert>
#include <map>
...
...
@@ -41,3 +42,74 @@ void copy_blocks(
stream
);
}
}
template
<
typename
scalar_t
>
__global__
void
reshape_and_cache_kernel
(
const
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
value
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
scalar_t
*
__restrict__
value_cache
,
// [num_blocks, num_heads, block_size, head_size]
const
int
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int
num_heads
,
const
int
head_size
,
const
int
block_size
,
const
int
x
)
{
const
int
token_idx
=
blockIdx
.
x
;
const
int
slot_idx
=
slot_mapping
[
token_idx
];
const
int
block_idx
=
slot_idx
/
block_size
;
const
int
block_offset
=
slot_idx
%
block_size
;
const
int
n
=
num_heads
*
head_size
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int
src_idx
=
token_idx
*
n
+
i
;
const
int
head_idx
=
i
/
head_size
;
const
int
head_offset
=
i
%
head_size
;
const
int
x_idx
=
head_offset
/
x
;
const
int
x_offset
=
head_offset
%
x
;
const
int
tgt_key_idx
=
block_idx
*
num_heads
*
(
head_size
/
x
)
*
block_size
*
x
+
head_idx
*
(
head_size
/
x
)
*
block_size
*
x
+
x_idx
*
block_size
*
x
+
block_offset
*
x
+
x_offset
;
const
int
tgt_value_idx
=
block_idx
*
num_heads
*
block_size
*
head_size
+
head_idx
*
block_size
*
head_size
+
block_offset
*
head_size
+
head_offset
;
key_cache
[
tgt_key_idx
]
=
__ldg
(
&
key
[
src_idx
]);
value_cache
[
tgt_value_idx
]
=
__ldg
(
&
value
[
src_idx
]);
}
}
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
slot_mapping
)
{
int
num_tokens
=
key
.
size
(
0
);
int
head_num
=
key
.
size
(
1
);
int
head_size
=
key
.
size
(
2
);
int
block_size
=
key_cache
.
size
(
3
);
int
x
=
key_cache
.
size
(
4
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
head_num
*
head_size
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
key
.
scalar_type
(),
"reshape_and_cache_kernel"
,
[
&
]
{
reshape_and_cache_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
key
.
data_ptr
<
scalar_t
>
(),
value
.
data_ptr
<
scalar_t
>
(),
key_cache
.
data_ptr
<
scalar_t
>
(),
value_cache
.
data_ptr
<
scalar_t
>
(),
slot_mapping
.
data_ptr
<
int
>
(),
head_num
,
head_size
,
block_size
,
x
);
});
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment