Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
9738b84a
Unverified
Commit
9738b84a
authored
Nov 01, 2023
by
Antoni Baum
Committed by
GitHub
Nov 01, 2023
Browse files
Force paged attention v2 for long contexts (#1510)
parent
1fe09900
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
29 deletions
+4
-29
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+3
-1
vllm/worker/worker.py
vllm/worker/worker.py
+1
-28
No files found.
vllm/model_executor/layers/attention.py
View file @
9738b84a
...
@@ -156,7 +156,9 @@ class PagedAttention(nn.Module):
...
@@ -156,7 +156,9 @@ class PagedAttention(nn.Module):
# sequences or heads is large, we use V1 since there is enough work
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# TODO(woosuk): Tune this heuristic.
use_v1
=
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1
=
input_metadata
.
max_context_len
<=
8192
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
)
if
use_v1
:
if
use_v1
:
# Run PagedAttention V1.
# Run PagedAttention V1.
attention_ops
.
paged_attention_v1
(
attention_ops
.
paged_attention_v1
(
...
...
vllm/worker/worker.py
View file @
9738b84a
...
@@ -13,7 +13,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
...
@@ -13,7 +13,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.utils
import
get_gpu_memory
,
get_max_shared_memory_bytes
from
vllm.utils
import
get_gpu_memory
class
Worker
:
class
Worker
:
...
@@ -141,13 +141,6 @@ class Worker:
...
@@ -141,13 +141,6 @@ class Worker:
self
.
block_size
=
cache_config
.
block_size
self
.
block_size
=
cache_config
.
block_size
self
.
sliding_window
=
cache_config
.
sliding_window
self
.
sliding_window
=
cache_config
.
sliding_window
if
self
.
sliding_window
is
None
:
max_seq_len
=
self
.
scheduler_config
.
max_model_len
else
:
max_seq_len
=
min
(
self
.
scheduler_config
.
max_model_len
,
self
.
sliding_window
)
_check_if_can_support_max_seq_len
(
max_seq_len
,
self
.
block_size
)
self
.
cache_engine
=
CacheEngine
(
self
.
cache_config
,
self
.
model_config
,
self
.
cache_engine
=
CacheEngine
(
self
.
cache_config
,
self
.
model_config
,
self
.
parallel_config
)
self
.
parallel_config
)
self
.
cache_events
=
self
.
cache_engine
.
events
self
.
cache_events
=
self
.
cache_engine
.
events
...
@@ -421,26 +414,6 @@ def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
...
@@ -421,26 +414,6 @@ def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
return
x
+
[
pad
]
*
(
max_len
-
len
(
x
))
return
x
+
[
pad
]
*
(
max_len
-
len
(
x
))
def
_check_if_can_support_max_seq_len
(
max_seq_len
:
int
,
block_size
:
int
)
->
None
:
# Follows the logic in
# attention_kernels.cu::single_query_cached_kv_attention_launcher
max_shared_mem
=
get_max_shared_memory_bytes
()
float32_bytes
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
padded_max_seq_len
=
(
(
max_seq_len
+
block_size
-
1
)
/
block_size
)
*
block_size
# padded_max_seq_len + extra buffer
required_shared_mem
=
(
padded_max_seq_len
+
512
)
*
float32_bytes
if
padded_max_seq_len
*
float32_bytes
>
max_shared_mem
:
raise
RuntimeError
(
f
"vLLM cannot currently support max_model_len=
{
max_seq_len
}
"
f
"with block_size=
{
block_size
}
on GPU with compute "
f
"capability
{
torch
.
cuda
.
get_device_capability
()
}
"
f
"(required shared memory
{
required_shared_mem
}
> "
f
"available shared memory
{
max_shared_mem
}
). "
"This will be fixed in a future release."
)
def
_check_if_gpu_supports_dtype
(
torch_dtype
:
torch
.
dtype
):
def
_check_if_gpu_supports_dtype
(
torch_dtype
:
torch
.
dtype
):
# Check if the GPU supports the dtype.
# Check if the GPU supports the dtype.
if
torch_dtype
==
torch
.
bfloat16
:
if
torch_dtype
==
torch
.
bfloat16
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment