Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
cf5cb1e3
Unverified
Commit
cf5cb1e3
authored
Sep 26, 2023
by
Antoni Baum
Committed by
GitHub
Sep 26, 2023
Browse files
Allocate more shared memory to attention kernel (#1154)
parent
03ffd0a0
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
87 additions
and
3 deletions
+87
-3
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+5
-0
csrc/cuda_utils.cpp
csrc/cuda_utils.cpp
+13
-0
csrc/cuda_utils_kernels.cu
csrc/cuda_utils_kernels.cu
+14
-0
setup.py
setup.py
+11
-0
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+7
-1
vllm/utils.py
vllm/utils.py
+12
-1
vllm/worker/worker.py
vllm/worker/worker.py
+25
-1
No files found.
csrc/attention/attention_kernels.cu
View file @
cf5cb1e3
...
@@ -341,6 +341,9 @@ __global__ void single_query_cached_kv_attention_kernel(
...
@@ -341,6 +341,9 @@ __global__ void single_query_cached_kv_attention_kernel(
}
// namespace vllm
}
// namespace vllm
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
cudaFuncSetAttribute( \
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
<<<grid, block, shared_mem_size, stream>>>( \
<<<grid, block, shared_mem_size, stream>>>( \
out_ptr, \
out_ptr, \
...
@@ -401,6 +404,8 @@ void single_query_cached_kv_attention_launcher(
...
@@ -401,6 +404,8 @@ void single_query_cached_kv_attention_launcher(
int
padded_max_context_len
=
((
max_context_len
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
)
*
BLOCK_SIZE
;
int
padded_max_context_len
=
((
max_context_len
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
)
*
BLOCK_SIZE
;
int
logits_size
=
padded_max_context_len
*
sizeof
(
float
);
int
logits_size
=
padded_max_context_len
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
int
shared_mem_size
=
std
::
max
(
logits_size
,
outputs_size
);
int
shared_mem_size
=
std
::
max
(
logits_size
,
outputs_size
);
dim3
grid
(
num_heads
,
num_seqs
);
dim3
grid
(
num_heads
,
num_seqs
);
...
...
csrc/cuda_utils.cpp
0 → 100644
View file @
cf5cb1e3
#include <torch/extension.h>
int
get_device_attribute
(
int
attribute
,
int
device_id
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"get_device_attribute"
,
&
get_device_attribute
,
"Gets the specified device attribute."
);
}
csrc/cuda_utils_kernels.cu
0 → 100644
View file @
cf5cb1e3
int
get_device_attribute
(
int
attribute
,
int
device_id
)
{
int
device
,
value
;
if
(
device_id
<
0
)
{
cudaGetDevice
(
&
device
);
}
else
{
device
=
device_id
;
}
cudaDeviceGetAttribute
(
&
value
,
static_cast
<
cudaDeviceAttr
>
(
attribute
),
device
);
return
value
;
}
setup.py
View file @
cf5cb1e3
...
@@ -195,6 +195,17 @@ quantization_extension = CUDAExtension(
...
@@ -195,6 +195,17 @@ quantization_extension = CUDAExtension(
)
)
ext_modules
.
append
(
quantization_extension
)
ext_modules
.
append
(
quantization_extension
)
# Misc. CUDA utils.
cuda_utils_extension
=
CUDAExtension
(
name
=
"vllm.cuda_utils"
,
sources
=
[
"csrc/cuda_utils.cpp"
,
"csrc/cuda_utils_kernels.cu"
],
extra_compile_args
=
{
"cxx"
:
CXX_FLAGS
,
"nvcc"
:
NVCC_FLAGS
,
},
)
ext_modules
.
append
(
cuda_utils_extension
)
def
get_path
(
*
filepath
)
->
str
:
def
get_path
(
*
filepath
)
->
str
:
return
os
.
path
.
join
(
ROOT_DIR
,
*
filepath
)
return
os
.
path
.
join
(
ROOT_DIR
,
*
filepath
)
...
...
tests/kernels/test_attention.py
View file @
cf5cb1e3
...
@@ -7,8 +7,12 @@ from xformers import ops as xops
...
@@ -7,8 +7,12 @@ from xformers import ops as xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
vllm
import
attention_ops
from
vllm
import
attention_ops
from
vllm.utils
import
get_max_shared_memory_bytes
MAX_SEQ_LEN
=
8192
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN
=
get_max_shared_memory_bytes
()
//
FLOAT32_BYTES
-
512
NUM_BLOCKS
=
128
# Arbitrary values for testing
NUM_BLOCKS
=
128
# Arbitrary values for testing
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
@@ -135,6 +139,7 @@ def test_single_query_cached_kv_attention(
...
@@ -135,6 +139,7 @@ def test_single_query_cached_kv_attention(
device
=
"cuda"
)
device
=
"cuda"
)
context_lens
=
[
random
.
randint
(
1
,
MAX_SEQ_LEN
)
for
_
in
range
(
num_seqs
)]
context_lens
=
[
random
.
randint
(
1
,
MAX_SEQ_LEN
)
for
_
in
range
(
num_seqs
)]
context_lens
[
-
1
]
=
MAX_SEQ_LEN
max_context_len
=
max
(
context_lens
)
max_context_len
=
max
(
context_lens
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
...
@@ -243,6 +248,7 @@ def test_multi_query_kv_attention(
...
@@ -243,6 +248,7 @@ def test_multi_query_kv_attention(
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
seq_lens
=
random
.
sample
(
range
(
1
,
MAX_SEQ_LEN
),
num_seqs
)
seq_lens
=
random
.
sample
(
range
(
1
,
MAX_SEQ_LEN
),
num_seqs
)
seq_lens
[
-
1
]
=
MAX_SEQ_LEN
num_tokens
=
sum
(
seq_lens
)
num_tokens
=
sum
(
seq_lens
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
...
...
vllm/utils.py
View file @
cf5cb1e3
import
enum
import
enum
from
platform
import
uname
import
uuid
import
uuid
from
platform
import
uname
import
psutil
import
psutil
import
torch
import
torch
from
vllm
import
cuda_utils
class
Device
(
enum
.
Enum
):
class
Device
(
enum
.
Enum
):
GPU
=
enum
.
auto
()
GPU
=
enum
.
auto
()
...
@@ -25,6 +27,15 @@ class Counter:
...
@@ -25,6 +27,15 @@ class Counter:
self
.
counter
=
0
self
.
counter
=
0
def
get_max_shared_memory_bytes
(
gpu
:
int
=
0
)
->
int
:
"""Returns the maximum shared memory per thread block in bytes."""
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
cudaDevAttrMaxSharedMemoryPerBlockOptin
=
97
# pylint: disable=invalid-name
max_shared_mem
=
cuda_utils
.
get_device_attribute
(
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
gpu
)
return
int
(
max_shared_mem
)
def
get_gpu_memory
(
gpu
:
int
=
0
)
->
int
:
def
get_gpu_memory
(
gpu
:
int
=
0
)
->
int
:
"""Returns the total memory of the GPU in bytes."""
"""Returns the total memory of the GPU in bytes."""
return
torch
.
cuda
.
get_device_properties
(
gpu
).
total_memory
return
torch
.
cuda
.
get_device_properties
(
gpu
).
total_memory
...
...
vllm/worker/worker.py
View file @
cf5cb1e3
...
@@ -13,7 +13,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
...
@@ -13,7 +13,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.utils
import
get_gpu_memory
from
vllm.utils
import
get_gpu_memory
,
get_max_shared_memory_bytes
class
Worker
:
class
Worker
:
...
@@ -136,6 +136,10 @@ class Worker:
...
@@ -136,6 +136,10 @@ class Worker:
def
init_cache_engine
(
self
,
cache_config
:
CacheConfig
)
->
None
:
def
init_cache_engine
(
self
,
cache_config
:
CacheConfig
)
->
None
:
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
self
.
block_size
=
cache_config
.
block_size
self
.
block_size
=
cache_config
.
block_size
_check_if_can_support_max_seq_len
(
self
.
scheduler_config
.
max_model_len
,
self
.
block_size
)
self
.
cache_engine
=
CacheEngine
(
self
.
cache_config
,
self
.
model_config
,
self
.
cache_engine
=
CacheEngine
(
self
.
cache_config
,
self
.
model_config
,
self
.
parallel_config
)
self
.
parallel_config
)
self
.
cache_events
=
self
.
cache_engine
.
events
self
.
cache_events
=
self
.
cache_engine
.
events
...
@@ -347,3 +351,23 @@ def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
...
@@ -347,3 +351,23 @@ def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
def
_pad_to_max
(
x
:
List
[
int
],
max_len
:
int
)
->
List
[
int
]:
def
_pad_to_max
(
x
:
List
[
int
],
max_len
:
int
)
->
List
[
int
]:
return
x
+
[
0
]
*
(
max_len
-
len
(
x
))
return
x
+
[
0
]
*
(
max_len
-
len
(
x
))
def
_check_if_can_support_max_seq_len
(
max_seq_len
:
int
,
block_size
:
int
)
->
None
:
# Follows the logic in
# attention_kernels.cu::single_query_cached_kv_attention_launcher
max_shared_mem
=
get_max_shared_memory_bytes
()
float32_bytes
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
padded_max_seq_len
=
(
(
max_seq_len
+
block_size
-
1
)
/
block_size
)
*
block_size
# padded_max_seq_len + extra buffer
required_shared_mem
=
(
padded_max_seq_len
+
512
)
*
float32_bytes
if
padded_max_seq_len
*
float32_bytes
>
max_shared_mem
:
raise
RuntimeError
(
f
"vLLM cannot currently support max_model_len=
{
max_seq_len
}
"
f
"with block_size=
{
block_size
}
on GPU with compute "
f
"capability
{
torch
.
cuda
.
get_device_capability
()
}
"
f
"(required shared memory
{
required_shared_mem
}
> "
f
"available shared memory
{
max_shared_mem
}
). "
"This will be fixed in a future release."
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment