Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
0ce8647d
Unverified
Commit
0ce8647d
authored
Oct 31, 2023
by
Woosuk Kwon
Committed by
GitHub
Oct 31, 2023
Browse files
Fix integer overflows in attention & cache ops (#1514)
parent
9cabcb76
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
53 additions
and
47 deletions
+53
-47
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+8
-2
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+36
-36
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+1
-1
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+7
-7
vllm/worker/worker.py
vllm/worker/worker.py
+1
-1
No files found.
csrc/attention/attention_kernels.cu
View file @
0ce8647d
...
...
@@ -175,7 +175,10 @@ __device__ void paged_attention_kernel(
// dot product with the query.
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
const
int
physical_block_number
=
block_table
[
block_idx
];
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
// because int32 can lead to overflow when this variable is multiplied by large numbers
// (e.g., kv_block_stride).
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
...
...
@@ -285,7 +288,10 @@ __device__ void paged_attention_kernel(
scalar_t
zero_value
;
zero
(
zero_value
);
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
const
int
physical_block_number
=
block_table
[
block_idx
];
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
// because int32 can lead to overflow when this variable is multiplied by large numbers
// (e.g., kv_block_stride).
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
const
int
physical_block_offset
=
(
lane
%
NUM_V_VECS_PER_ROW
)
*
V_VEC_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
L_vec
logits_vec
;
...
...
csrc/cache_kernels.cu
View file @
0ce8647d
...
...
@@ -55,26 +55,26 @@ template<typename scalar_t>
__global__
void
copy_blocks_kernel
(
int64_t
*
key_cache_ptrs
,
int64_t
*
value_cache_ptrs
,
const
int
*
__restrict__
block_mapping
,
const
int
64_t
*
__restrict__
block_mapping
,
const
int
numel_per_block
)
{
const
int
layer_idx
=
blockIdx
.
x
;
const
int
pair_idx
=
blockIdx
.
y
;
scalar_t
*
key_cache
=
reinterpret_cast
<
scalar_t
*>
(
key_cache_ptrs
[
layer_idx
]);
scalar_t
*
value_cache
=
reinterpret_cast
<
scalar_t
*>
(
value_cache_ptrs
[
layer_idx
]);
int
src_block_number
=
block_mapping
[
2
*
pair_idx
];
int
dst_block_number
=
block_mapping
[
2
*
pair_idx
+
1
];
int
64_t
src_block_number
=
block_mapping
[
2
*
pair_idx
];
int
64_t
dst_block_number
=
block_mapping
[
2
*
pair_idx
+
1
];
const
int
src_block_offset
=
src_block_number
*
numel_per_block
;
const
int
dst_block_offset
=
dst_block_number
*
numel_per_block
;
const
int
64_t
src_block_offset
=
src_block_number
*
numel_per_block
;
const
int
64_t
dst_block_offset
=
dst_block_number
*
numel_per_block
;
for
(
int
i
=
threadIdx
.
x
;
i
<
numel_per_block
;
i
+=
blockDim
.
x
)
{
int
src_offset
=
src_block_offset
+
i
;
int
dst_offset
=
dst_block_offset
+
i
;
int
64_t
src_offset
=
src_block_offset
+
i
;
int
64_t
dst_offset
=
dst_block_offset
+
i
;
key_cache
[
dst_offset
]
=
key_cache
[
src_offset
];
}
for
(
int
i
=
threadIdx
.
x
;
i
<
numel_per_block
;
i
+=
blockDim
.
x
)
{
int
src_offset
=
src_block_offset
+
i
;
int
dst_offset
=
dst_block_offset
+
i
;
int
64_t
src_offset
=
src_block_offset
+
i
;
int
64_t
dst_offset
=
dst_block_offset
+
i
;
value_cache
[
dst_offset
]
=
value_cache
[
src_offset
];
}
}
...
...
@@ -102,15 +102,15 @@ void copy_blocks(
value_cache_ptrs
[
layer_idx
]
=
reinterpret_cast
<
int64_t
>
(
value_caches
[
layer_idx
].
data_ptr
());
}
// Create block mapping array.
std
::
vector
<
int
>
block_mapping_vec
;
std
::
vector
<
int
64_t
>
block_mapping_vec
;
for
(
const
auto
&
pair
:
block_mapping
)
{
int
src_block_number
=
pair
.
first
;
for
(
int
dst_block_number
:
pair
.
second
)
{
int
64_t
src_block_number
=
pair
.
first
;
for
(
int
64_t
dst_block_number
:
pair
.
second
)
{
block_mapping_vec
.
push_back
(
src_block_number
);
block_mapping_vec
.
push_back
(
dst_block_number
);
}
}
int
*
block_mapping_array
=
block_mapping_vec
.
data
();
int
64_t
*
block_mapping_array
=
block_mapping_vec
.
data
();
int
num_pairs
=
block_mapping_vec
.
size
()
/
2
;
// Move the data structures to the GPU.
...
...
@@ -120,7 +120,7 @@ void copy_blocks(
torch
::
Tensor
value_cache_ptrs_tensor
=
torch
::
from_blob
(
value_cache_ptrs
,
{
num_layers
},
torch
::
kInt64
).
to
(
cache_device
);
torch
::
Tensor
block_mapping_tensor
=
torch
::
from_blob
(
block_mapping_array
,
{
2
*
num_pairs
},
torch
::
kInt
).
to
(
cache_device
);
block_mapping_array
,
{
2
*
num_pairs
},
torch
::
kInt
64
).
to
(
cache_device
);
// Launch the kernel.
const
int
numel_per_block
=
key_caches
[
0
][
0
].
numel
();
...
...
@@ -132,7 +132,7 @@ void copy_blocks(
vllm
::
copy_blocks_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
key_cache_ptrs_tensor
.
data_ptr
<
int64_t
>
(),
value_cache_ptrs_tensor
.
data_ptr
<
int64_t
>
(),
block_mapping_tensor
.
data_ptr
<
int
>
(),
block_mapping_tensor
.
data_ptr
<
int
64_t
>
(),
numel_per_block
);
}));
}
...
...
@@ -141,46 +141,46 @@ namespace vllm {
template
<
typename
scalar_t
>
__global__
void
reshape_and_cache_kernel
(
const
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
value
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
scalar_t
*
__restrict__
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
const
int
*
__restrict__
slot_mapping
,
// [num_tokens]
const
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
value
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
scalar_t
*
__restrict__
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
const
int
64_t
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int
key_stride
,
const
int
value_stride
,
const
int
num_heads
,
const
int
head_size
,
const
int
block_size
,
const
int
x
)
{
const
int
token_idx
=
blockIdx
.
x
;
const
int
slot_idx
=
slot_mapping
[
token_idx
];
const
int
64_t
token_idx
=
blockIdx
.
x
;
const
int
64_t
slot_idx
=
slot_mapping
[
token_idx
];
if
(
slot_idx
<
0
)
{
// Padding token that should be ignored.
return
;
}
const
int
block_idx
=
slot_idx
/
block_size
;
const
int
block_offset
=
slot_idx
%
block_size
;
const
int
64_t
block_idx
=
slot_idx
/
block_size
;
const
int
64_t
block_offset
=
slot_idx
%
block_size
;
const
int
n
=
num_heads
*
head_size
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int
src_key_idx
=
token_idx
*
key_stride
+
i
;
const
int
src_value_idx
=
token_idx
*
value_stride
+
i
;
const
int
64_t
src_key_idx
=
token_idx
*
key_stride
+
i
;
const
int
64_t
src_value_idx
=
token_idx
*
value_stride
+
i
;
const
int
head_idx
=
i
/
head_size
;
const
int
head_offset
=
i
%
head_size
;
const
int
x_idx
=
head_offset
/
x
;
const
int
x_offset
=
head_offset
%
x
;
const
int
tgt_key_idx
=
block_idx
*
num_heads
*
(
head_size
/
x
)
*
block_size
*
x
+
head_idx
*
(
head_size
/
x
)
*
block_size
*
x
+
x_idx
*
block_size
*
x
+
block_offset
*
x
+
x_offset
;
const
int
tgt_value_idx
=
block_idx
*
num_heads
*
head_size
*
block_size
+
head_idx
*
head_size
*
block_size
+
head_offset
*
block_size
+
block_offset
;
const
int
64_t
tgt_key_idx
=
block_idx
*
num_heads
*
(
head_size
/
x
)
*
block_size
*
x
+
head_idx
*
(
head_size
/
x
)
*
block_size
*
x
+
x_idx
*
block_size
*
x
+
block_offset
*
x
+
x_offset
;
const
int
64_t
tgt_value_idx
=
block_idx
*
num_heads
*
head_size
*
block_size
+
head_idx
*
head_size
*
block_size
+
head_offset
*
block_size
+
block_offset
;
key_cache
[
tgt_key_idx
]
=
key
[
src_key_idx
];
value_cache
[
tgt_value_idx
]
=
value
[
src_value_idx
];
}
...
...
@@ -216,7 +216,7 @@ void reshape_and_cache(
value
.
data_ptr
<
scalar_t
>
(),
key_cache
.
data_ptr
<
scalar_t
>
(),
value_cache
.
data_ptr
<
scalar_t
>
(),
slot_mapping
.
data_ptr
<
int
>
(),
slot_mapping
.
data_ptr
<
int
64_t
>
(),
key_stride
,
value_stride
,
num_heads
,
...
...
tests/kernels/test_attention.py
View file @
0ce8647d
...
...
@@ -13,7 +13,7 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN
=
get_max_shared_memory_bytes
()
//
FLOAT32_BYTES
-
512
NUM_BLOCKS
=
128
# Arbitrary values for testing
NUM_BLOCKS
=
40000
# Arbitrary values for testing
PARTITION_SIZE
=
512
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
...
tests/kernels/test_cache.py
View file @
0ce8647d
...
...
@@ -6,13 +6,13 @@ import torch
from
vllm
import
cache_ops
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_TOKENS
=
[
7
,
83
,
204
8
]
# Arbitrary values for testing
NUM_LAYERS
=
[
5
]
# Arbitrary values for testing
NUM_TOKENS
=
[
8
3
]
# Arbitrary values for testing
NUM_LAYERS
=
[
1
]
# Arbitrary values for testing
NUM_HEADS
=
[
8
]
# Arbitrary values for testing
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
BLOCK_SIZES
=
[
8
,
16
,
32
]
NUM_BLOCKS
=
[
1024
]
# Arbitrary values for testing
NUM_MAPPINGS
=
[
32
,
256
]
# Arbitrary values for testing
NUM_BLOCKS
=
[
1024
,
36000
]
# Arbitrary values for testing
NUM_MAPPINGS
=
[
256
]
# Arbitrary values for testing
SEEDS
=
[
0
]
...
...
@@ -69,9 +69,9 @@ def test_copy_blocks(
for
src
,
dsts
in
block_mapping
.
items
():
for
dst
in
dsts
:
for
cloned_key_cache
in
cloned_key_caches
:
cloned_key_cache
[
dst
]
=
cloned_key_cache
[
src
]
cloned_key_cache
[
dst
]
.
copy_
(
cloned_key_cache
[
src
]
)
for
cloned_value_cache
in
cloned_value_caches
:
cloned_value_cache
[
dst
]
=
cloned_value_cache
[
src
]
cloned_value_cache
[
dst
]
.
copy_
(
cloned_value_cache
[
src
]
)
# Compare the results.
for
key_cache
,
cloned_key_cache
in
zip
(
key_caches
,
cloned_key_caches
):
...
...
@@ -106,7 +106,7 @@ def test_reshape_and_cache(
# Create a random slot mapping.
num_slots
=
block_size
*
num_blocks
slot_mapping
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
qkv
=
torch
.
randn
(
num_tokens
,
3
,
...
...
vllm/worker/worker.py
View file @
0ce8647d
...
...
@@ -301,7 +301,7 @@ class Worker:
dtype
=
torch
.
long
,
device
=
"cuda"
)
slot_mapping_tensor
=
torch
.
tensor
(
padded_slot_mapping
,
dtype
=
torch
.
int
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment