Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
0ce8647d
Unverified
Commit
0ce8647d
authored
Oct 31, 2023
by
Woosuk Kwon
Committed by
GitHub
Oct 31, 2023
Browse files
Fix integer overflows in attention & cache ops (#1514)
parent
9cabcb76
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
53 additions
and
47 deletions
+53
-47
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+8
-2
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+36
-36
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+1
-1
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+7
-7
vllm/worker/worker.py
vllm/worker/worker.py
+1
-1
No files found.
csrc/attention/attention_kernels.cu
View file @
0ce8647d
...
@@ -175,7 +175,10 @@ __device__ void paged_attention_kernel(
...
@@ -175,7 +175,10 @@ __device__ void paged_attention_kernel(
// dot product with the query.
// dot product with the query.
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
const
int
physical_block_number
=
block_table
[
block_idx
];
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
// because int32 can lead to overflow when this variable is multiplied by large numbers
// (e.g., kv_block_stride).
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
// Load a key to registers.
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// Each thread in a thread group has a different part of the key.
...
@@ -285,7 +288,10 @@ __device__ void paged_attention_kernel(
...
@@ -285,7 +288,10 @@ __device__ void paged_attention_kernel(
scalar_t
zero_value
;
scalar_t
zero_value
;
zero
(
zero_value
);
zero
(
zero_value
);
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
const
int
physical_block_number
=
block_table
[
block_idx
];
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
// because int32 can lead to overflow when this variable is multiplied by large numbers
// (e.g., kv_block_stride).
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
const
int
physical_block_offset
=
(
lane
%
NUM_V_VECS_PER_ROW
)
*
V_VEC_SIZE
;
const
int
physical_block_offset
=
(
lane
%
NUM_V_VECS_PER_ROW
)
*
V_VEC_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
L_vec
logits_vec
;
L_vec
logits_vec
;
...
...
csrc/cache_kernels.cu
View file @
0ce8647d
...
@@ -55,26 +55,26 @@ template<typename scalar_t>
...
@@ -55,26 +55,26 @@ template<typename scalar_t>
__global__
void
copy_blocks_kernel
(
__global__
void
copy_blocks_kernel
(
int64_t
*
key_cache_ptrs
,
int64_t
*
key_cache_ptrs
,
int64_t
*
value_cache_ptrs
,
int64_t
*
value_cache_ptrs
,
const
int
*
__restrict__
block_mapping
,
const
int
64_t
*
__restrict__
block_mapping
,
const
int
numel_per_block
)
{
const
int
numel_per_block
)
{
const
int
layer_idx
=
blockIdx
.
x
;
const
int
layer_idx
=
blockIdx
.
x
;
const
int
pair_idx
=
blockIdx
.
y
;
const
int
pair_idx
=
blockIdx
.
y
;
scalar_t
*
key_cache
=
reinterpret_cast
<
scalar_t
*>
(
key_cache_ptrs
[
layer_idx
]);
scalar_t
*
key_cache
=
reinterpret_cast
<
scalar_t
*>
(
key_cache_ptrs
[
layer_idx
]);
scalar_t
*
value_cache
=
reinterpret_cast
<
scalar_t
*>
(
value_cache_ptrs
[
layer_idx
]);
scalar_t
*
value_cache
=
reinterpret_cast
<
scalar_t
*>
(
value_cache_ptrs
[
layer_idx
]);
int
src_block_number
=
block_mapping
[
2
*
pair_idx
];
int
64_t
src_block_number
=
block_mapping
[
2
*
pair_idx
];
int
dst_block_number
=
block_mapping
[
2
*
pair_idx
+
1
];
int
64_t
dst_block_number
=
block_mapping
[
2
*
pair_idx
+
1
];
const
int
src_block_offset
=
src_block_number
*
numel_per_block
;
const
int
64_t
src_block_offset
=
src_block_number
*
numel_per_block
;
const
int
dst_block_offset
=
dst_block_number
*
numel_per_block
;
const
int
64_t
dst_block_offset
=
dst_block_number
*
numel_per_block
;
for
(
int
i
=
threadIdx
.
x
;
i
<
numel_per_block
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
numel_per_block
;
i
+=
blockDim
.
x
)
{
int
src_offset
=
src_block_offset
+
i
;
int
64_t
src_offset
=
src_block_offset
+
i
;
int
dst_offset
=
dst_block_offset
+
i
;
int
64_t
dst_offset
=
dst_block_offset
+
i
;
key_cache
[
dst_offset
]
=
key_cache
[
src_offset
];
key_cache
[
dst_offset
]
=
key_cache
[
src_offset
];
}
}
for
(
int
i
=
threadIdx
.
x
;
i
<
numel_per_block
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
numel_per_block
;
i
+=
blockDim
.
x
)
{
int
src_offset
=
src_block_offset
+
i
;
int
64_t
src_offset
=
src_block_offset
+
i
;
int
dst_offset
=
dst_block_offset
+
i
;
int
64_t
dst_offset
=
dst_block_offset
+
i
;
value_cache
[
dst_offset
]
=
value_cache
[
src_offset
];
value_cache
[
dst_offset
]
=
value_cache
[
src_offset
];
}
}
}
}
...
@@ -102,15 +102,15 @@ void copy_blocks(
...
@@ -102,15 +102,15 @@ void copy_blocks(
value_cache_ptrs
[
layer_idx
]
=
reinterpret_cast
<
int64_t
>
(
value_caches
[
layer_idx
].
data_ptr
());
value_cache_ptrs
[
layer_idx
]
=
reinterpret_cast
<
int64_t
>
(
value_caches
[
layer_idx
].
data_ptr
());
}
}
// Create block mapping array.
// Create block mapping array.
std
::
vector
<
int
>
block_mapping_vec
;
std
::
vector
<
int
64_t
>
block_mapping_vec
;
for
(
const
auto
&
pair
:
block_mapping
)
{
for
(
const
auto
&
pair
:
block_mapping
)
{
int
src_block_number
=
pair
.
first
;
int
64_t
src_block_number
=
pair
.
first
;
for
(
int
dst_block_number
:
pair
.
second
)
{
for
(
int
64_t
dst_block_number
:
pair
.
second
)
{
block_mapping_vec
.
push_back
(
src_block_number
);
block_mapping_vec
.
push_back
(
src_block_number
);
block_mapping_vec
.
push_back
(
dst_block_number
);
block_mapping_vec
.
push_back
(
dst_block_number
);
}
}
}
}
int
*
block_mapping_array
=
block_mapping_vec
.
data
();
int
64_t
*
block_mapping_array
=
block_mapping_vec
.
data
();
int
num_pairs
=
block_mapping_vec
.
size
()
/
2
;
int
num_pairs
=
block_mapping_vec
.
size
()
/
2
;
// Move the data structures to the GPU.
// Move the data structures to the GPU.
...
@@ -120,7 +120,7 @@ void copy_blocks(
...
@@ -120,7 +120,7 @@ void copy_blocks(
torch
::
Tensor
value_cache_ptrs_tensor
=
torch
::
from_blob
(
torch
::
Tensor
value_cache_ptrs_tensor
=
torch
::
from_blob
(
value_cache_ptrs
,
{
num_layers
},
torch
::
kInt64
).
to
(
cache_device
);
value_cache_ptrs
,
{
num_layers
},
torch
::
kInt64
).
to
(
cache_device
);
torch
::
Tensor
block_mapping_tensor
=
torch
::
from_blob
(
torch
::
Tensor
block_mapping_tensor
=
torch
::
from_blob
(
block_mapping_array
,
{
2
*
num_pairs
},
torch
::
kInt
).
to
(
cache_device
);
block_mapping_array
,
{
2
*
num_pairs
},
torch
::
kInt
64
).
to
(
cache_device
);
// Launch the kernel.
// Launch the kernel.
const
int
numel_per_block
=
key_caches
[
0
][
0
].
numel
();
const
int
numel_per_block
=
key_caches
[
0
][
0
].
numel
();
...
@@ -132,7 +132,7 @@ void copy_blocks(
...
@@ -132,7 +132,7 @@ void copy_blocks(
vllm
::
copy_blocks_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
vllm
::
copy_blocks_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
key_cache_ptrs_tensor
.
data_ptr
<
int64_t
>
(),
key_cache_ptrs_tensor
.
data_ptr
<
int64_t
>
(),
value_cache_ptrs_tensor
.
data_ptr
<
int64_t
>
(),
value_cache_ptrs_tensor
.
data_ptr
<
int64_t
>
(),
block_mapping_tensor
.
data_ptr
<
int
>
(),
block_mapping_tensor
.
data_ptr
<
int
64_t
>
(),
numel_per_block
);
numel_per_block
);
}));
}));
}
}
...
@@ -145,39 +145,39 @@ __global__ void reshape_and_cache_kernel(
...
@@ -145,39 +145,39 @@ __global__ void reshape_and_cache_kernel(
const
scalar_t
*
__restrict__
value
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
value
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
scalar_t
*
__restrict__
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
scalar_t
*
__restrict__
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
scalar_t
*
__restrict__
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
const
int
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int
64_t
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int
key_stride
,
const
int
key_stride
,
const
int
value_stride
,
const
int
value_stride
,
const
int
num_heads
,
const
int
num_heads
,
const
int
head_size
,
const
int
head_size
,
const
int
block_size
,
const
int
block_size
,
const
int
x
)
{
const
int
x
)
{
const
int
token_idx
=
blockIdx
.
x
;
const
int
64_t
token_idx
=
blockIdx
.
x
;
const
int
slot_idx
=
slot_mapping
[
token_idx
];
const
int
64_t
slot_idx
=
slot_mapping
[
token_idx
];
if
(
slot_idx
<
0
)
{
if
(
slot_idx
<
0
)
{
// Padding token that should be ignored.
// Padding token that should be ignored.
return
;
return
;
}
}
const
int
block_idx
=
slot_idx
/
block_size
;
const
int
64_t
block_idx
=
slot_idx
/
block_size
;
const
int
block_offset
=
slot_idx
%
block_size
;
const
int
64_t
block_offset
=
slot_idx
%
block_size
;
const
int
n
=
num_heads
*
head_size
;
const
int
n
=
num_heads
*
head_size
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int
src_key_idx
=
token_idx
*
key_stride
+
i
;
const
int
64_t
src_key_idx
=
token_idx
*
key_stride
+
i
;
const
int
src_value_idx
=
token_idx
*
value_stride
+
i
;
const
int
64_t
src_value_idx
=
token_idx
*
value_stride
+
i
;
const
int
head_idx
=
i
/
head_size
;
const
int
head_idx
=
i
/
head_size
;
const
int
head_offset
=
i
%
head_size
;
const
int
head_offset
=
i
%
head_size
;
const
int
x_idx
=
head_offset
/
x
;
const
int
x_idx
=
head_offset
/
x
;
const
int
x_offset
=
head_offset
%
x
;
const
int
x_offset
=
head_offset
%
x
;
const
int
tgt_key_idx
=
block_idx
*
num_heads
*
(
head_size
/
x
)
*
block_size
*
x
const
int
64_t
tgt_key_idx
=
block_idx
*
num_heads
*
(
head_size
/
x
)
*
block_size
*
x
+
head_idx
*
(
head_size
/
x
)
*
block_size
*
x
+
head_idx
*
(
head_size
/
x
)
*
block_size
*
x
+
x_idx
*
block_size
*
x
+
x_idx
*
block_size
*
x
+
block_offset
*
x
+
block_offset
*
x
+
x_offset
;
+
x_offset
;
const
int
tgt_value_idx
=
block_idx
*
num_heads
*
head_size
*
block_size
const
int
64_t
tgt_value_idx
=
block_idx
*
num_heads
*
head_size
*
block_size
+
head_idx
*
head_size
*
block_size
+
head_idx
*
head_size
*
block_size
+
head_offset
*
block_size
+
head_offset
*
block_size
+
block_offset
;
+
block_offset
;
...
@@ -216,7 +216,7 @@ void reshape_and_cache(
...
@@ -216,7 +216,7 @@ void reshape_and_cache(
value
.
data_ptr
<
scalar_t
>
(),
value
.
data_ptr
<
scalar_t
>
(),
key_cache
.
data_ptr
<
scalar_t
>
(),
key_cache
.
data_ptr
<
scalar_t
>
(),
value_cache
.
data_ptr
<
scalar_t
>
(),
value_cache
.
data_ptr
<
scalar_t
>
(),
slot_mapping
.
data_ptr
<
int
>
(),
slot_mapping
.
data_ptr
<
int
64_t
>
(),
key_stride
,
key_stride
,
value_stride
,
value_stride
,
num_heads
,
num_heads
,
...
...
tests/kernels/test_attention.py
View file @
0ce8647d
...
@@ -13,7 +13,7 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
...
@@ -13,7 +13,7 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# This will change depending on the compute capability.
# - 512 as a buffer
# - 512 as a buffer
MAX_SEQ_LEN
=
get_max_shared_memory_bytes
()
//
FLOAT32_BYTES
-
512
MAX_SEQ_LEN
=
get_max_shared_memory_bytes
()
//
FLOAT32_BYTES
-
512
NUM_BLOCKS
=
128
# Arbitrary values for testing
NUM_BLOCKS
=
40000
# Arbitrary values for testing
PARTITION_SIZE
=
512
PARTITION_SIZE
=
512
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
...
tests/kernels/test_cache.py
View file @
0ce8647d
...
@@ -6,13 +6,13 @@ import torch
...
@@ -6,13 +6,13 @@ import torch
from
vllm
import
cache_ops
from
vllm
import
cache_ops
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_TOKENS
=
[
7
,
83
,
204
8
]
# Arbitrary values for testing
NUM_TOKENS
=
[
8
3
]
# Arbitrary values for testing
NUM_LAYERS
=
[
5
]
# Arbitrary values for testing
NUM_LAYERS
=
[
1
]
# Arbitrary values for testing
NUM_HEADS
=
[
8
]
# Arbitrary values for testing
NUM_HEADS
=
[
8
]
# Arbitrary values for testing
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
BLOCK_SIZES
=
[
8
,
16
,
32
]
BLOCK_SIZES
=
[
8
,
16
,
32
]
NUM_BLOCKS
=
[
1024
]
# Arbitrary values for testing
NUM_BLOCKS
=
[
1024
,
36000
]
# Arbitrary values for testing
NUM_MAPPINGS
=
[
32
,
256
]
# Arbitrary values for testing
NUM_MAPPINGS
=
[
256
]
# Arbitrary values for testing
SEEDS
=
[
0
]
SEEDS
=
[
0
]
...
@@ -69,9 +69,9 @@ def test_copy_blocks(
...
@@ -69,9 +69,9 @@ def test_copy_blocks(
for
src
,
dsts
in
block_mapping
.
items
():
for
src
,
dsts
in
block_mapping
.
items
():
for
dst
in
dsts
:
for
dst
in
dsts
:
for
cloned_key_cache
in
cloned_key_caches
:
for
cloned_key_cache
in
cloned_key_caches
:
cloned_key_cache
[
dst
]
=
cloned_key_cache
[
src
]
cloned_key_cache
[
dst
]
.
copy_
(
cloned_key_cache
[
src
]
)
for
cloned_value_cache
in
cloned_value_caches
:
for
cloned_value_cache
in
cloned_value_caches
:
cloned_value_cache
[
dst
]
=
cloned_value_cache
[
src
]
cloned_value_cache
[
dst
]
.
copy_
(
cloned_value_cache
[
src
]
)
# Compare the results.
# Compare the results.
for
key_cache
,
cloned_key_cache
in
zip
(
key_caches
,
cloned_key_caches
):
for
key_cache
,
cloned_key_cache
in
zip
(
key_caches
,
cloned_key_caches
):
...
@@ -106,7 +106,7 @@ def test_reshape_and_cache(
...
@@ -106,7 +106,7 @@ def test_reshape_and_cache(
# Create a random slot mapping.
# Create a random slot mapping.
num_slots
=
block_size
*
num_blocks
num_slots
=
block_size
*
num_blocks
slot_mapping
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
qkv
=
torch
.
randn
(
num_tokens
,
qkv
=
torch
.
randn
(
num_tokens
,
3
,
3
,
...
...
vllm/worker/worker.py
View file @
0ce8647d
...
@@ -301,7 +301,7 @@ class Worker:
...
@@ -301,7 +301,7 @@ class Worker:
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
device
=
"cuda"
)
slot_mapping_tensor
=
torch
.
tensor
(
padded_slot_mapping
,
slot_mapping_tensor
=
torch
.
tensor
(
padded_slot_mapping
,
dtype
=
torch
.
int
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
device
=
"cuda"
)
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment