Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
d6e4a130
Unverified
Commit
d6e4a130
authored
Feb 26, 2024
by
Woosuk Kwon
Committed by
GitHub
Feb 26, 2024
Browse files
[Minor] Remove gather_cached_kv kernel (#3043)
parent
cfc15a10
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
0 additions
and
172 deletions
+0
-172
csrc/cache.h
csrc/cache.h
+0
-7
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+0
-161
csrc/pybind.cpp
csrc/pybind.cpp
+0
-4
No files found.
csrc/cache.h
View file @
d6e4a130
...
@@ -23,13 +23,6 @@ void reshape_and_cache(
...
@@ -23,13 +23,6 @@ void reshape_and_cache(
torch
::
Tensor
&
slot_mapping
,
torch
::
Tensor
&
slot_mapping
,
const
std
::
string
&
kv_cache_dtype
);
const
std
::
string
&
kv_cache_dtype
);
void
gather_cached_kv
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
slot_mapping
);
// Just for unittest
// Just for unittest
void
convert_fp8_e5m2
(
void
convert_fp8_e5m2
(
torch
::
Tensor
&
src_cache
,
torch
::
Tensor
&
src_cache
,
...
...
csrc/cache_kernels.cu
View file @
d6e4a130
...
@@ -269,167 +269,6 @@ void reshape_and_cache(
...
@@ -269,167 +269,6 @@ void reshape_and_cache(
namespace
vllm
{
namespace
vllm
{
// Grid: (num_blocks, block_size).
template
<
typename
scalar_t
>
__global__
void
gather_cached_kv_kernel
(
scalar_t
*
__restrict__
key
,
// [num_tokens, [stride], num_heads, head_size]
scalar_t
*
__restrict__
value
,
// [num_tokens, [stride], num_heads, head_size]
const
scalar_t
*
__restrict__
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
const
scalar_t
*
__restrict__
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
const
int
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int
key_stride
,
const
int
value_stride
,
const
int
num_heads
,
const
int
head_size
,
const
int
block_size
,
const
int
x
)
{
const
int
token_idx
=
blockIdx
.
x
;
const
int
slot_idx
=
slot_mapping
[
token_idx
];
const
int
block_idx
=
slot_idx
/
block_size
;
const
int
block_offset
=
slot_idx
%
block_size
;
const
int
num_tokens
=
num_heads
*
head_size
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_tokens
;
i
+=
blockDim
.
x
)
{
const
int
tgt_key_idx
=
token_idx
*
key_stride
+
i
;
const
int
tgt_value_idx
=
token_idx
*
value_stride
+
i
;
const
int
head_idx
=
i
/
head_size
;
const
int
head_offset
=
i
%
head_size
;
const
int
x_idx
=
head_offset
/
x
;
// the offset of the [head_size/x] dimension
const
int
x_offset
=
head_offset
%
x
;
const
int
src_key_idx
=
block_idx
*
num_heads
*
(
head_size
/
x
)
*
block_size
*
x
+
head_idx
*
(
head_size
/
x
)
*
block_size
*
x
+
x_idx
*
block_size
*
x
+
block_offset
*
x
+
x_offset
;
const
int
src_value_idx
=
block_idx
*
num_heads
*
head_size
*
block_size
+
head_idx
*
head_size
*
block_size
+
head_offset
*
block_size
+
block_offset
;
key
[
tgt_key_idx
]
=
VLLM_LDG
(
&
key_cache
[
src_key_idx
]);
value
[
tgt_value_idx
]
=
VLLM_LDG
(
&
value_cache
[
src_value_idx
]);
}
}
template
<
typename
scalar_t
>
__global__
void
gather_cached_kv_kernel_optimized
(
scalar_t
*
__restrict__
key
,
// [num_tokens, [stride], num_heads, head_size]
scalar_t
*
__restrict__
value
,
// [num_tokens, [stride], num_heads, head_size]
const
scalar_t
*
__restrict__
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
const
scalar_t
*
__restrict__
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
const
int
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int
key_stride
,
const
int
value_stride
,
const
int
num_heads
,
const
int
head_size
,
const
int
block_size
,
const
int
x
)
{
const
int
token_idx
=
blockIdx
.
x
;
const
int
slot_idx
=
slot_mapping
[
token_idx
];
const
int
block_idx
=
slot_idx
/
block_size
;
const
int
block_offset
=
slot_idx
%
block_size
;
const
int
dim
=
num_heads
*
head_size
;
assert
(
dim
%
4
==
0
);
// this is true for known use cases
const
int
unroll_factor
=
4
;
const
int
unrolled_dim
=
dim
/
unroll_factor
;
for
(
int
i
=
threadIdx
.
x
;
i
<
unrolled_dim
;
i
+=
blockDim
.
x
)
{
int
tgt_key_indices
[
unroll_factor
];
int
tgt_value_indices
[
unroll_factor
];
int
src_key_indices
[
unroll_factor
];
int
src_value_indices
[
unroll_factor
];
scalar_t
keys_to_store
[
unroll_factor
];
scalar_t
values_to_store
[
unroll_factor
];
#pragma unroll
for
(
int
j
=
0
;
j
<
unroll_factor
;
++
j
)
{
int
index
=
i
+
j
*
unrolled_dim
;
const
int
tgt_key_idx
=
token_idx
*
key_stride
+
index
;
const
int
tgt_value_idx
=
token_idx
*
value_stride
+
index
;
const
int
head_idx
=
index
/
head_size
;
const
int
head_offset
=
index
%
head_size
;
const
int
x_idx
=
head_offset
/
x
;
const
int
x_offset
=
head_offset
%
x
;
const
int
src_key_idx
=
block_idx
*
num_heads
*
(
head_size
/
x
)
*
block_size
*
x
+
head_idx
*
(
head_size
/
x
)
*
block_size
*
x
+
x_idx
*
block_size
*
x
+
block_offset
*
x
+
x_offset
;
const
int
src_value_idx
=
block_idx
*
num_heads
*
head_size
*
block_size
+
head_idx
*
head_size
*
block_size
+
head_offset
*
block_size
+
block_offset
;
tgt_key_indices
[
j
]
=
tgt_key_idx
;
tgt_value_indices
[
j
]
=
tgt_value_idx
;
src_key_indices
[
j
]
=
src_key_idx
;
src_value_indices
[
j
]
=
src_value_idx
;
keys_to_store
[
j
]
=
VLLM_LDG
(
&
key_cache
[
src_key_idx
]);
values_to_store
[
j
]
=
VLLM_LDG
(
&
value_cache
[
src_value_idx
]);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
unroll_factor
;
++
j
)
{
key
[
tgt_key_indices
[
j
]]
=
keys_to_store
[
j
];
value
[
tgt_value_indices
[
j
]]
=
values_to_store
[
j
];
}
}
}
}
// namespace vllm
void
gather_cached_kv
(
torch
::
Tensor
&
key
,
// [out] [num_tokens, num_heads, head_size]
torch
::
Tensor
&
value
,
// [out] [num_tokens, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [in] [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [in] [num_blocks, num_heads, head_size, block_size]
torch
::
Tensor
&
slot_mapping
)
// [in] [num_tokens]
{
int
num_tokens
=
key
.
size
(
0
);
int
num_heads
=
key
.
size
(
1
);
int
head_size
=
key
.
size
(
2
);
int
block_size
=
key_cache
.
size
(
3
);
int
x
=
key_cache
.
size
(
4
);
int
key_stride
=
key
.
stride
(
0
);
int
value_stride
=
value
.
stride
(
0
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
key
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES
(
key
.
scalar_type
(),
"gather_cached_kv_kernel_optimized"
,
[
&
]
{
vllm
::
gather_cached_kv_kernel_optimized
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
key
.
data_ptr
<
scalar_t
>
(),
value
.
data_ptr
<
scalar_t
>
(),
key_cache
.
data_ptr
<
scalar_t
>
(),
value_cache
.
data_ptr
<
scalar_t
>
(),
slot_mapping
.
data_ptr
<
int
>
(),
key_stride
,
value_stride
,
num_heads
,
head_size
,
block_size
,
x
);
});
}
namespace
vllm
{
template
<
typename
Tout
,
typename
Tin
>
template
<
typename
Tout
,
typename
Tin
>
__global__
void
convert_fp8_e5m2_kernel
(
__global__
void
convert_fp8_e5m2_kernel
(
const
Tin
*
__restrict__
src_cache
,
const
Tin
*
__restrict__
src_cache
,
...
...
csrc/pybind.cpp
View file @
d6e4a130
...
@@ -79,10 +79,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -79,10 +79,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"reshape_and_cache"
,
"reshape_and_cache"
,
&
reshape_and_cache
,
&
reshape_and_cache
,
"Reshape the key and value tensors and cache them"
);
"Reshape the key and value tensors and cache them"
);
cache_ops
.
def
(
"gather_cached_kv"
,
&
gather_cached_kv
,
"Gather key and value from the cache into contiguous QKV tensors"
);
cache_ops
.
def
(
cache_ops
.
def
(
"convert_fp8_e5m2"
,
"convert_fp8_e5m2"
,
&
convert_fp8_e5m2
,
&
convert_fp8_e5m2
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment