Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
cf5cb1e3
Unverified
Commit
cf5cb1e3
authored
Sep 26, 2023
by
Antoni Baum
Committed by
GitHub
Sep 26, 2023
Browse files
Allocate more shared memory to attention kernel (#1154)
parent
03ffd0a0
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
87 additions
and
3 deletions
+87
-3
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+5
-0
csrc/cuda_utils.cpp
csrc/cuda_utils.cpp
+13
-0
csrc/cuda_utils_kernels.cu
csrc/cuda_utils_kernels.cu
+14
-0
setup.py
setup.py
+11
-0
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+7
-1
vllm/utils.py
vllm/utils.py
+12
-1
vllm/worker/worker.py
vllm/worker/worker.py
+25
-1
No files found.
csrc/attention/attention_kernels.cu
View file @
cf5cb1e3
...
...
@@ -341,6 +341,9 @@ __global__ void single_query_cached_kv_attention_kernel(
}
// namespace vllm
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
cudaFuncSetAttribute( \
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
<<<grid, block, shared_mem_size, stream>>>( \
out_ptr, \
...
...
@@ -401,6 +404,8 @@ void single_query_cached_kv_attention_launcher(
int
padded_max_context_len
=
((
max_context_len
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
)
*
BLOCK_SIZE
;
int
logits_size
=
padded_max_context_len
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
int
shared_mem_size
=
std
::
max
(
logits_size
,
outputs_size
);
dim3
grid
(
num_heads
,
num_seqs
);
...
...
csrc/cuda_utils.cpp
0 → 100644
View file @
cf5cb1e3
#include <torch/extension.h>
int
get_device_attribute
(
int
attribute
,
int
device_id
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"get_device_attribute"
,
&
get_device_attribute
,
"Gets the specified device attribute."
);
}
csrc/cuda_utils_kernels.cu
0 → 100644
View file @
cf5cb1e3
int
get_device_attribute
(
int
attribute
,
int
device_id
)
{
int
device
,
value
;
if
(
device_id
<
0
)
{
cudaGetDevice
(
&
device
);
}
else
{
device
=
device_id
;
}
cudaDeviceGetAttribute
(
&
value
,
static_cast
<
cudaDeviceAttr
>
(
attribute
),
device
);
return
value
;
}
setup.py
View file @
cf5cb1e3
...
...
@@ -195,6 +195,17 @@ quantization_extension = CUDAExtension(
)
ext_modules
.
append
(
quantization_extension
)
# Misc. CUDA utils.
cuda_utils_extension
=
CUDAExtension
(
name
=
"vllm.cuda_utils"
,
sources
=
[
"csrc/cuda_utils.cpp"
,
"csrc/cuda_utils_kernels.cu"
],
extra_compile_args
=
{
"cxx"
:
CXX_FLAGS
,
"nvcc"
:
NVCC_FLAGS
,
},
)
ext_modules
.
append
(
cuda_utils_extension
)
def
get_path
(
*
filepath
)
->
str
:
return
os
.
path
.
join
(
ROOT_DIR
,
*
filepath
)
...
...
tests/kernels/test_attention.py
View file @
cf5cb1e3
...
...
@@ -7,8 +7,12 @@ from xformers import ops as xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
vllm
import
attention_ops
from
vllm.utils
import
get_max_shared_memory_bytes
MAX_SEQ_LEN
=
8192
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN
=
get_max_shared_memory_bytes
()
//
FLOAT32_BYTES
-
512
NUM_BLOCKS
=
128
# Arbitrary values for testing
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
...
@@ -135,6 +139,7 @@ def test_single_query_cached_kv_attention(
device
=
"cuda"
)
context_lens
=
[
random
.
randint
(
1
,
MAX_SEQ_LEN
)
for
_
in
range
(
num_seqs
)]
context_lens
[
-
1
]
=
MAX_SEQ_LEN
max_context_len
=
max
(
context_lens
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
...
...
@@ -243,6 +248,7 @@ def test_multi_query_kv_attention(
torch
.
cuda
.
manual_seed
(
seed
)
seq_lens
=
random
.
sample
(
range
(
1
,
MAX_SEQ_LEN
),
num_seqs
)
seq_lens
[
-
1
]
=
MAX_SEQ_LEN
num_tokens
=
sum
(
seq_lens
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
...
...
vllm/utils.py
View file @
cf5cb1e3
import
enum
from
platform
import
uname
import
uuid
from
platform
import
uname
import
psutil
import
torch
from
vllm
import
cuda_utils
class
Device
(
enum
.
Enum
):
GPU
=
enum
.
auto
()
...
...
@@ -25,6 +27,15 @@ class Counter:
self
.
counter
=
0
def
get_max_shared_memory_bytes
(
gpu
:
int
=
0
)
->
int
:
"""Returns the maximum shared memory per thread block in bytes."""
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
cudaDevAttrMaxSharedMemoryPerBlockOptin
=
97
# pylint: disable=invalid-name
max_shared_mem
=
cuda_utils
.
get_device_attribute
(
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
gpu
)
return
int
(
max_shared_mem
)
def
get_gpu_memory
(
gpu
:
int
=
0
)
->
int
:
"""Returns the total memory of the GPU in bytes."""
return
torch
.
cuda
.
get_device_properties
(
gpu
).
total_memory
...
...
vllm/worker/worker.py
View file @
cf5cb1e3
...
...
@@ -13,7 +13,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.utils
import
get_gpu_memory
from
vllm.utils
import
get_gpu_memory
,
get_max_shared_memory_bytes
class
Worker
:
...
...
@@ -136,6 +136,10 @@ class Worker:
def
init_cache_engine
(
self
,
cache_config
:
CacheConfig
)
->
None
:
self
.
cache_config
=
cache_config
self
.
block_size
=
cache_config
.
block_size
_check_if_can_support_max_seq_len
(
self
.
scheduler_config
.
max_model_len
,
self
.
block_size
)
self
.
cache_engine
=
CacheEngine
(
self
.
cache_config
,
self
.
model_config
,
self
.
parallel_config
)
self
.
cache_events
=
self
.
cache_engine
.
events
...
...
@@ -347,3 +351,23 @@ def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
def
_pad_to_max
(
x
:
List
[
int
],
max_len
:
int
)
->
List
[
int
]:
return
x
+
[
0
]
*
(
max_len
-
len
(
x
))
def
_check_if_can_support_max_seq_len
(
max_seq_len
:
int
,
block_size
:
int
)
->
None
:
# Follows the logic in
# attention_kernels.cu::single_query_cached_kv_attention_launcher
max_shared_mem
=
get_max_shared_memory_bytes
()
float32_bytes
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
padded_max_seq_len
=
(
(
max_seq_len
+
block_size
-
1
)
/
block_size
)
*
block_size
# padded_max_seq_len + extra buffer
required_shared_mem
=
(
padded_max_seq_len
+
512
)
*
float32_bytes
if
padded_max_seq_len
*
float32_bytes
>
max_shared_mem
:
raise
RuntimeError
(
f
"vLLM cannot currently support max_model_len=
{
max_seq_len
}
"
f
"with block_size=
{
block_size
}
on GPU with compute "
f
"capability
{
torch
.
cuda
.
get_device_capability
()
}
"
f
"(required shared memory
{
required_shared_mem
}
> "
f
"available shared memory
{
max_shared_mem
}
). "
"This will be fixed in a future release."
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment