Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
0f40557a
Unverified
Commit
0f40557a
authored
Apr 07, 2023
by
Woosuk Kwon
Committed by
GitHub
Apr 07, 2023
Browse files
Implement block copy kernel to optimize beam search (#32)
parent
a490aafa
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
155 additions
and
49 deletions
+155
-49
benchmark/benchmark_latency.py
benchmark/benchmark_latency.py
+6
-3
cacheflow/models/sample.py
cacheflow/models/sample.py
+3
-2
cacheflow/worker/cache_engine.py
cacheflow/worker/cache_engine.py
+4
-20
csrc/cache.cpp
csrc/cache.cpp
+2
-2
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+82
-22
tests/kernels/cache.py
tests/kernels/cache.py
+58
-0
No files found.
benchmark/benchmark_latency.py
View file @
0f40557a
...
@@ -50,14 +50,15 @@ def main(args: argparse.Namespace):
...
@@ -50,14 +50,15 @@ def main(args: argparse.Namespace):
block_size
=
args
.
block_size
,
block_size
=
args
.
block_size
,
)
)
sampling_params_dict
=
{
sampling_params_dict
=
{
'n'
:
1
,
'n'
:
args
.
n
,
'temperature'
:
0.0
,
'temperature'
:
0.0
if
args
.
use_beam_search
else
1.0
,
'top_p'
:
1.0
,
'top_p'
:
1.0
,
'use_beam_search'
:
False
,
'use_beam_search'
:
args
.
use_beam_search
,
'stop_token_ids'
:
set
(),
'stop_token_ids'
:
set
(),
'max_num_steps'
:
args
.
output_len
,
'max_num_steps'
:
args
.
output_len
,
}
}
sampling_params
=
SamplingParams
.
from_dict
(
sampling_params_dict
)
sampling_params
=
SamplingParams
.
from_dict
(
sampling_params_dict
)
print
(
sampling_params
)
input_token_ids
=
[
0
]
*
args
.
input_len
input_token_ids
=
[
0
]
*
args
.
input_len
def
profile_step
(
profile
=
False
):
def
profile_step
(
profile
=
False
):
...
@@ -93,6 +94,8 @@ if __name__ == '__main__':
...
@@ -93,6 +94,8 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--input-len'
,
type
=
int
,
default
=
32
)
parser
.
add_argument
(
'--input-len'
,
type
=
int
,
default
=
32
)
parser
.
add_argument
(
'--output-len'
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
'--output-len'
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
'--n'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--use-beam-search'
,
action
=
'store_true'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
args
.
max_num_batched_tokens
=
max
(
args
.
max_num_batched_tokens
=
max
(
args
.
max_num_batched_tokens
,
args
.
batch_size
*
args
.
input_len
)
args
.
max_num_batched_tokens
,
args
.
batch_size
*
args
.
input_len
)
...
...
cacheflow/models/sample.py
View file @
0f40557a
...
@@ -185,9 +185,10 @@ def _sample_from_generation_tokens(
...
@@ -185,9 +185,10 @@ def _sample_from_generation_tokens(
vocab_size
=
logprobs
.
size
(
-
1
)
vocab_size
=
logprobs
.
size
(
-
1
)
beam_width
=
len
(
seq_ids
)
beam_width
=
len
(
seq_ids
)
_
,
topk_ids
=
torch
.
topk
(
logprobs
.
flatten
(),
beam_width
)
_
,
topk_ids
=
torch
.
topk
(
logprobs
.
flatten
(),
beam_width
)
seq_idx
=
torch
.
div
(
topk_ids
,
vocab_size
,
rounding_mode
=
'floor'
).
tolist
()
topk_ids
=
topk_ids
.
tolist
()
seq_idx
=
[
i
//
vocab_size
for
i
in
topk_ids
]
beam_seq_ids
=
[
seq_ids
[
i
]
for
i
in
seq_idx
]
beam_seq_ids
=
[
seq_ids
[
i
]
for
i
in
seq_idx
]
token_ids
=
(
topk_ids
%
vocab_size
).
tolist
()
token_ids
=
[
i
%
vocab_size
for
i
in
topk_ids
]
beam_outputs
:
Dict
[
int
,
Tuple
[
int
,
int
]]
=
{}
beam_outputs
:
Dict
[
int
,
Tuple
[
int
,
int
]]
=
{}
outstanding_beams
:
List
[
Tuple
[
int
,
int
]]
=
[]
outstanding_beams
:
List
[
Tuple
[
int
,
int
]]
=
[]
...
...
cacheflow/worker/cache_engine.py
View file @
0f40557a
...
@@ -120,24 +120,8 @@ class CacheEngine:
...
@@ -120,24 +120,8 @@ class CacheEngine:
def
swap_out
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
def
swap_out
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
self
.
_swap
(
self
.
gpu_cache
,
self
.
cpu_cache
,
src_to_dst
)
self
.
_swap
(
self
.
gpu_cache
,
self
.
cpu_cache
,
src_to_dst
)
def
_copy
(
self
,
src
:
List
[
KVCache
],
dst
:
List
[
KVCache
],
src_to_dsts
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
with
torch
.
cuda
.
stream
(
self
.
cache_stream
):
for
i
in
range
(
self
.
num_layers
):
src_key_cache
,
src_value_cache
=
src
[
i
]
dst_key_cache
,
dst_value_cache
=
dst
[
i
]
# Copy the key blocks.
cache_ops
.
copy_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dsts
)
# Copy the value blocks.
cache_ops
.
copy_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dsts
)
event
=
self
.
events
[
i
]
event
.
record
(
stream
=
self
.
cache_stream
)
def
copy
(
self
,
src_to_dsts
:
Dict
[
int
,
List
[
int
]])
->
None
:
def
copy
(
self
,
src_to_dsts
:
Dict
[
int
,
List
[
int
]])
->
None
:
self
.
_copy
(
self
.
gpu_cache
,
self
.
gpu_cache
,
src_to_dsts
)
key_caches
=
[
key_cache
for
key_cache
,
_
in
self
.
gpu_cache
]
value_caches
=
[
value_cache
for
_
,
value_cache
in
self
.
gpu_cache
]
# NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
cache_ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dsts
)
csrc/cache.cpp
View file @
0f40557a
...
@@ -9,8 +9,8 @@ void swap_blocks(
...
@@ -9,8 +9,8 @@ void swap_blocks(
const
std
::
map
<
int64_t
,
int64_t
>&
block_mapping
);
const
std
::
map
<
int64_t
,
int64_t
>&
block_mapping
);
void
copy_blocks
(
void
copy_blocks
(
torch
::
Tensor
&
src
,
std
::
vector
<
torch
::
Tensor
>
&
key_caches
,
torch
::
Tensor
&
dst
,
std
::
vector
<
torch
::
Tensor
>
&
value_caches
,
const
std
::
map
<
int64_t
,
std
::
vector
<
int64_t
>>&
block_mapping
);
const
std
::
map
<
int64_t
,
std
::
vector
<
int64_t
>>&
block_mapping
);
void
reshape_and_cache
(
void
reshape_and_cache
(
...
...
csrc/cache_kernels.cu
View file @
0f40557a
...
@@ -43,33 +43,93 @@ void swap_blocks(
...
@@ -43,33 +43,93 @@ void swap_blocks(
}
}
}
}
void
copy_blocks
(
namespace
cacheflow
{
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
const
std
::
map
<
int64_t
,
std
::
vector
<
int64_t
>>&
block_mapping
)
{
torch
::
Device
src_device
=
src
.
device
();
torch
::
Device
dst_device
=
dst
.
device
();
assert
(
src_device
.
is_cuda
()
&&
dst_device
.
is_cuda
());
cudaMemcpyKind
memcpy_type
=
cudaMemcpyDeviceToDevice
;
void
*
src_ptr
=
src
.
data_ptr
();
// Grid: (num_layers, num_pairs)
void
*
dst_ptr
=
dst
.
data_ptr
();
template
<
typename
scalar_t
>
__global__
void
copy_blocks_kernel
(
int64_t
*
key_cache_ptrs
,
int64_t
*
value_cache_ptrs
,
const
int
*
__restrict__
block_mapping
,
const
int
numel_per_block
)
{
const
int
layer_idx
=
blockIdx
.
x
;
const
int
pair_idx
=
blockIdx
.
y
;
scalar_t
*
key_cache
=
reinterpret_cast
<
scalar_t
*>
(
key_cache_ptrs
[
layer_idx
]);
scalar_t
*
value_cache
=
reinterpret_cast
<
scalar_t
*>
(
value_cache_ptrs
[
layer_idx
]);
int
src_block_number
=
block_mapping
[
2
*
pair_idx
];
int
dst_block_number
=
block_mapping
[
2
*
pair_idx
+
1
];
const
int
src_block_offset
=
src_block_number
*
numel_per_block
;
const
int
dst_block_offset
=
dst_block_number
*
numel_per_block
;
for
(
int
i
=
threadIdx
.
x
;
i
<
numel_per_block
;
i
+=
blockDim
.
x
)
{
int
src_offset
=
src_block_offset
+
i
;
int
dst_offset
=
dst_block_offset
+
i
;
key_cache
[
dst_offset
]
=
key_cache
[
src_offset
];
}
for
(
int
i
=
threadIdx
.
x
;
i
<
numel_per_block
;
i
+=
blockDim
.
x
)
{
int
src_offset
=
src_block_offset
+
i
;
int
dst_offset
=
dst_block_offset
+
i
;
value_cache
[
dst_offset
]
=
value_cache
[
src_offset
];
}
}
const
int64_t
block_size_in_bytes
=
src
.
element_size
()
*
src
[
0
].
numel
();
}
// namespace cacheflow
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
void
copy_blocks
(
std
::
vector
<
torch
::
Tensor
>&
key_caches
,
std
::
vector
<
torch
::
Tensor
>&
value_caches
,
const
std
::
map
<
int64_t
,
std
::
vector
<
int64_t
>>&
block_mapping
)
{
int
num_layers
=
key_caches
.
size
();
TORCH_CHECK
(
num_layers
==
value_caches
.
size
());
if
(
num_layers
==
0
)
{
return
;
}
torch
::
Device
cache_device
=
key_caches
[
0
].
device
();
TORCH_CHECK
(
cache_device
.
is_cuda
());
// Create data structures for the kernel.
// Create an array of pointers to the key and value caches.
int64_t
key_cache_ptrs
[
num_layers
];
int64_t
value_cache_ptrs
[
num_layers
];
for
(
int
layer_idx
=
0
;
layer_idx
<
num_layers
;
++
layer_idx
)
{
key_cache_ptrs
[
layer_idx
]
=
reinterpret_cast
<
int64_t
>
(
key_caches
[
layer_idx
].
data_ptr
());
value_cache_ptrs
[
layer_idx
]
=
reinterpret_cast
<
int64_t
>
(
value_caches
[
layer_idx
].
data_ptr
());
}
// Create block mapping array.
std
::
vector
<
int
>
block_mapping_vec
;
for
(
const
auto
&
pair
:
block_mapping
)
{
for
(
const
auto
&
pair
:
block_mapping
)
{
int64_t
src_block_number
=
pair
.
first
;
int
src_block_number
=
pair
.
first
;
for
(
int64_t
dst_block_number
:
pair
.
second
)
{
for
(
int
dst_block_number
:
pair
.
second
)
{
int64_t
src_offset
=
src_block_number
*
block_size_in_bytes
;
block_mapping_vec
.
push_back
(
src_block_number
);
int64_t
dst_offset
=
dst_block_number
*
block_size_in_bytes
;
block_mapping_vec
.
push_back
(
dst_block_number
);
cudaMemcpyAsync
(
dst_ptr
+
dst_offset
,
src_ptr
+
src_offset
,
block_size_in_bytes
,
memcpy_type
,
stream
);
}
}
}
}
int
*
block_mapping_array
=
block_mapping_vec
.
data
();
int
num_pairs
=
block_mapping_vec
.
size
()
/
2
;
// Move the data structures to the GPU.
// NOTE: This synchronizes the CPU and GPU.
torch
::
Tensor
key_cache_ptrs_tensor
=
torch
::
from_blob
(
key_cache_ptrs
,
{
num_layers
},
torch
::
kInt64
).
to
(
cache_device
);
torch
::
Tensor
value_cache_ptrs_tensor
=
torch
::
from_blob
(
value_cache_ptrs
,
{
num_layers
},
torch
::
kInt64
).
to
(
cache_device
);
torch
::
Tensor
block_mapping_tensor
=
torch
::
from_blob
(
block_mapping_array
,
{
2
*
num_pairs
},
torch
::
kInt
).
to
(
cache_device
);
// Launch the kernel.
const
int
numel_per_block
=
key_caches
[
0
][
0
].
numel
();
dim3
grid
(
num_layers
,
num_pairs
);
dim3
block
(
std
::
min
(
1024
,
numel_per_block
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
key_caches
[
0
].
scalar_type
(),
"copy_blocks_kernel"
,
([
&
]
{
cacheflow
::
copy_blocks_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
key_cache_ptrs_tensor
.
data_ptr
<
int64_t
>
(),
value_cache_ptrs_tensor
.
data_ptr
<
int64_t
>
(),
block_mapping_tensor
.
data_ptr
<
int
>
(),
numel_per_block
);
}));
}
}
namespace
cacheflow
{
namespace
cacheflow
{
...
...
tests/kernels/cache.py
View file @
0f40557a
...
@@ -5,6 +5,61 @@ import torch
...
@@ -5,6 +5,61 @@ import torch
from
cacheflow
import
cache_ops
from
cacheflow
import
cache_ops
def
test_copy_blocks
(
num_mappings
:
int
,
num_layers
:
int
,
num_heads
:
int
,
head_size
:
int
,
block_size
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
)
->
None
:
# Generate random block mappings.
src_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
remainig_blocks
=
list
(
set
(
range
(
num_blocks
))
-
set
(
src_blocks
))
dst_blocks
=
random
.
sample
(
remainig_blocks
,
num_mappings
)
block_mapping
=
{
src
:
[
dst
]
for
src
,
dst
in
zip
(
src_blocks
,
dst_blocks
)}
# Create the KV cache.
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_caches
=
[]
for
_
in
range
(
num_layers
):
key_cache
=
torch
.
randn
(
size
=
key_cache_shape
,
dtype
=
dtype
,
device
=
'cuda'
)
key_caches
.
append
(
key_cache
)
cloned_key_caches
=
[]
for
key_cache
in
key_caches
:
cloned_key_caches
.
append
(
key_cache
.
clone
())
value_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
value_caches
=
[]
for
_
in
range
(
num_layers
):
value_cache
=
torch
.
randn
(
size
=
value_cache_shape
,
dtype
=
dtype
,
device
=
'cuda'
)
value_caches
.
append
(
value_cache
)
cloned_value_caches
=
[]
for
value_cache
in
value_caches
:
cloned_value_caches
.
append
(
value_cache
.
clone
())
# Call the copy blocks kernel.
cache_ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping
)
# Reference implementation.
for
src
,
dsts
in
block_mapping
.
items
():
for
dst
in
dsts
:
for
key_cache
,
cloned_key_cache
in
zip
(
key_caches
,
cloned_key_caches
):
cloned_key_cache
[
dst
]
=
cloned_key_cache
[
src
]
for
value_cache
,
cloned_value_cache
in
zip
(
value_caches
,
cloned_value_caches
):
cloned_value_cache
[
dst
]
=
cloned_value_cache
[
src
]
# Compare the results.
for
key_cache
,
cloned_key_cache
in
zip
(
key_caches
,
cloned_key_caches
):
assert
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
for
value_cache
,
cloned_value_cache
in
zip
(
value_caches
,
cloned_value_caches
):
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
def
test_reshape_and_cache
(
def
test_reshape_and_cache
(
num_tokens
:
int
,
num_tokens
:
int
,
num_heads
:
int
,
num_heads
:
int
,
...
@@ -46,6 +101,9 @@ def test_reshape_and_cache(
...
@@ -46,6 +101,9 @@ def test_reshape_and_cache(
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_cache
()
->
None
:
def
test_cache
()
->
None
:
test_copy_blocks
(
num_mappings
=
23
,
num_layers
=
7
,
num_heads
=
17
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
1024
,
dtype
=
torch
.
half
)
test_reshape_and_cache
(
test_reshape_and_cache
(
num_tokens
=
3
,
num_heads
=
2
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
2
,
num_tokens
=
3
,
num_heads
=
2
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
2
,
dtype
=
torch
.
half
)
dtype
=
torch
.
half
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment