Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
705f6a35
Commit
705f6a35
authored
Jul 16, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.2' into v0.5.2-dtk24.04.1
parents
af837396
4cf256ae
Changes
439
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1680 additions
and
123 deletions
+1680
-123
tests/worker/test_model_input.py
tests/worker/test_model_input.py
+152
-0
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+46
-39
tests/worker/test_swap.py
tests/worker/test_swap.py
+2
-2
vllm/__init__.py
vllm/__init__.py
+3
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+72
-7
vllm/_ipex_ops.py
vllm/_ipex_ops.py
+244
-0
vllm/adapter_commons/__init__.py
vllm/adapter_commons/__init__.py
+0
-0
vllm/adapter_commons/layers.py
vllm/adapter_commons/layers.py
+14
-0
vllm/adapter_commons/models.py
vllm/adapter_commons/models.py
+104
-0
vllm/adapter_commons/request.py
vllm/adapter_commons/request.py
+25
-0
vllm/adapter_commons/utils.py
vllm/adapter_commons/utils.py
+90
-0
vllm/adapter_commons/worker_manager.py
vllm/adapter_commons/worker_manager.py
+36
-0
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+13
-1
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+10
-3
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+17
-11
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+74
-29
vllm/attention/backends/ipex_attn.py
vllm/attention/backends/ipex_attn.py
+361
-0
vllm/attention/backends/openvino.py
vllm/attention/backends/openvino.py
+101
-0
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+243
-0
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+73
-30
No files found.
Too many changes to show.
To preserve performance only
439 of 439+
files are displayed.
Plain diff
Email patch
tests/worker/test_model_input.py
0 → 100644
View file @
705f6a35
import
dataclasses
from
typing
import
List
,
Tuple
,
Type
import
torch
from
vllm.attention
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.worker.embedding_model_runner
import
(
ModelInputForGPUWithPoolingMetadata
)
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
class
MockAttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
raise
NotImplementedError
@
staticmethod
def
get_impl_cls
():
raise
NotImplementedError
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
AttentionMetadata
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
raise
NotImplementedError
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
pass
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
pass
def
test_model_runner_input
():
sampling_metadata
=
SamplingMetadata
(
[
"seq_group"
],
"selected_token_indices"
,
"categorized_sample_indices"
,
"num_prompts"
,
)
attn_metadata
=
AttentionMetadata
(
num_prefills
=
1
,
num_prefill_tokens
=
2
,
num_decode_tokens
=
3
,
slot_mapping
=
torch
.
zeros
(
1
),
)
model_input
=
ModelInputForGPUWithSamplingMetadata
(
input_tokens
=
torch
.
ones
(
10
),
input_positions
=
torch
.
ones
(
10
),
sampling_metadata
=
sampling_metadata
,
attn_metadata
=
attn_metadata
)
assert
isinstance
(
model_input
,
ModelInputForGPUWithSamplingMetadata
)
# Test round trip serialization.
tensor_dict
=
model_input
.
as_broadcastable_tensor_dict
()
attn_backend
=
MockAttentionBackend
()
received_model_input
=
(
ModelInputForGPUWithSamplingMetadata
.
from_broadcasted_tensor_dict
(
tensor_dict
,
attn_backend
=
attn_backend
))
# Check that received copy has correct values.
assert
isinstance
(
received_model_input
,
ModelInputForGPUWithSamplingMetadata
)
assert
received_model_input
.
input_tokens
is
not
None
assert
(
received_model_input
.
input_tokens
==
model_input
.
input_tokens
).
all
()
assert
received_model_input
.
input_positions
is
not
None
assert
(
received_model_input
.
input_positions
==
model_input
.
input_positions
).
all
()
assert
received_model_input
.
multi_modal_kwargs
is
None
assert
(
received_model_input
.
multi_modal_kwargs
==
model_input
.
multi_modal_kwargs
)
assert
received_model_input
.
lora_requests
is
None
assert
received_model_input
.
lora_requests
==
model_input
.
lora_requests
assert
received_model_input
.
lora_mapping
is
None
assert
received_model_input
.
lora_mapping
==
model_input
.
lora_mapping
for
field
in
dataclasses
.
fields
(
AttentionMetadata
):
assert
getattr
(
received_model_input
.
attn_metadata
,
field
.
name
,
None
)
==
getattr
(
attn_metadata
,
field
.
name
,
None
)
# For sampling metadata, only selected_token_indices is copied.
assert
(
received_model_input
.
sampling_metadata
.
selected_token_indices
==
sampling_metadata
.
selected_token_indices
)
assert
received_model_input
.
sampling_metadata
.
seq_groups
is
None
def
test_embedding_model_runner_input
():
pooling_metadata
=
PoolingMetadata
(
seq_groups
=
[[
0
]],
seq_data
=
{},
prompt_lens
=
[
1
],
)
attn_metadata
=
AttentionMetadata
(
num_prefills
=
1
,
num_prefill_tokens
=
2
,
num_decode_tokens
=
3
,
slot_mapping
=
torch
.
zeros
(
1
),
)
model_input
=
ModelInputForGPUWithPoolingMetadata
(
input_tokens
=
torch
.
ones
(
10
),
input_positions
=
torch
.
ones
(
10
),
pooling_metadata
=
pooling_metadata
,
attn_metadata
=
attn_metadata
)
assert
isinstance
(
model_input
,
ModelInputForGPUWithPoolingMetadata
)
# Test round trip serialization.
tensor_dict
=
model_input
.
as_broadcastable_tensor_dict
()
attn_backend
=
MockAttentionBackend
()
received_model_input
=
(
ModelInputForGPUWithPoolingMetadata
.
from_broadcasted_tensor_dict
(
tensor_dict
,
attn_backend
=
attn_backend
))
# Check that received copy has correct values.
assert
isinstance
(
received_model_input
,
ModelInputForGPUWithPoolingMetadata
)
assert
received_model_input
.
input_tokens
is
not
None
assert
(
received_model_input
.
input_tokens
==
model_input
.
input_tokens
).
all
()
assert
received_model_input
.
input_positions
is
not
None
assert
(
received_model_input
.
input_positions
==
model_input
.
input_positions
).
all
()
assert
received_model_input
.
multi_modal_kwargs
is
None
assert
(
received_model_input
.
multi_modal_kwargs
==
model_input
.
multi_modal_kwargs
)
assert
received_model_input
.
lora_requests
is
None
assert
received_model_input
.
lora_requests
==
model_input
.
lora_requests
assert
received_model_input
.
lora_mapping
is
None
assert
received_model_input
.
lora_mapping
==
model_input
.
lora_mapping
for
field
in
dataclasses
.
fields
(
AttentionMetadata
):
assert
getattr
(
received_model_input
.
attn_metadata
,
field
.
name
,
None
)
==
getattr
(
attn_metadata
,
field
.
name
,
None
)
# Pooling metadata is not broadcast.
assert
received_model_input
.
pooling_metadata
is
None
tests/worker/test_model_runner.py
View file @
705f6a35
from
typing
import
List
import
pytest
import
pytest
import
torch
import
torch
from
vllm.distributed.parallel_state
import
init_distributed_environment
from
vllm.distributed.parallel_state
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
...
@@ -20,6 +23,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
...
@@ -20,6 +23,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
cache_config
=
engine_config
.
cache_config
,
cache_config
=
engine_config
.
cache_config
,
load_config
=
engine_config
.
load_config
,
load_config
=
engine_config
.
load_config
,
lora_config
=
engine_config
.
lora_config
,
lora_config
=
engine_config
.
lora_config
,
prompt_adapter_config
=
engine_config
.
prompt_adapter_config
,
is_driver_worker
=
True
,
is_driver_worker
=
True
,
)
)
return
model_runner
return
model_runner
...
@@ -34,8 +38,8 @@ def test_prepare_prompt(batch_size):
...
@@ -34,8 +38,8 @@ def test_prepare_prompt(batch_size):
enable_chunked_prefill
=
False
,
enable_chunked_prefill
=
False
,
)
)
seq_lens
=
[]
seq_lens
:
List
[
int
]
=
[]
seq_group_metadata_list
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
block_tables
=
{
0
:
[
1
]}
block_tables
=
{
0
:
[
1
]}
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
...
@@ -58,12 +62,13 @@ def test_prepare_prompt(batch_size):
...
@@ -58,12 +62,13 @@ def test_prepare_prompt(batch_size):
expected_selected_token_indices
.
append
(
selected_token_start_idx
+
expected_selected_token_indices
.
append
(
selected_token_start_idx
+
seq_len
-
1
)
seq_len
-
1
)
selected_token_start_idx
+=
seq_len
selected_token_start_idx
+=
seq_len
model_input
=
model_runner
.
_prepare_model_input
(
seq_group_metadata_list
)
model_input
=
model_runner
.
_prepare_model_input_tensors
(
seq_group_metadata_list
)
input_tokens
=
model_input
.
input_tokens
input_tokens
=
model_input
.
input_tokens
input_positions
=
model_input
.
input_positions
input_positions
=
model_input
.
input_positions
attn_metadata
=
model_input
.
attn_metadata
attn_metadata
=
model_input
.
attn_metadata
return_seq_lens
=
model_input
.
seq_lens
return_seq_lens
=
model_input
.
seq_lens
slot_mapping
=
model_input
.
slot_mapping
slot_mapping
=
attn_metadata
.
slot_mapping
assert
return_seq_lens
==
seq_lens
assert
return_seq_lens
==
seq_lens
assert
len
(
slot_mapping
)
==
len
(
input_tokens
)
assert
len
(
slot_mapping
)
==
len
(
input_tokens
)
...
@@ -150,15 +155,14 @@ def test_prepare_decode_cuda_graph(batch_size):
...
@@ -150,15 +155,14 @@ def test_prepare_decode_cuda_graph(batch_size):
enable_chunked_prefill
=
False
,
enable_chunked_prefill
=
False
,
)
)
context_lens
=
[]
context_lens
:
List
[
int
]
=
[]
seq_group_metadata_list
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
# Assume each seq group finishes prefill.
# Assume each seq group finishes prefill.
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
context_lens
.
append
(
context_len
)
context_lens
.
append
(
context_len
)
seq_data
=
list
(
range
(
context_len
))
seq_data
=
SequenceData
(
list
(
range
(
context_len
)))
seq_data
=
SequenceData
(
seq_data
)
seq_data
.
update_num_computed_tokens
(
context_len
)
seq_data
.
update_num_computed_tokens
(
context_len
)
# Append one token ID since prefill is finished.
# Append one token ID since prefill is finished.
seq_data
.
append_token_id
(
1
,
0
)
seq_data
.
append_token_id
(
1
,
0
)
...
@@ -172,10 +176,11 @@ def test_prepare_decode_cuda_graph(batch_size):
...
@@ -172,10 +176,11 @@ def test_prepare_decode_cuda_graph(batch_size):
assert
seq_group_metadata
.
token_chunk_size
==
1
assert
seq_group_metadata
.
token_chunk_size
==
1
seq_group_metadata_list
.
append
(
seq_group_metadata
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
model_input
=
model_runner
.
_prepare_model_input
(
seq_group_metadata_list
)
model_input
=
model_runner
.
_prepare_model_input_tensors
(
seq_group_metadata_list
)
input_tokens
,
input_positions
,
attn_metadata
,
slot_mapping
=
(
input_tokens
,
input_positions
,
attn_metadata
,
slot_mapping
=
(
model_input
.
input_tokens
,
model_input
.
input_positions
,
model_input
.
input_tokens
,
model_input
.
input_positions
,
model_input
.
attn_metadata
,
model_input
.
slot_mapping
)
model_input
.
attn_metadata
,
model_input
.
attn_metadata
.
slot_mapping
)
assert
len
(
slot_mapping
)
==
len
(
input_tokens
)
assert
len
(
slot_mapping
)
==
len
(
input_tokens
)
expected_bs
=
_get_graph_batch_size
(
len
(
seq_group_metadata_list
))
expected_bs
=
_get_graph_batch_size
(
len
(
seq_group_metadata_list
))
...
@@ -256,33 +261,30 @@ def test_empty_seq_group():
...
@@ -256,33 +261,30 @@ def test_empty_seq_group():
dtype
=
"float16"
,
dtype
=
"float16"
,
enforce_eager
=
False
,
enforce_eager
=
False
,
)
)
seq_group_metadata_list
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
model_input
=
model_runner
.
_prepare_model_input
(
seq_group_metadata_list
)
model_input
=
model_runner
.
_prepare_model_input_tensors
(
input_tokens
,
input_positions
,
attn_metadata
,
slot_mapping
=
(
seq_group_metadata_list
)
input_tokens
,
input_positions
,
attn_metadata
=
(
model_input
.
input_tokens
,
model_input
.
input_tokens
,
model_input
.
input_positions
,
model_input
.
input_positions
,
model_input
.
attn_metadata
,
model_input
.
attn_metadata
,
model_input
.
slot_mapping
,
)
)
assert
len
(
input_tokens
)
==
0
assert
input_tokens
is
None
assert
len
(
input_positions
)
==
0
assert
input_positions
is
None
assert
attn_metadata
is
None
assert
attn_metadata
is
None
assert
len
(
slot_mapping
)
==
0
model_input
=
model_runner
.
_prepare_model_input_tensors
(
model_input
=
model_runner
.
_prepare_model_input
(
seq_group_metadata_list
)
seq_group_metadata_list
)
(
input_tokens
,
input_positions
,
attn_metadata
,
slot_mapping
,
(
input_tokens
,
input_positions
,
attn_metadata
,
return_seq_lens
)
=
(
return_seq_lens
)
=
(
model_input
.
input_tokens
,
model_input
.
input_tokens
,
model_input
.
input_positions
,
model_input
.
input_positions
,
model_input
.
attn_metadata
,
model_input
.
attn_metadata
,
model_input
.
seq_lens
,
model_input
.
slot_mapping
,
)
model_input
.
seq_lens
,
assert
input_tokens
is
None
)
assert
input_positions
is
None
assert
len
(
input_tokens
)
==
0
assert
len
(
input_positions
)
==
0
assert
attn_metadata
is
None
assert
attn_metadata
is
None
assert
len
(
slot_mapping
)
==
0
assert
return_seq_lens
is
None
assert
len
(
return_seq_lens
)
==
0
@
pytest
.
fixture
@
pytest
.
fixture
...
@@ -292,6 +294,7 @@ def distributed_init():
...
@@ -292,6 +294,7 @@ def distributed_init():
rank
=
0
,
rank
=
0
,
distributed_init_method
=
f
"tcp://127.0.0.1:
{
get_open_port
()
}
"
,
distributed_init_method
=
f
"tcp://127.0.0.1:
{
get_open_port
()
}
"
,
local_rank
=
0
)
local_rank
=
0
)
ensure_model_parallel_initialized
(
1
,
1
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
2
,
128
)))
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
2
,
128
)))
...
@@ -308,10 +311,10 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
...
@@ -308,10 +311,10 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
)
)
# Add prefill requests.
# Add prefill requests.
seq_lens
=
[]
seq_lens
:
List
[
int
]
=
[]
seq_group_metadata_list
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
prefill_metadata_list
=
[]
prefill_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
decode_metadata_list
=
[]
decode_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
block_tables
=
{
0
:
[
1
]}
block_tables
=
{
0
:
[
1
]}
prefill_batch_size
=
batch_size
//
2
prefill_batch_size
=
batch_size
//
2
decode_batch_size
=
batch_size
-
prefill_batch_size
decode_batch_size
=
batch_size
-
prefill_batch_size
...
@@ -350,8 +353,12 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
...
@@ -350,8 +353,12 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
seq_group_metadata_list
.
append
(
seq_group_metadata
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
decode_metadata_list
.
append
(
seq_group_metadata
)
decode_metadata_list
.
append
(
seq_group_metadata
)
(
input_tokens
,
input_positions
,
attn_metadata
,
_
,
_
,
_
,
model_input
=
model_runner
.
prepare_model_input
(
seq_group_metadata_list
)
_
)
=
model_runner
.
prepare_input_tensors
(
seq_group_metadata_list
)
(
input_tokens
,
input_positions
,
attn_metadata
)
=
(
model_input
.
input_tokens
,
model_input
.
input_positions
,
model_input
.
attn_metadata
,
)
prefill_meta_actual
=
attn_metadata
.
prefill_metadata
prefill_meta_actual
=
attn_metadata
.
prefill_metadata
decode_meta_actual
=
attn_metadata
.
decode_metadata
decode_meta_actual
=
attn_metadata
.
decode_metadata
...
@@ -364,7 +371,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
...
@@ -364,7 +371,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# Verify attn metadata is consistent. We don't need to test individual
# Verify attn metadata is consistent. We don't need to test individual
# values here because they are tested above.
# values here because they are tested above.
attn_metadata
=
model_runner
.
_prepare_model_input
(
attn_metadata
=
model_runner
.
_prepare_model_input
_tensors
(
seq_group_metadata_list
).
attn_metadata
seq_group_metadata_list
).
attn_metadata
for
attr_expected
,
attr_actual
in
zip
(
vars
(
attn_metadata
.
prefill_metadata
),
for
attr_expected
,
attr_actual
in
zip
(
vars
(
attn_metadata
.
prefill_metadata
),
...
...
tests/worker/test_swap.py
View file @
705f6a35
...
@@ -39,8 +39,8 @@ def test_swap() -> None:
...
@@ -39,8 +39,8 @@ def test_swap() -> None:
num_cpu_blocks
=
engine_config
.
cache_config
.
num_cpu_blocks
)
num_cpu_blocks
=
engine_config
.
cache_config
.
num_cpu_blocks
)
# Randomly initialize the cache.
# Randomly initialize the cache.
gpu_cache
=
worker
.
cache_engine
.
gpu_cache
gpu_cache
=
worker
.
cache_engine
[
0
]
.
gpu_cache
cpu_cache
=
worker
.
cache_engine
.
cpu_cache
cpu_cache
=
worker
.
cache_engine
[
0
]
.
cpu_cache
num_layers
=
len
(
gpu_cache
)
num_layers
=
len
(
gpu_cache
)
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
gpu_key_cache
,
gpu_value_cache
=
gpu_cache
[
i
]
gpu_key_cache
,
gpu_value_cache
=
gpu_cache
[
i
]
...
...
vllm/__init__.py
View file @
705f6a35
...
@@ -13,9 +13,11 @@ from vllm.pooling_params import PoolingParams
...
@@ -13,9 +13,11 @@ from vllm.pooling_params import PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.version
import
__dcu_version__
from
vllm.version
import
__dcu_version__
__version__
=
"0.5.0"
from
.version
import
__commit__
,
__version__
__all__
=
[
__all__
=
[
"__commit__"
,
"__version__"
,
"LLM"
,
"LLM"
,
"ModelRegistry"
,
"ModelRegistry"
,
"PromptStrictInputs"
,
"PromptStrictInputs"
,
...
...
vllm/_custom_ops.py
View file @
705f6a35
import
contextlib
import
contextlib
import
functools
from
typing
import
List
,
Optional
,
Tuple
,
Type
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
try
:
try
:
import
vllm._C
import
vllm._C
except
ImportError
as
e
:
except
ImportError
as
e
:
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
with
contextlib
.
suppress
(
ImportError
):
with
contextlib
.
suppress
(
ImportError
):
...
@@ -23,6 +26,25 @@ def is_custom_op_supported(op_name: str) -> bool:
...
@@ -23,6 +26,25 @@ def is_custom_op_supported(op_name: str) -> bool:
return
op
is
not
None
return
op
is
not
None
def
hint_on_error
(
fn
):
@
functools
.
wraps
(
fn
)
def
wrapper
(
*
args
,
**
kwargs
):
try
:
return
fn
(
*
args
,
**
kwargs
)
except
AttributeError
as
e
:
msg
=
(
"Error in calling custom op %s: %s
\n
"
"Possibly you have built or installed an obsolete version of vllm.
\n
"
"Please try a clean build and install of vllm,"
"or remove old built files such as vllm/*cpython*.so and build/ ."
)
logger
.
error
(
msg
,
fn
.
__name__
,
e
)
raise
e
return
wrapper
# activation ops
# activation ops
def
silu_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
def
silu_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
silu_and_mul
(
out
,
x
)
torch
.
ops
.
_C
.
silu_and_mul
(
out
,
x
)
...
@@ -44,6 +66,10 @@ def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
...
@@ -44,6 +66,10 @@ def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
torch
.
ops
.
_C
.
gelu_new
(
out
,
x
)
torch
.
ops
.
_C
.
gelu_new
(
out
,
x
)
def
gelu_quick
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
gelu_quick
(
out
,
x
)
# page attention ops
# page attention ops
def
paged_attention_v1
(
def
paged_attention_v1
(
out
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
...
@@ -190,9 +216,16 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -190,9 +216,16 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# cutlass
# cutlass
def
cutlass_scaled_mm_dq
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
def
cutlass_scaled_mm_supports_fp8
(
cuda_device_capability
:
int
)
->
bool
:
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
return
torch
.
ops
.
_C
.
cutlass_scaled_mm_supports_fp8
(
cuda_device_capability
)
out_dtype
:
Type
[
torch
.
dtype
])
->
torch
.
Tensor
:
def
cutlass_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
Type
[
torch
.
dtype
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
assert
(
b
.
shape
[
0
]
%
16
==
0
and
b
.
shape
[
1
]
%
16
==
0
)
assert
(
b
.
shape
[
0
]
%
16
==
0
and
b
.
shape
[
1
]
%
16
==
0
)
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
...
@@ -200,7 +233,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
...
@@ -200,7 +233,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
n
=
b
.
shape
[
1
]
n
=
b
.
shape
[
1
]
out
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
a
.
device
)
out
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
a
.
device
)
torch
.
ops
.
_C
.
cutlass_scaled_mm
_dq
(
out
,
a
,
b
,
scale_a
,
scale_b
)
torch
.
ops
.
_C
.
cutlass_scaled_mm
(
out
,
a
,
b
,
scale_a
,
scale_b
,
bias
)
return
out
return
out
...
@@ -238,6 +271,15 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -238,6 +271,15 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_k
,
is_k_full
)
size_k
,
is_k_full
)
# fp8 marlin
def
fp8_marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
fp8_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
workspace
,
num_bits
,
size_m
,
size_n
,
size_k
)
# fp8
# fp8
# def scaled_fp8_quant(
# def scaled_fp8_quant(
# input: torch.Tensor,
# input: torch.Tensor,
...
@@ -352,7 +394,8 @@ def reshape_and_cache_flash(
...
@@ -352,7 +394,8 @@ def reshape_and_cache_flash(
kv_cache_dtype
)
kv_cache_dtype
)
def
copy_blocks
(
key_caches
:
torch
.
Tensor
,
value_caches
:
torch
.
Tensor
,
def
copy_blocks
(
key_caches
:
List
[
torch
.
Tensor
],
value_caches
:
List
[
torch
.
Tensor
],
block_mapping
:
torch
.
Tensor
)
->
None
:
block_mapping
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping
)
torch
.
ops
.
_C_cache_ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping
)
...
@@ -459,3 +502,25 @@ def dispatch_bgmv_low_level(
...
@@ -459,3 +502,25 @@ def dispatch_bgmv_low_level(
h_out
,
h_out
,
y_offset
,
y_offset
,
)
)
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0
names_and_values
=
globals
()
names_and_values_to_update
=
{}
# prepare variables to avoid dict size change during iteration
k
,
v
,
arg
=
None
,
None
,
None
fn_type
=
type
(
lambda
x
:
x
)
for
k
,
v
in
names_and_values
.
items
():
# find functions that are defined in this file and have torch.Tensor
# in their annotations. `arg == "torch.Tensor"` is used to handle
# the case when users use `import __annotations__` to turn type
# hints into strings.
if
isinstance
(
v
,
fn_type
)
\
and
v
.
__code__
.
co_filename
==
__file__
\
and
any
(
arg
is
torch
.
Tensor
or
arg
==
"torch.Tensor"
for
arg
in
v
.
__annotations__
.
values
()):
names_and_values_to_update
[
k
]
=
hint_on_error
(
v
)
names_and_values
.
update
(
names_and_values_to_update
)
del
names_and_values_to_update
,
names_and_values
,
v
,
k
,
fn_type
vllm/_ipex_ops.py
0 → 100644
View file @
705f6a35
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
try
:
import
intel_extension_for_pytorch
as
ipex
except
ImportError
as
e
:
logger
.
warning
(
"Import error msg: %s"
,
e
.
msg
)
class
ipex_ops
:
@
staticmethod
def
_reshape_activation_tensor
(
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
num
=
x
.
size
(
0
)
d
=
x
.
size
(
1
)
//
2
x
=
x
.
reshape
(
num
,
2
,
d
)
x1
,
x2
=
torch
.
chunk
(
x
,
chunks
=
2
,
dim
=
1
)
x1
=
x1
.
reshape
(
num
,
d
)
x2
=
x2
.
reshape
(
num
,
d
)
return
x1
,
x2
def
silu_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
x1
,
x2
=
ipex_ops
.
_reshape_activation_tensor
(
x
)
ipex
.
llm
.
functional
.
silu_mul
(
x1
,
x2
,
out
)
def
gelu_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
x1
,
x2
=
ipex_ops
.
_reshape_activation_tensor
(
x
)
ipex
.
llm
.
functional
.
gelu_mul
(
x1
,
x2
,
out
,
"none"
)
def
gelu_tanh_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
x1
,
x2
=
ipex_ops
.
_reshape_activation_tensor
(
x
)
ipex
.
llm
.
functional
.
gelu_mul
(
x1
,
x2
,
out
,
"tanh"
)
def
gelu_fast
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
out
.
copy_
(
torch
.
nn
.
functional
.
gelu
(
x
))
def
gelu_new
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
out
.
copy_
(
torch
.
nn
.
functional
.
gelu
(
x
))
# TODO add implementation of gelu_quick here
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
def
paged_attention_v1
(
out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_context_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_scale
:
float
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
)
->
None
:
assert
kv_cache_dtype
==
"auto"
num_heads
=
out
.
size
(
1
)
num_queries_per_tokens
=
num_heads
//
num_kv_heads
head_mapping
=
torch
.
arange
(
0
,
num_kv_heads
,
device
=
query
.
device
,
dtype
=
torch
.
int32
,
).
view
(
num_kv_heads
,
1
).
repeat_interleave
(
num_queries_per_tokens
).
flatten
()
# todo: ipex will refactor namespace
torch
.
xpu
.
paged_attention_v1
(
out
,
query
.
contiguous
(),
key_cache
.
view_as
(
value_cache
),
value_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
)
def
paged_attention_v2
(
out
:
torch
.
Tensor
,
exp_sum
:
torch
.
Tensor
,
max_logits
:
torch
.
Tensor
,
tmp_out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_context_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_scale
:
float
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
)
->
None
:
assert
kv_cache_dtype
==
"auto"
num_heads
=
out
.
size
(
1
)
num_queries_per_tokens
=
num_heads
//
num_kv_heads
head_mapping
=
torch
.
arange
(
0
,
num_kv_heads
,
dtype
=
torch
.
int32
,
device
=
query
.
device
,
).
view
(
num_kv_heads
,
1
).
repeat_interleave
(
num_queries_per_tokens
).
flatten
()
# todo: ipex will refactor namespace
torch
.
xpu
.
paged_attention_v2
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
.
contiguous
(),
key_cache
.
view_as
(
value_cache
),
value_cache
,
head_mapping
,
block_tables
,
context_lens
,
scale
,
block_size
,
max_context_len
,
alibi_slopes
)
def
rotary_embedding
(
positions
:
torch
.
Tensor
,
# [batch_size, seq_len]
query
:
torch
.
Tensor
,
# [batch_size, seq_len, num_heads*head_size]
key
:
torch
.
Tensor
,
# [batch_size, seq_len, num_kv_heads*head_size]
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
# [cos_sin_dim, rot_dim]
is_neox
:
bool
,
)
->
None
:
if
positions
.
dim
()
==
1
:
positions
=
positions
.
unsqueeze
(
0
)
query
=
query
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
rotary_dim
=
cos_sin_cache
.
size
(
1
)
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
head_size
)
query_rot
=
query
[...,
:
rotary_dim
]
key_rot
=
key
[...,
:
rotary_dim
]
cos_sin
=
cos_sin_cache
[
positions
.
long
()]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
is_neox
:
cos
=
cos
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
else
:
cos
=
cos
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
ipex
.
llm
.
functional
.
rotary_embedding
(
query_rot
,
key_rot
,
sin
,
cos
,
rotary_dim
,
is_neox
,
positions
)
def
batched_rotary_embedding
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
rot_dim
:
int
,
cos_sin_cache_offsets
:
torch
.
Tensor
)
->
None
:
if
positions
.
dim
()
==
1
:
positions
=
positions
.
unsqueeze
(
0
)
query
=
query
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
cos_sin_cache_offsets
=
cos_sin_cache_offsets
.
view_as
(
positions
)
rotary_dim
=
cos_sin_cache
.
size
(
1
)
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
head_size
)
query_rot
=
query
[...,
:
rotary_dim
]
key_rot
=
key
[...,
:
rotary_dim
]
cos_sin
=
cos_sin_cache
[
torch
.
add
(
positions
,
cos_sin_cache_offsets
).
long
()]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
is_neox
:
cos
=
cos
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
else
:
cos
=
cos
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
ipex
.
llm
.
functional
.
rotary_embedding
(
query_rot
,
key_rot
,
sin
,
cos
,
rotary_dim
,
is_neox
,
positions
)
def
rms_norm
(
out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
tmp
=
ipex
.
llm
.
functional
.
rms_norm
(
input
,
weight
,
epsilon
)
out
.
copy_
(
tmp
)
def
fused_add_rms_norm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
tmp
=
ipex
.
llm
.
functional
.
add_rms_norm
(
residual
,
input
,
weight
,
None
,
epsilon
,
True
)
input
.
copy_
(
tmp
)
def
varlen_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
seqlen_q
:
torch
.
Tensor
,
seqlen_k
:
torch
.
Tensor
,
max_seqlen_q
:
int
,
max_seqlen_k
:
int
,
pdropout
:
float
,
softmax_scale
:
float
,
zero_tensors
:
bool
,
is_causal
:
bool
,
return_softmax
:
bool
,
gen_
:
torch
.
Generator
,
)
->
None
:
ipex
.
llm
.
functional
.
varlen_attention
(
query
,
key
,
value
,
out
,
seqlen_q
,
seqlen_k
,
max_seqlen_q
,
max_seqlen_k
,
pdropout
,
softmax_scale
,
zero_tensors
,
is_causal
,
return_softmax
,
gen_
)
def
reshape_and_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
kv_scale
:
float
,
)
->
None
:
assert
kv_cache_dtype
==
"auto"
ipex
.
llm
.
modules
.
PagedAttention
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
)
@
staticmethod
def
copy_blocks
(
key_caches
:
List
[
torch
.
Tensor
],
value_caches
:
List
[
torch
.
Tensor
],
block_mapping
:
torch
.
Tensor
)
->
None
:
torch
.
xpu
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping
)
def
swap_blocks
(
src
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
block_mapping
:
torch
.
Tensor
)
->
None
:
torch
.
xpu
.
swap_blocks
(
src
,
dst
,
block_mapping
)
vllm/adapter_commons/__init__.py
0 → 100644
View file @
705f6a35
vllm/adapter_commons/layers.py
0 → 100644
View file @
705f6a35
from
dataclasses
import
dataclass
from
typing
import
Tuple
@
dataclass
class
AdapterMapping
:
# Per every token in input_ids:
index_mapping
:
Tuple
[
int
,
...]
# Per sampled token:
prompt_mapping
:
Tuple
[
int
,
...]
def
__post_init__
(
self
):
self
.
index_mapping
=
tuple
(
self
.
index_mapping
)
self
.
prompt_mapping
=
tuple
(
self
.
prompt_mapping
)
\ No newline at end of file
vllm/adapter_commons/models.py
0 → 100644
View file @
705f6a35
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Callable
,
Dict
,
Hashable
,
Optional
,
TypeVar
from
torch
import
nn
from
vllm.logger
import
init_logger
from
vllm.utils
import
LRUCache
logger
=
init_logger
(
__name__
)
class
AdapterModel
(
ABC
):
def
__init__
(
self
,
model_id
=
None
):
self
.
id
=
model_id
@
abstractmethod
def
from_local_checkpoint
(
cls
,
model_dir
,
model_id
=
None
,
**
kwargs
):
# Common initialization code
# Load weights or embeddings from local checkpoint
raise
NotImplementedError
(
"Subclasses must implement this method."
)
T
=
TypeVar
(
'T'
)
class
AdapterLRUCache
(
LRUCache
[
T
]):
def
__init__
(
self
,
capacity
:
int
,
deactivate_fn
:
Callable
[[
Hashable
],
None
]):
super
().
__init__
(
capacity
)
self
.
deactivate_fn
=
deactivate_fn
def
_on_remove
(
self
,
key
:
Hashable
,
value
:
T
):
logger
.
debug
(
"Removing adapter int id: %d"
,
key
)
self
.
deactivate_fn
(
key
)
return
super
().
_on_remove
(
key
,
value
)
class
AdapterModelManager
(
ABC
):
def
__init__
(
self
,
model
:
nn
.
Module
,
):
"""Create a AdapterModelManager and adapter for a given model.
Args:
model: the model to be adapted.
"""
self
.
model
:
nn
.
Module
=
model
self
.
_registered_adapters
:
Dict
[
int
,
Any
]
=
{}
# Dict instead of a Set for compatibility with LRUCache.
self
.
_active_adapters
:
Dict
[
int
,
None
]
=
{}
self
.
adapter_type
=
'Adapter'
self
.
_last_mapping
=
None
def
__len__
(
self
)
->
int
:
return
len
(
self
.
_registered_adapters
)
@
property
@
abstractmethod
def
adapter_slots
(
self
):
...
@
property
@
abstractmethod
def
capacity
(
self
):
...
@
abstractmethod
def
activate_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
...
@
abstractmethod
def
deactivate_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
...
@
abstractmethod
def
add_adapter
(
self
,
adapter
:
Any
)
->
bool
:
...
@
abstractmethod
def
set_adapter_mapping
(
self
,
mapping
:
Any
)
->
None
:
...
@
abstractmethod
def
remove_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
...
@
abstractmethod
def
remove_all_adapters
(
self
):
...
@
abstractmethod
def
get_adapter
(
self
,
adapter_id
:
int
)
->
Optional
[
Any
]:
...
@
abstractmethod
def
list_adapters
(
self
)
->
Dict
[
int
,
Any
]:
...
@
abstractmethod
def
pin_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
...
vllm/adapter_commons/request.py
0 → 100644
View file @
705f6a35
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
@
dataclass
class
AdapterRequest
:
"""
Base class for adapter requests.
"""
@
property
@
abstractmethod
def
adapter_id
(
self
):
...
def
__post_init__
(
self
):
if
self
.
adapter_id
<
1
:
raise
ValueError
(
f
"id must be > 0, got
{
self
.
adapter_id
}
"
)
def
__eq__
(
self
,
value
:
object
)
->
bool
:
return
isinstance
(
value
,
self
.
__class__
)
and
self
.
adapter_id
==
value
.
adapter_id
def
__hash__
(
self
)
->
int
:
return
hash
(
self
.
adapter_id
)
vllm/adapter_commons/utils.py
0 → 100644
View file @
705f6a35
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Set
## model functions
def
deactivate_adapter
(
adapter_id
:
int
,
active_adapters
:
Dict
[
int
,
None
],
deactivate_func
:
Callable
)
->
bool
:
if
adapter_id
in
active_adapters
:
deactivate_func
(
adapter_id
)
active_adapters
.
pop
(
adapter_id
)
return
True
return
False
def
add_adapter
(
adapter
:
Any
,
registered_adapters
:
Dict
[
int
,
Any
],
capacity
:
int
,
add_func
:
Callable
)
->
bool
:
if
adapter
.
id
not
in
registered_adapters
:
if
len
(
registered_adapters
)
>=
capacity
:
raise
RuntimeError
(
'No free adapter slots.'
)
add_func
(
adapter
)
registered_adapters
[
adapter
.
id
]
=
adapter
return
True
return
False
def
set_adapter_mapping
(
mapping
:
Any
,
last_mapping
:
Any
,
set_mapping_func
:
Callable
)
->
Any
:
if
last_mapping
!=
mapping
:
set_mapping_func
(
mapping
)
return
mapping
return
last_mapping
def
remove_adapter
(
adapter_id
:
int
,
registered_adapters
:
Dict
[
int
,
Any
],
deactivate_func
:
Callable
)
->
bool
:
deactivate_func
(
adapter_id
)
return
bool
(
registered_adapters
.
pop
(
adapter_id
,
None
))
def
list_adapters
(
registered_adapters
:
Dict
[
int
,
Any
])
->
Dict
[
int
,
Any
]:
return
dict
(
registered_adapters
)
def
get_adapter
(
adapter_id
:
int
,
registered_adapters
:
Dict
[
int
,
Any
])
->
Optional
[
Any
]:
return
registered_adapters
.
get
(
adapter_id
,
None
)
## worker functions
def
set_active_adapters_worker
(
requests
:
Set
[
Any
],
mapping
:
Optional
[
Any
],
apply_adapters_func
,
set_adapter_mapping_func
)
->
None
:
apply_adapters_func
(
requests
)
set_adapter_mapping_func
(
mapping
)
def
add_adapter_worker
(
adapter_request
:
Any
,
list_adapters_func
,
load_adapter_func
,
add_adapter_func
,
activate_adapter_func
)
->
bool
:
if
adapter_request
.
adapter_id
in
list_adapters_func
():
return
False
loaded_adapter
=
load_adapter_func
(
adapter_request
)
loaded
=
add_adapter_func
(
loaded_adapter
)
activate_adapter_func
(
loaded_adapter
.
id
)
return
loaded
def
apply_adapters_worker
(
adapter_requests
:
Set
[
Any
],
list_adapters_func
,
adapter_slots
:
int
,
remove_adapter_func
,
add_adapter_func
)
->
None
:
models_that_exist
=
list_adapters_func
()
models_map
=
{
adapter_request
.
adapter_id
:
adapter_request
for
adapter_request
in
adapter_requests
if
adapter_request
}
if
len
(
models_map
)
>
adapter_slots
:
raise
RuntimeError
(
f
"Number of requested models (
{
len
(
models_map
)
}
) is greater "
f
"than the number of GPU model slots "
f
"(
{
adapter_slots
}
)."
)
new_models
=
set
(
models_map
)
models_to_add
=
new_models
-
models_that_exist
models_to_remove
=
models_that_exist
-
new_models
for
adapter_id
in
models_to_remove
:
remove_adapter_func
(
adapter_id
)
for
adapter_id
in
models_to_add
:
add_adapter_func
(
models_map
[
adapter_id
])
def
list_adapters_worker
(
adapter_manager_list_adapters_func
)
->
Set
[
int
]:
return
set
(
adapter_manager_list_adapters_func
())
vllm/adapter_commons/worker_manager.py
0 → 100644
View file @
705f6a35
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Optional
,
Set
import
torch
class
AbstractWorkerManager
(
ABC
):
def
__init__
(
self
,
device
:
torch
.
device
):
self
.
device
=
device
@
property
@
abstractmethod
def
is_enabled
(
self
)
->
bool
:
...
@
abstractmethod
def
set_active_adapters
(
self
,
requests
:
Set
[
Any
],
mapping
:
Optional
[
Any
])
->
None
:
...
@
abstractmethod
def
add_adapter
(
self
,
adapter_request
:
Any
)
->
bool
:
...
@
abstractmethod
def
remove_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
...
@
abstractmethod
def
remove_all_adapters
(
self
):
...
@
abstractmethod
def
list_adapters
(
self
)
->
Set
[
int
]:
...
vllm/attention/backends/abstract.py
View file @
705f6a35
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
fields
from
dataclasses
import
dataclass
,
fields
from
enum
import
Enum
,
auto
from
typing
import
(
Any
,
Dict
,
Generic
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
from
typing
import
(
Any
,
Dict
,
Generic
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
TypeVar
)
TypeVar
)
import
torch
import
torch
class
AttentionType
(
Enum
):
DECODER
=
auto
()
# Decoder attention between previous layer Q/K/V
ENCODER
=
auto
()
# Encoder attention between previous layer Q/K/V
ENCODER_DECODER
=
auto
()
# Attention between dec. Q and enc. K/V
class
AttentionBackend
(
ABC
):
class
AttentionBackend
(
ABC
):
"""Abstract class for attention backends."""
"""Abstract class for attention backends."""
...
@@ -21,9 +28,13 @@ class AttentionBackend(ABC):
...
@@ -21,9 +28,13 @@ class AttentionBackend(ABC):
@
staticmethod
@
staticmethod
@
abstractmethod
@
abstractmethod
def
make
_metadata
(
*
args
,
**
kwargs
)
->
"AttentionMetadata"
:
def
get
_metadata
_cls
()
->
Type
[
"AttentionMetadata"
]
:
raise
NotImplementedError
raise
NotImplementedError
@
classmethod
def
make_metadata
(
cls
,
*
args
,
**
kwargs
)
->
"AttentionMetadata"
:
return
cls
.
get_metadata_cls
()(
*
args
,
**
kwargs
)
@
staticmethod
@
staticmethod
@
abstractmethod
@
abstractmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
...
@@ -124,5 +135,6 @@ class AttentionImpl(ABC, Generic[T]):
...
@@ -124,5 +135,6 @@ class AttentionImpl(ABC, Generic[T]):
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
T
,
attn_metadata
:
T
,
kv_scale
:
float
=
1.0
,
kv_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
vllm/attention/backends/blocksparse_attn.py
View file @
705f6a35
...
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
...
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import
torch
import
torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.ops.blocksparse_attention.interface
import
(
from
vllm.attention.ops.blocksparse_attention.interface
import
(
LocalStridedBlockSparseAttn
,
get_head_sliding_step
)
LocalStridedBlockSparseAttn
,
get_head_sliding_step
)
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.attention.ops.paged_attn
import
PagedAttention
...
@@ -90,8 +90,8 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
...
@@ -90,8 +90,8 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
return
BlocksparseFlashAttentionImpl
return
BlocksparseFlashAttentionImpl
@
staticmethod
@
staticmethod
def
make
_metadata
(
*
args
,
**
kwargs
)
->
"BlocksparseFlash
AttentionMetadata"
:
def
get
_metadata
_cls
()
->
Type
[
"
AttentionMetadata"
]
:
return
BlocksparseFlashAttentionMetadata
(
*
args
,
**
kwargs
)
return
BlocksparseFlashAttentionMetadata
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
...
@@ -328,6 +328,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
...
@@ -328,6 +328,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
BlocksparseFlashAttentionMetadata
,
attn_metadata
:
BlocksparseFlashAttentionMetadata
,
kv_scale
:
float
=
1.0
,
kv_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
"""Forward pass with FlashAttention and PagedAttention.
...
@@ -340,6 +341,12 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
...
@@ -340,6 +341,12 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"BlocksparseFlashAttentionImpl"
)
num_tokens
,
hidden_size
=
query
.
shape
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
...
...
vllm/attention/backends/flash_attn.py
View file @
705f6a35
...
@@ -7,7 +7,7 @@ from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
...
@@ -7,7 +7,7 @@ from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
AttentionMetadata
,
AttentionType
)
class
FlashAttentionBackend
(
AttentionBackend
):
class
FlashAttentionBackend
(
AttentionBackend
):
...
@@ -25,8 +25,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -25,8 +25,8 @@ class FlashAttentionBackend(AttentionBackend):
return
FlashAttentionImpl
return
FlashAttentionImpl
@
staticmethod
@
staticmethod
def
make
_metadata
(
*
args
,
**
kwargs
)
->
"Flash
AttentionMetadata"
:
def
get
_metadata
_cls
()
->
Type
[
"
AttentionMetadata"
]
:
return
FlashAttentionMetadata
(
*
args
,
**
kwargs
)
return
FlashAttentionMetadata
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
...
@@ -83,7 +83,7 @@ class FlashAttentionMetadata(AttentionMetadata):
...
@@ -83,7 +83,7 @@ class FlashAttentionMetadata(AttentionMetadata):
# |---------------- N iteration ---------------------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------
-
|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# |-- query_len ---|
# Maximum query length in the batch. None for decoding.
# Maximum query length in the batch. None for decoding.
...
@@ -257,6 +257,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -257,6 +257,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashAttentionMetadata
,
attn_metadata
:
FlashAttentionMetadata
,
kv_scale
:
float
=
1.0
,
kv_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention.
"""Forward pass with FlashAttention.
...
@@ -269,6 +270,12 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -269,6 +270,12 @@ class FlashAttentionImpl(AttentionImpl):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl"
)
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert
kv_scale
==
1.0
,
"kv_scale is not supported in FlashAttention."
assert
kv_scale
==
1.0
,
"kv_scale is not supported in FlashAttention."
...
@@ -317,7 +324,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -317,7 +324,7 @@ class FlashAttentionImpl(AttentionImpl):
# normal attention
# normal attention
# When block_tables are not filled, it means q and k are the
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
# prompt, and they have the same length.
flash_attn_varlen_func
(
out
=
flash_attn_varlen_func
(
q
=
query
,
q
=
query
,
k
=
key
,
k
=
key
,
v
=
value
,
v
=
value
,
...
@@ -329,13 +336,14 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -329,13 +336,14 @@ class FlashAttentionImpl(AttentionImpl):
causal
=
True
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
self
.
alibi_slopes
,
out
=
output
[:
num_prefill_tokens
],
)
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
else
:
else
:
# prefix-enabled attention
# prefix-enabled attention
assert
prefill_meta
.
seq_lens
is
not
None
assert
prefill_meta
.
seq_lens
is
not
None
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
flash_attn_varlen_func
(
output
[:
num_prefill_tokens
]
=
flash_attn_varlen_func
(
q
=
query
,
q
=
query
,
k
=
key_cache
,
k
=
key_cache
,
v
=
value_cache
,
v
=
value_cache
,
...
@@ -347,12 +355,11 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -347,12 +355,11 @@ class FlashAttentionImpl(AttentionImpl):
causal
=
True
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
self
.
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
block_table
=
prefill_meta
.
block_tables
,
out
=
output
[:
num_prefill_tokens
],
)
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
# Decoding run.
flash_attn_with_kvcache
(
output
[
num_prefill_tokens
:]
=
flash_attn_with_kvcache
(
decode_query
.
unsqueeze
(
1
),
decode_query
.
unsqueeze
(
1
),
key_cache
,
key_cache
,
value_cache
,
value_cache
,
...
@@ -361,8 +368,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -361,8 +368,7 @@ class FlashAttentionImpl(AttentionImpl):
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
self
.
alibi_slopes
,
out
=
output
[
num_prefill_tokens
:].
unsqueeze
(
1
),
).
squeeze
(
1
)
)
# Reshape the output tensor.
# Reshape the output tensor.
return
output
.
view
(
num_tokens
,
hidden_size
)
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/attention/backends/flashinfer.py
View file @
705f6a35
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
import
flashinfer
try
:
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.prefill
import
BatchPrefillWithPagedKVCacheWrapper
from
vllm_flash_attn
import
flash_attn_varlen_func
except
ImportError
:
flash_attn_varlen_func
=
None
BatchDecodeWithPagedKVCacheWrapper
=
None
BatchPrefillWithPagedKVCacheWrapper
=
None
import
torch
import
torch
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
vllm_flash_attn
import
flash_attn_varlen_func
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
AttentionMetadata
,
AttentionType
)
class
FlashInferBackend
(
AttentionBackend
):
class
FlashInferBackend
(
AttentionBackend
):
...
@@ -22,8 +28,8 @@ class FlashInferBackend(AttentionBackend):
...
@@ -22,8 +28,8 @@ class FlashInferBackend(AttentionBackend):
return
FlashInferImpl
return
FlashInferImpl
@
staticmethod
@
staticmethod
def
make
_metadata
(
*
args
,
**
kwargs
)
->
"FlashInfer
Metadata"
:
def
get
_metadata
_cls
()
->
Type
[
"Attention
Metadata"
]
:
return
FlashInferMetadata
(
*
args
,
**
kwargs
)
return
FlashInferMetadata
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
...
@@ -60,19 +66,16 @@ class FlashInferMetadata(AttentionMetadata):
...
@@ -60,19 +66,16 @@ class FlashInferMetadata(AttentionMetadata):
# requests only.
# requests only.
max_prefill_seq_len
:
int
max_prefill_seq_len
:
int
use_cuda_graph
:
bool
=
Fals
e
use_cuda_graph
:
bool
=
Tru
e
prefill_wrapper
:
Optional
[
BatchPrefillWithPagedKVCacheWrapper
]
=
None
decode_wrapper
:
Optional
[
BatchDecodeWithPagedKVCacheWrapper
]
=
None
decode_wrapper
:
Optional
[
BatchDecodeWithPagedKVCacheWrapper
]
=
None
# Metadata for the prefill stage since we still
# Metadata for the prefill stage
# use flash attention for prefill.
seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
block_tables
:
Optional
[
torch
.
Tensor
]
=
None
block_tables
:
Optional
[
torch
.
Tensor
]
=
None
# Metadata for the decode stage
# Workspace buffer required by the kernel, the buffer should not
# be allocated/deacollated by the FalshInfermetadata object.
workspace_buffer
:
Optional
[
torch
.
Tensor
]
=
None
# An example for paged_kv_indices, paged_kv_indptr:
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 2, page indices [1, 6, 7]
...
@@ -98,6 +101,9 @@ class FlashInferMetadata(AttentionMetadata):
...
@@ -98,6 +101,9 @@ class FlashInferMetadata(AttentionMetadata):
page_size
:
Optional
[
int
]
=
None
page_size
:
Optional
[
int
]
=
None
# The data type of the paged kv cache
# The data type of the paged kv cache
data_type
:
torch
.
dtype
=
None
data_type
:
torch
.
dtype
=
None
device
:
torch
.
device
=
torch
.
device
(
"cuda"
)
# Only used by gemma2 model
logits_soft_cap
:
Optional
[
float
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Refer to
# Refer to
...
@@ -109,13 +115,37 @@ class FlashInferMetadata(AttentionMetadata):
...
@@ -109,13 +115,37 @@ class FlashInferMetadata(AttentionMetadata):
f
"Only
{
supported_head_sizes
}
are supported for head_dim,"
,
f
"Only
{
supported_head_sizes
}
are supported for head_dim,"
,
f
"received
{
self
.
head_dim
}
."
)
f
"received
{
self
.
head_dim
}
."
)
# When using flashinfer, we are also creating the FlashInferMetadata,
def
begin_forward
(
self
):
# which will also call post_init by default, here we want to skip the
if
self
.
num_prefill_tokens
>
0
:
# post_init if it's the prefill phase.
if
self
.
paged_kv_indices
is
None
:
if
self
.
num_prefills
==
0
:
return
assert
self
.
num_decode_tokens
>
0
self
.
decode_wrapper
=
flashinfer
.
BatchDecodeWithPagedKVCacheWrapper
(
assert
self
.
prefill_wrapper
is
not
None
self
.
workspace_buffer
,
"NHD"
)
assert
self
.
paged_kv_indices
is
not
None
assert
self
.
paged_kv_indptr
is
not
None
assert
self
.
paged_kv_last_page_len
is
not
None
self
.
paged_kv_indices
=
self
.
paged_kv_indices
.
to
(
self
.
device
)
self
.
paged_kv_indptr
=
self
.
paged_kv_indptr
.
to
(
self
.
device
)
self
.
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
.
to
(
self
.
device
)
self
.
prefill_wrapper
.
end_forward
()
self
.
prefill_wrapper
.
begin_forward
(
self
.
query_start_loc
,
self
.
paged_kv_indptr
,
self
.
paged_kv_indices
,
self
.
paged_kv_last_page_len
,
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
page_size
)
else
:
if
not
self
.
use_cuda_graph
:
assert
self
.
paged_kv_indices
is
not
None
assert
self
.
paged_kv_indptr
is
not
None
assert
self
.
paged_kv_last_page_len
is
not
None
self
.
paged_kv_indices
=
self
.
paged_kv_indices
.
to
(
self
.
device
)
self
.
paged_kv_indptr
=
self
.
paged_kv_indptr
.
to
(
self
.
device
)
self
.
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
.
to
(
self
.
device
)
assert
self
.
decode_wrapper
is
not
None
self
.
decode_wrapper
.
end_forward
()
self
.
decode_wrapper
.
begin_forward
(
self
.
decode_wrapper
.
begin_forward
(
self
.
paged_kv_indptr
,
self
.
paged_kv_indptr
,
self
.
paged_kv_indices
,
self
.
paged_kv_indices
,
...
@@ -133,8 +163,9 @@ class FlashInferMetadata(AttentionMetadata):
...
@@ -133,8 +163,9 @@ class FlashInferMetadata(AttentionMetadata):
)
->
Dict
[
str
,
Any
]:
)
->
Dict
[
str
,
Any
]:
if
skip_fields
is
None
:
if
skip_fields
is
None
:
skip_fields
=
set
()
skip_fields
=
set
()
# We need to skip the decode_wrapper field since it cannot be
# We need to skip the
prefill/
decode_wrapper field since it cannot be
# broadcasted with nccl when TP is enabled.
# broadcasted with nccl when TP is enabled.
skip_fields
.
add
(
'prefill_wrapper'
)
skip_fields
.
add
(
'decode_wrapper'
)
skip_fields
.
add
(
'decode_wrapper'
)
return
super
().
asdict_zerocopy
(
skip_fields
)
return
super
().
asdict_zerocopy
(
skip_fields
)
...
@@ -168,6 +199,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -168,6 +199,7 @@ class FlashInferImpl(AttentionImpl):
alibi_slopes
:
Optional
[
List
[
float
]],
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
None
:
)
->
None
:
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
...
@@ -192,8 +224,14 @@ class FlashInferImpl(AttentionImpl):
...
@@ -192,8 +224,14 @@ class FlashInferImpl(AttentionImpl):
kv_cache
:
Optional
[
torch
.
Tensor
],
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
FlashInferMetadata
,
attn_metadata
:
FlashInferMetadata
,
kv_scale
:
float
=
1.0
,
kv_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
kv_scale
==
1.0
assert
kv_scale
==
1.0
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl"
)
num_tokens
,
hidden_size
=
query
.
shape
num_tokens
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
...
@@ -217,10 +255,14 @@ class FlashInferImpl(AttentionImpl):
...
@@ -217,10 +255,14 @@ class FlashInferImpl(AttentionImpl):
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
)
)
query
=
query
.
contiguous
(
)
# Flashinfer requires query to be contiguous
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
# We will use flash attention for prefill
assert
prefill_meta
.
block_tables
is
not
None
# when kv_cache is not provided.
if
kv_cache
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
# This happens when vllm runs the profiling to
# determine the number of blocks.
if
kv_cache
is
None
:
output
=
flash_attn_varlen_func
(
output
=
flash_attn_varlen_func
(
q
=
query
,
q
=
query
,
k
=
key
,
k
=
key
,
...
@@ -235,16 +277,19 @@ class FlashInferImpl(AttentionImpl):
...
@@ -235,16 +277,19 @@ class FlashInferImpl(AttentionImpl):
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
self
.
alibi_slopes
,
)
)
else
:
else
:
raise
NotImplementedError
(
assert
prefill_meta
is
not
None
"Prefix caching is not supported with flashinfer yet."
)
assert
prefill_meta
.
prefill_wrapper
is
not
None
output
=
prefill_meta
.
prefill_wrapper
.
forward
(
query
,
kv_cache
,
logits_soft_cap
=
attn_metadata
.
logits_soft_cap
,
causal
=
True
)
else
:
else
:
assert
attn_metadata
.
decode_metadata
is
not
None
assert
attn_metadata
.
decode_metadata
is
not
None
assert
attn_metadata
.
decode_metadata
.
decode_wrapper
is
not
None
assert
attn_metadata
.
decode_metadata
.
decode_wrapper
is
not
None
query
=
query
.
contiguous
(
)
# Flashinfer requires query to be contiguous
output
=
attn_metadata
.
decode_metadata
.
decode_wrapper
.
forward
(
output
=
attn_metadata
.
decode_metadata
.
decode_wrapper
.
forward
(
query
,
query
,
kv_cache
,
kv_cache
,
sm_scale
=
self
.
scale
,
sm_scale
=
self
.
scale
,
)
logits_soft_cap
=
attn_metadata
.
logits_soft_cap
)
return
output
.
view
(
num_tokens
,
hidden_size
)
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/attention/backends/ipex_attn.py
0 → 100644
View file @
705f6a35
""" Attention layer with torch scaled_dot_product_attention
and PagedAttention."""
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
vllm._ipex_ops
import
ipex_ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
_PARTITION_SIZE
=
512
class
IpexAttnBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"ipex-attn"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"IpexAttnBackendImpl"
]:
return
IpexAttnBackendImpl
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"IpexAttnMetadata"
]:
return
IpexAttnMetadata
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
PagedAttention
.
get_kv_cache_shape
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
dataclass
class
IpexAttnMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
"""Metadata for IpexAttnBackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
slot_mapping
:
torch
.
Tensor
seq_lens
:
Optional
[
List
[
int
]]
seqlen_q
:
Optional
[
torch
.
Tensor
]
max_seqlen
:
Optional
[
int
]
def
__post_init__
(
self
):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self
.
attn_bias
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"IpexAttnMetadata"
]:
# Currently chunked prefill is not supported
if
self
.
num_decode_tokens
==
0
:
assert
self
.
num_prefills
>
0
return
self
return
None
@
property
def
decode_metadata
(
self
)
->
Optional
[
"IpexAttnMetadata"
]:
# Currently chunked prefill is not supported
if
self
.
num_prefills
>
0
:
assert
self
.
num_decode_tokens
==
0
return
None
return
self
class
IpexAttnBackendImpl
(
AttentionImpl
[
IpexAttnMetadata
]):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
None
:
assert
blocksparse_params
is
None
,
ValueError
(
"Torch SPDA does not support block-sparse attention."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
sliding_window
=
sliding_window
self
.
kv_cache_dtype
=
kv_cache_dtype
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
need_mask
=
(
self
.
alibi_slopes
is
not
None
or
self
.
sliding_window
is
not
None
)
supported_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
if
head_size
not
in
supported_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
supported_head_sizes
}
."
)
if
kv_cache_dtype
!=
"auto"
:
raise
NotImplementedError
(
"IPEX backend does not support FP8 KV cache. "
"Please use xFormers backend instead."
)
def
split_kv_cache
(
self
,
kv_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
x
=
1
num_blocks
=
kv_cache
.
shape
[
1
]
key_cache
=
kv_cache
[
0
]
key_cache
=
key_cache
.
view
(
num_blocks
,
num_kv_heads
,
head_size
//
x
,
-
1
,
x
)
value_cache
=
kv_cache
[
1
]
value_cache
=
value_cache
.
view
(
num_blocks
,
num_kv_heads
,
head_size
,
-
1
)
return
key_cache
,
value_cache
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
IpexAttnMetadata
,
# type: ignore
kv_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
"""Forward pass with IPEX varlen_attention and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert
kv_scale
==
1.0
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"IpexAttnBackendImpl"
)
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
is
not
None
:
key_cache
,
value_cache
=
self
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
ipex_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
kv_scale
,
)
if
attn_metadata
.
is_prompt
:
assert
attn_metadata
.
seq_lens
is
not
None
if
(
kv_cache
is
None
or
attn_metadata
.
block_tables
.
numel
()
==
0
):
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
if
attn_metadata
.
attn_bias
is
None
:
if
self
.
alibi_slopes
is
not
None
:
att_masks
=
_make_alibi_bias
(
self
.
alibi_slopes
,
query
.
dtype
,
attn_metadata
.
seq_lens
)
# type: ignore
elif
self
.
sliding_window
is
not
None
:
att_masks
=
_make_sliding_window_bias
(
attn_metadata
.
seq_lens
,
self
.
sliding_window
,
query
.
dtype
)
# type: ignore
else
:
att_masks
=
_make_sliding_window_bias
(
attn_metadata
.
seq_lens
,
None
,
dtype
=
query
.
dtype
)
attn_metadata
.
attn_bias
=
att_masks
output
=
torch
.
empty
(
(
num_tokens
,
self
.
num_heads
,
self
.
head_size
),
dtype
=
query
.
dtype
,
device
=
query
.
device
)
ipex_ops
.
varlen_attention
(
query
,
key
,
value
,
output
,
attn_metadata
.
seqlen_q
,
attn_metadata
.
seqlen_q
,
attn_metadata
.
max_seqlen
,
attn_metadata
.
max_seqlen
,
pdropout
=
0.0
,
softmax_scale
=
self
.
scale
,
zero_tensors
=
False
,
is_causal
=
True
,
return_softmax
=
False
,
gen_
=
None
)
else
:
# prefix-enabled attention
raise
RuntimeError
(
"IPEX backend doesn't support prefix decoding."
)
else
:
# Decoding run.
max_seq_len
=
attn_metadata
.
max_decode_seq_len
output
=
torch
.
empty_like
(
query
)
block_size
=
value_cache
.
shape
[
3
]
num_seqs
,
num_heads
,
head_size
=
query
.
shape
max_num_partitions
=
((
max_seq_len
+
_PARTITION_SIZE
-
1
)
//
_PARTITION_SIZE
)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory
# shortage.
use_v1
=
(
max_seq_len
<=
8192
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
))
if
use_v1
:
# Run PagedAttention V1.
ipex_ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
self
.
num_kv_heads
,
self
.
scale
,
attn_metadata
.
block_tables
,
attn_metadata
.
seq_lens_tensor
,
block_size
,
max_seq_len
,
self
.
alibi_slopes
,
self
.
kv_cache_dtype
,
kv_scale
,
)
else
:
# Run PagedAttention V2.
assert
_PARTITION_SIZE
%
block_size
==
0
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
ipex_ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
self
.
num_kv_heads
,
self
.
scale
,
attn_metadata
.
block_tables
,
attn_metadata
.
seq_lens_tensor
,
block_size
,
max_seq_len
,
self
.
alibi_slopes
,
self
.
kv_cache_dtype
,
kv_scale
,
)
# Reshape the output tensor.
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
seq_lens
:
List
[
int
],
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
for
seq_len
in
seq_lens
:
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
,
device
=
alibi_slopes
.
device
)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
num_heads
=
alibi_slopes
.
shape
[
0
]
bias
=
bias
[
None
,
:].
repeat
((
num_heads
,
1
,
1
))
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
inf_mask
=
torch
.
empty
(
(
1
,
seq_len
,
seq_len
),
dtype
=
bias
.
dtype
,
device
=
alibi_slopes
.
device
).
fill_
(
-
torch
.
inf
).
triu_
(
diagonal
=
1
)
attn_biases
.
append
((
bias
+
inf_mask
).
to
(
dtype
))
return
attn_biases
def
_make_sliding_window_bias
(
seq_lens
:
List
[
int
],
window_size
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
for
seq_len
in
seq_lens
:
tensor
=
torch
.
full
(
(
1
,
seq_len
,
seq_len
),
dtype
=
dtype
,
fill_value
=
1
,
)
shift
=
0
mask
=
torch
.
tril
(
tensor
,
diagonal
=
shift
).
to
(
dtype
)
# type: ignore
if
window_size
is
not
None
:
mask
=
torch
.
triu
(
mask
,
diagonal
=
shift
-
window_size
+
1
)
mask
=
torch
.
log
(
mask
)
attn_biases
.
append
(
mask
.
to
(
dtype
))
return
attn_biases
vllm/attention/backends/openvino.py
0 → 100644
View file @
705f6a35
from
dataclasses
import
dataclass
from
typing
import
List
,
Tuple
import
openvino
as
ov
import
torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionMetadata
)
class
OpenVINOAttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"openvino"
@
staticmethod
def
get_impl_cls
():
# OpenVINO implements PagedAttention as part of the Optimum
# exported model
raise
NotImplementedError
@
staticmethod
def
make_metadata
(
*
args
,
**
kwargs
)
->
"AttentionMetadata"
:
raise
NotImplementedError
@
staticmethod
def
make_openvino_metadata
(
*
args
,
**
kwargs
)
->
"OpenVINOAttentionMetadata"
:
return
OpenVINOAttentionMetadata
(
*
args
,
**
kwargs
)
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
(
2
,
num_blocks
,
num_kv_heads
,
block_size
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
ov
.
Tensor
,
dst_kv_cache
:
ov
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
# OpenVINO currently supports only CPU, which does not require
# swap of KV cache blocks
raise
NotImplementedError
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
Tuple
[
ov
.
Tensor
,
ov
.
Tensor
]],
src_to_dists
:
List
[
Tuple
[
int
,
int
]],
)
->
None
:
for
src
,
dst
in
src_to_dists
:
for
key_cache
,
value_cache
in
kv_caches
:
key_cache
.
data
[
dst
,
:]
=
key_cache
.
data
[
src
,
:]
value_cache
.
data
[
dst
,
:]
=
value_cache
.
data
[
src
,
:]
@
dataclass
class
OpenVINOAttentionMetadata
:
"""Metadata for OpenVINOAttentionBackend.
Basic terms used below:
- batch_size_in_sequences - total number of sequences to execute
- prompt_lens – per sequence size number of scheduled tokens
- batch_size_in_tokens = sum(prompt_lens)
- max_context_len = max(context_lens)
- max_num_blocks = div_up(max_context_len / BLOCK_SIZE)
- num_blocks – total number of blocks in block_indices
"""
# Describes past KV cache size for each sequence within a batch
# Shape: [batch_size_in_sequences]
# Type: i32
past_lens
:
torch
.
Tensor
# Describes start indices of input / speculative tokens from
# current sequences within a batch sequence
# Shape: [batch_size_in_sequences + 1]
# Type: i32
subsequence_begins
:
torch
.
Tensor
# Describes block tables for each sequence within a batch -
# indices along 0th dimension in key_cache and value_cache inputs
# Shape: [num_blocks]
# Type: i32
block_indices
:
torch
.
Tensor
# Describes block tables for each sequence within a batch -
# for i-th element, it is an index in block_indices with the
# first block belonging to i-th sequence
# Shape: [batch_size_in_sequences + 1]
# Type: i32
block_indices_begins
:
torch
.
Tensor
# Describes max context length
# Shape: scalar
# Type: i32
max_context_len
:
torch
.
Tensor
vllm/attention/backends/pallas.py
0 → 100644
View file @
705f6a35
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch_xla.experimental.custom_kernel
# Required to register custom ops.
import
torch_xla.experimental.dynamo_set_buffer_donor
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
class
PallasAttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_impl_cls
()
->
Type
[
"PallasAttentionBackendImpl"
]:
return
PallasAttentionBackendImpl
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"PallasMetadata"
]:
return
PallasMetadata
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
(
num_kv_heads
,
num_blocks
,
block_size
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
raise
RuntimeError
(
"swap_blocks is not used for the TPU backend."
)
@
torch
.
compile
(
backend
=
"openxla"
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
src_to_dists
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
)
->
None
:
src_indices
,
dst_indices
=
src_to_dists
for
k_cache
,
v_cache
in
kv_caches
:
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
k_cache
,
True
)
k_cache
[:,
dst_indices
]
=
k_cache
[:,
src_indices
]
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
v_cache
,
True
)
v_cache
[:,
dst_indices
]
=
v_cache
[:,
src_indices
]
@
dataclass
class
PallasMetadata
(
AttentionMetadata
):
# Currently, input sequences can only contain all prefills
# or all decoding.
block_tables
:
Optional
[
torch
.
Tensor
]
context_lens
:
Optional
[
torch
.
Tensor
]
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"PallasMetadata"
]:
if
self
.
num_prefills
==
0
:
return
None
assert
self
.
num_decode_tokens
==
0
assert
self
.
block_tables
is
None
assert
self
.
context_lens
is
None
return
self
@
property
def
decode_metadata
(
self
)
->
Optional
[
"PallasMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
assert
self
.
num_prefills
==
0
assert
self
.
num_prefill_tokens
==
0
assert
self
.
block_tables
is
not
None
assert
self
.
context_lens
is
not
None
return
self
class
PallasAttentionBackendImpl
(
AttentionImpl
):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
None
:
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
if
head_size
%
128
!=
0
:
raise
NotImplementedError
(
"Head size must be a multiple of 128."
)
if
alibi_slopes
is
not
None
:
raise
NotImplementedError
(
"Alibi slopes is not supported."
)
if
sliding_window
is
not
None
:
raise
NotImplementedError
(
"Sliding window is not supported."
)
if
kv_cache_dtype
!=
"auto"
:
raise
NotImplementedError
(
"FP8 KV cache dtype is not supported."
)
if
blocksparse_params
is
not
None
:
raise
NotImplementedError
(
"Blocksparse is not supported."
)
if
torch_xla
.
tpu
.
version
()
<
4
:
raise
NotImplementedError
(
"TPU version must be 4 or higher."
)
self
.
megacore_mode
=
None
tpu_type
=
torch_xla
.
tpu
.
get_tpu_env
()[
"TYPE"
].
lower
()
if
"lite"
not
in
tpu_type
:
if
self
.
num_kv_heads
%
2
==
0
:
self
.
megacore_mode
=
"kv_head"
else
:
# NOTE(woosuk): If the batch size is not a multiple of 2, the
# megacore mode will be None.
self
.
megacore_mode
=
"batch"
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Tuple
[
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]],
attn_metadata
:
PallasMetadata
,
kv_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
"""Forward pass with Pallas attention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache = [num_kv_heads, num_blocks, block_size, head_size]
value_cache = [num_kv_heads, num_blocks, block_size, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
assert
kv_scale
==
1.0
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl"
)
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
batch_size
,
seq_len
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
batch_size
,
seq_len
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
[
0
]
is
not
None
:
slot_mapping
=
attn_metadata
.
slot_mapping
key_cache
,
value_cache
=
kv_cache
write_to_kv_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
)
query
=
query
*
self
.
scale
if
attn_metadata
.
num_prefills
>
0
:
assert
seq_len
%
16
==
0
,
(
"Pallas FlashAttention kernel requires seq_len to be a "
f
"multiple of 16 but got
{
seq_len
}
"
)
# Handle GQA/MQA.
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=-
2
)
key
=
key
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=-
2
)
value
=
value
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
# FlashAttention requires [batch_size, num_heads, seq_len, d_model]
# while the input is [batch_size, seq_len, num_heads, d_model].
# Permute the input to match the required format.
output
=
torch
.
ops
.
xla
.
flash_attention
(
query
.
permute
(
0
,
2
,
1
,
3
),
key
.
permute
(
0
,
2
,
1
,
3
),
value
.
permute
(
0
,
2
,
1
,
3
),
True
,
)
output
=
output
.
permute
(
0
,
2
,
1
,
3
)
else
:
# Decoding run.
assert
kv_cache
is
not
None
pages_per_compute_block
=
16
# TODO(woosuk): Tune this value.
if
self
.
megacore_mode
==
"batch"
and
batch_size
%
2
!=
0
:
megacore_mode
=
None
else
:
megacore_mode
=
self
.
megacore_mode
# NOTE(woosuk): A temporary workaround to avoid the error:
# "xla::paged_attention() Expected a value of type 'str' for
# argument 'megacore_mode' but instead found type 'NoneType'."
if
megacore_mode
is
not
None
:
output
=
torch
.
ops
.
xla
.
paged_attention
(
query
.
squeeze
(
dim
=
1
),
key_cache
,
value_cache
,
attn_metadata
.
context_lens
,
attn_metadata
.
block_tables
,
pages_per_compute_block
,
megacore_mode
=
megacore_mode
,
)
else
:
output
=
torch
.
ops
.
xla
.
paged_attention
(
query
.
squeeze
(
dim
=
1
),
key_cache
,
value_cache
,
attn_metadata
.
context_lens
,
attn_metadata
.
block_tables
,
pages_per_compute_block
,
)
# Reshape the output tensor.
return
output
.
reshape
(
batch_size
,
seq_len
,
hidden_size
)
def
write_to_kv_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
key_cache
,
True
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
value_cache
,
True
)
key
=
key
.
flatten
(
0
,
2
)
value
=
value
.
flatten
(
0
,
2
)
key_cache
=
key_cache
.
flatten
(
0
,
2
)
value_cache
=
value_cache
.
flatten
(
0
,
2
)
key_cache
.
index_copy_
(
0
,
slot_mapping
,
key
)
value_cache
.
index_copy_
(
0
,
slot_mapping
,
value
)
vllm/attention/backends/rocm_flash_attn.py
View file @
705f6a35
...
@@ -6,7 +6,7 @@ import torch
...
@@ -6,7 +6,7 @@ import torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -25,8 +25,8 @@ class ROCmFlashAttentionBackend(AttentionBackend):
...
@@ -25,8 +25,8 @@ class ROCmFlashAttentionBackend(AttentionBackend):
return
ROCmFlashAttentionImpl
return
ROCmFlashAttentionImpl
@
staticmethod
@
staticmethod
def
make
_metadata
(
*
args
,
**
kwargs
)
->
"ROCmFlash
AttentionMetadata"
:
def
get
_metadata
_cls
()
->
Type
[
"
AttentionMetadata"
]
:
return
ROCmFlashAttentionMetadata
(
*
args
,
**
kwargs
)
return
ROCmFlashAttentionMetadata
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
...
@@ -166,6 +166,37 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -166,6 +166,37 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
return
self
.
_cached_decode_metadata
return
self
.
_cached_decode_metadata
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
seq_lens
:
Optional
[
List
[
int
]],
make_attn_mask
:
bool
=
True
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
if
seq_lens
:
for
seq_len
in
seq_lens
:
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
num_heads
=
alibi_slopes
.
shape
[
0
]
bias
=
bias
[
None
,
:].
repeat
(
(
num_heads
,
1
,
1
)).
to
(
alibi_slopes
.
device
)
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
if
make_attn_mask
:
inf_mask
=
torch
.
empty
(
(
1
,
seq_len
,
seq_len
),
dtype
=
bias
.
dtype
).
fill_
(
-
torch
.
inf
).
triu_
(
diagonal
=
1
).
to
(
alibi_slopes
.
device
)
attn_biases
.
append
((
bias
+
inf_mask
).
to
(
dtype
))
else
:
attn_biases
.
append
(
bias
.
to
(
dtype
))
return
attn_biases
class
ROCmFlashAttentionImpl
(
AttentionImpl
):
class
ROCmFlashAttentionImpl
(
AttentionImpl
):
"""
"""
If the input tensors contain prompt tokens, the layout is as follows:
If the input tensors contain prompt tokens, the layout is as follows:
...
@@ -229,11 +260,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -229,11 +260,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self
.
use_triton_flash_attn
=
envs
.
VLLM_USE_TRITON_FLASH_ATTN
self
.
use_triton_flash_attn
=
envs
.
VLLM_USE_TRITON_FLASH_ATTN
if
self
.
use_triton_flash_attn
:
if
self
.
use_triton_flash_attn
:
#
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
from
vllm.attention.ops.triton_flash_attention
import
(
# noqa: F401
#
triton_attention)
triton_attention
)
from
vllm.attention.ops.flash_attn_triton_mqa_gqa
import
(
#
from vllm.attention.ops.flash_attn_triton_mqa_gqa import (
flash_attn_varlen_func
)
#
flash_attn_varlen_func)
self
.
attn_func
=
flash_attn_varlen_func
# triton_attention
self
.
attn_func
=
triton_attention
#
flash_attn_varlen_func
logger
.
debug
(
"Using Triton FA in ROCmBackend"
)
logger
.
debug
(
"Using Triton FA in ROCmBackend"
)
else
:
else
:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
...
@@ -268,6 +299,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -268,6 +299,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
ROCmFlashAttentionMetadata
,
attn_metadata
:
ROCmFlashAttentionMetadata
,
kv_scale
:
float
=
1.0
,
kv_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
"""Forward pass with FlashAttention and PagedAttention.
...
@@ -280,6 +312,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -280,6 +312,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"ROCmFlashAttentionImpl"
)
num_tokens
,
hidden_size
=
query
.
shape
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
...
@@ -326,34 +364,36 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -326,34 +364,36 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# triton attention
# triton attention
# When block_tables are not filled, it means q and k are the
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
# prompt, and they have the same length.
attn_masks
=
None
if
self
.
use_triton_flash_attn
:
if
self
.
use_triton_flash_attn
:
# out, _ = self.attn_func(
if
self
.
alibi_slopes
is
not
None
:
# query,
attn_masks
=
_make_alibi_bias
(
# key,
self
.
alibi_slopes
,
# value,
query
.
dtype
,
# None,
attn_metadata
.
seq_lens
,
# prefill_meta.seq_start_loc,
make_attn_mask
=
False
)
# type: ignore
# prefill_meta.seq_start_loc,
# prefill_meta.max_prefill_seq_len,
# prefill_meta.max_prefill_seq_len,
# True,
# self.scale,
out
=
self
.
attn_func
(
out
=
self
.
attn_func
(
q
=
query
,
query
,
k
=
key
,
key
,
v
=
value
,
value
,
cu_seqlens_q
=
prefill_meta
.
seq_
start_loc
,
prefill_meta
.
seq_
lens
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
num_tokens
,
max_seqlens_q
=
prefill_meta
.
max_prefill_seq_len
,
self
.
num_heads
,
max_seqlens_k
=
prefill_meta
.
max_prefill_seq_len
,
self
.
head_size
,
softmax_scale
=
self
.
scale
,
self
.
scale
,
causal
=
True
,
attn_masks
,
)
)
elif
self
.
use_naive_attn
:
elif
self
.
use_naive_attn
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# Interleave for MQA workaround.
# Interleave for MQA workaround.
key
=
self
.
repeat_kv
(
key
,
self
.
num_queries_per_kv
)
key
=
self
.
repeat_kv
(
key
,
self
.
num_queries_per_kv
)
value
=
self
.
repeat_kv
(
value
,
self
.
num_queries_per_kv
)
value
=
self
.
repeat_kv
(
value
,
self
.
num_queries_per_kv
)
if
self
.
alibi_slopes
is
not
None
:
attn_masks
=
_make_alibi_bias
(
self
.
alibi_slopes
,
query
.
dtype
,
attn_metadata
.
seq_lens
,
make_attn_mask
=
True
)
# type: ignore
query
=
query
.
movedim
(
0
,
query
.
dim
()
-
2
)
query
=
query
.
movedim
(
0
,
query
.
dim
()
-
2
)
key
=
key
.
movedim
(
0
,
key
.
dim
()
-
2
)
key
=
key
.
movedim
(
0
,
key
.
dim
()
-
2
)
value
=
value
.
movedim
(
0
,
value
.
dim
()
-
2
)
value
=
value
.
movedim
(
0
,
value
.
dim
()
-
2
)
...
@@ -367,6 +407,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -367,6 +407,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
num_heads
,
self
.
num_heads
,
self
.
head_size
,
self
.
head_size
,
self
.
scale
,
self
.
scale
,
attn_masks
,
)
)
else
:
else
:
out
=
self
.
attn_func
(
out
=
self
.
attn_func
(
...
@@ -430,13 +471,14 @@ def _sdpa_attention(
...
@@ -430,13 +471,14 @@ def _sdpa_attention(
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
scale
:
float
,
scale
:
float
,
attn_masks
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
start
=
0
start
=
0
output
=
torch
.
empty
((
num_tokens
,
num_heads
,
head_size
),
output
=
torch
.
empty
((
num_tokens
,
num_heads
,
head_size
),
dtype
=
query
.
dtype
,
dtype
=
query
.
dtype
,
device
=
query
.
device
)
device
=
query
.
device
)
for
seq_len
in
seq_lens
:
for
i
,
seq_len
in
enumerate
(
seq_lens
)
:
end
=
start
+
seq_len
end
=
start
+
seq_len
with
torch
.
backends
.
cuda
.
sdp_kernel
(
enable_math
=
True
,
with
torch
.
backends
.
cuda
.
sdp_kernel
(
enable_math
=
True
,
enable_flash
=
False
,
enable_flash
=
False
,
...
@@ -446,7 +488,8 @@ def _sdpa_attention(
...
@@ -446,7 +488,8 @@ def _sdpa_attention(
key
[:,
start
:
end
,
:],
key
[:,
start
:
end
,
:],
value
[:,
start
:
end
,
:],
value
[:,
start
:
end
,
:],
dropout_p
=
0.0
,
dropout_p
=
0.0
,
is_causal
=
True
,
is_causal
=
attn_masks
is
None
,
attn_mask
=
attn_masks
[
i
]
if
attn_masks
else
None
,
scale
=
scale
).
movedim
(
query
.
dim
()
-
2
,
0
)
scale
=
scale
).
movedim
(
query
.
dim
()
-
2
,
0
)
output
[
start
:
end
,
:,
:]
=
sub_out
output
[
start
:
end
,
:,
:]
=
sub_out
start
=
end
start
=
end
...
...
Prev
1
…
12
13
14
15
16
17
18
19
20
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment