Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0f8cafe2
Unverified
Commit
0f8cafe2
authored
Jan 13, 2025
by
Chen Zhang
Committed by
GitHub
Jan 13, 2025
Browse files
[Kernel] unified_attention for Attention.forward (#11967)
Signed-off-by:
Chen Zhang
<
zhangch99@outlook.com
>
parent
5340a30d
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
87 additions
and
45 deletions
+87
-45
vllm/attention/layer.py
vllm/attention/layer.py
+14
-12
vllm/utils.py
vllm/utils.py
+0
-1
vllm/worker/hpu_model_runner.py
vllm/worker/hpu_model_runner.py
+11
-2
vllm/worker/hpu_worker.py
vllm/worker/hpu_worker.py
+3
-0
vllm/worker/neuron_model_runner.py
vllm/worker/neuron_model_runner.py
+10
-7
vllm/worker/openvino_model_runner.py
vllm/worker/openvino_model_runner.py
+3
-1
vllm/worker/openvino_worker.py
vllm/worker/openvino_worker.py
+11
-2
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+18
-10
vllm/worker/tpu_worker.py
vllm/worker/tpu_worker.py
+5
-1
vllm/worker/xpu_model_runner.py
vllm/worker/xpu_model_runner.py
+12
-9
No files found.
vllm/attention/layer.py
View file @
0f8cafe2
...
...
@@ -134,15 +134,10 @@ class Attention(nn.Module):
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
_
kv_cache
:
torch
.
Tensor
,
_
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
if
self
.
use_direct_call
:
return
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
self
.
_k_scale
,
self
.
_v_scale
)
elif
self
.
use_output
:
if
self
.
use_output
:
output
=
torch
.
empty_like
(
query
)
hidden_size
=
query
.
size
(
-
1
)
# Reshape the query, key, and value tensors.
...
...
@@ -154,12 +149,19 @@ class Attention(nn.Module):
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
value
is
not
None
:
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
self
.
use_direct_call
:
unified_attention_with_output
(
query
,
key
,
value
,
output
,
self
.
layer_name
)
else
:
torch
.
ops
.
vllm
.
unified_attention_with_output
(
query
,
key
,
value
,
output
,
self
.
layer_name
)
return
output
.
view
(
-
1
,
hidden_size
)
else
:
return
torch
.
ops
.
vllm
.
unified_attention
(
query
,
key
,
value
,
self
.
layer_name
)
if
self
.
use_direct_call
:
return
unified_attention
(
query
,
key
,
value
,
self
.
layer_name
)
else
:
return
torch
.
ops
.
vllm
.
unified_attention
(
query
,
key
,
value
,
self
.
layer_name
)
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
impl
.
head_size
}
"
# type: ignore
...
...
vllm/utils.py
View file @
0f8cafe2
...
...
@@ -2171,5 +2171,4 @@ def bind_kv_cache(
forward_ctx
=
ctx
[
layer_name
]
assert
len
(
forward_ctx
.
kv_cache
)
==
len
(
kv_cache
)
for
ve
,
ve_kv_cache
in
enumerate
(
kv_cache
):
assert
forward_ctx
.
kv_cache
[
ve
].
numel
()
==
0
forward_ctx
.
kv_cache
[
ve
]
=
ve_kv_cache
[
kv_cache_idx
]
vllm/worker/hpu_model_runner.py
View file @
0f8cafe2
...
...
@@ -28,6 +28,7 @@ from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.config
import
DeviceConfig
,
VllmConfig
from
vllm.distributed.parallel_state
import
get_world_group
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.request
import
LoRARequest
...
...
@@ -40,7 +41,8 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
IntermediateTensors
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
is_pin_memory_available
,
make_tensor_with_pad
from
vllm.utils
import
(
bind_kv_cache
,
is_pin_memory_available
,
make_tensor_with_pad
)
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
_add_attn_metadata_broadcastable_dict
,
...
...
@@ -1286,6 +1288,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
def
profile_run
(
self
)
->
None
:
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
kv_caches
=
[
None
]
*
num_layers
bind_kv_cache
(
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
[
kv_caches
])
max_seq_len
=
self
.
bucketing_global_state
.
prompt_seq_bucket_cfg
[
-
1
]
max_batch_size
=
min
(
self
.
max_num_batched_tokens
//
max_seq_len
,
self
.
scheduler_config
.
max_num_seqs
)
...
...
@@ -1943,7 +1948,11 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
f
"graphs
{
'T'
if
use_graphs
else
'F'
}
"
)
else
:
model_event_name
=
'model_executable'
with
self
.
profiler
.
record_event
(
'internal'
,
model_event_name
):
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
),
\
self
.
profiler
.
record_event
(
'internal'
,
model_event_name
):
hidden_states
=
self
.
model
.
forward
(
**
execute_model_kwargs
,
selected_token_indices
=
sampling_metadata
.
selected_token_indices
...
...
vllm/worker/hpu_worker.py
View file @
0f8cafe2
...
...
@@ -20,6 +20,7 @@ from vllm.lora.request import LoRARequest
from
vllm.model_executor
import
set_random_seed
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
bind_kv_cache
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.hpu_model_runner
import
HPUModelRunner
from
vllm.worker.model_runner_base
import
ModelRunnerBase
...
...
@@ -215,6 +216,8 @@ class HPUWorker(LocalOrDistributedWorkerBase):
self
.
cache_engine
[
ve
].
gpu_cache
for
ve
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
bind_kv_cache
(
self
.
compilation_config
.
static_forward_context
,
self
.
hpu_cache
)
def
_warm_up_model
(
self
)
->
None
:
# NOTE(kzawora): We should use virtual engine index here
...
...
vllm/worker/neuron_model_runner.py
View file @
0f8cafe2
...
...
@@ -8,6 +8,7 @@ from torch import nn
from
transformers_neuronx.config
import
GenerationConfig
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
...
...
@@ -314,11 +315,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
raise
ValueError
(
"NeuronModelRunner does not support multi-step execution."
)
with
set_forward_context
(
None
,
self
.
vllm_config
,
0
):
hidden_states
=
self
.
model
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
input_block_ids
=
model_input
.
input_block_ids
,
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
device
=
self
.
device
),
)
...
...
vllm/worker/openvino_model_runner.py
View file @
0f8cafe2
...
...
@@ -8,6 +8,7 @@ from torch import nn
from
vllm.attention
import
get_attn_backend
from
vllm.attention.backends.openvino
import
OpenVINOAttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
...
...
@@ -350,6 +351,7 @@ class OpenVINOModelRunner(ModelRunnerBase):
device
=
self
.
device
),
}
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
0
):
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
# Compute the logits.
...
...
vllm/worker/openvino_worker.py
View file @
0f8cafe2
...
...
@@ -20,6 +20,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from
vllm.platforms
import
current_platform
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceGroupMetadata
from
vllm.utils
import
bind_kv_cache
from
vllm.worker.openvino_model_runner
import
OpenVINOModelRunner
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
,
WorkerBase
...
...
@@ -339,6 +340,8 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
ov_device
,
)
self
.
kv_cache
=
self
.
cache_engine
.
kv_cache
bind_kv_cache
(
self
.
compilation_config
.
static_forward_context
,
[
self
.
kv_cache
])
self
.
model_runner
.
block_size
=
self
.
cache_engine
.
block_size
assert
self
.
kv_cache
is
not
None
...
...
@@ -507,12 +510,18 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
self
.
model_runner
.
block_size
=
tmp_cache_config
.
block_size
bind_kv_cache
(
self
.
compilation_config
.
static_forward_context
,
profiling_cache_engine
.
kv_cache
)
# Run the model with the dummy inputs.
self
.
model_runner
.
execute_model
(
seqs
,
profiling_cache_engine
.
kv_cache
)
# explicitly delete temporary KV cache manager to free KV cache
# when real inputs will be passed to OV
# Explicitly revert bind_kv_cache and delete temporary KV cache
# manager to free KV cache when real inputs will be passed to OV
bind_kv_cache
(
self
.
compilation_config
.
static_forward_context
,
[[
torch
.
tensor
([])
for
_
in
range
(
len
(
profiling_cache_engine
.
kv_cache
))
]])
del
profiling_cache_engine
logger
.
info
(
...
...
vllm/worker/tpu_model_runner.py
View file @
0f8cafe2
...
...
@@ -13,6 +13,7 @@ import torch_xla.runtime as xr
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader
import
get_model
...
...
@@ -265,8 +266,9 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
torch
.
_dynamo
.
mark_dynamic
(
t
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
p
,
0
)
# Dummy run.
self
.
model
(
token_ids
,
position_ids
,
attn_metadata
,
input_lens
,
t
,
p
,
num_samples
,
kv_caches
)
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
0
):
self
.
model
(
token_ids
,
position_ids
,
attn_metadata
,
input_lens
,
t
,
p
,
num_samples
,
kv_caches
)
def
warmup_model
(
self
,
...
...
@@ -663,9 +665,12 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
input_lens
=
model_input
.
input_lens
[
i
:
i
+
1
].
to
(
self
.
device
)
t
=
model_input
.
t
[
i
:
i
+
1
].
to
(
self
.
device
)
p
=
model_input
.
p
[
i
:
i
+
1
].
to
(
self
.
device
)
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
):
output_token_ids
=
self
.
model
(
token_ids
,
position_ids
,
attn_metadata
,
input_lens
,
t
,
p
,
model_input
.
num_samples
,
attn_metadata
,
input_lens
,
t
,
p
,
model_input
.
num_samples
,
kv_caches
)
next_token_ids
.
append
(
output_token_ids
[
0
])
start_idx
=
end_idx
...
...
@@ -711,9 +716,12 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
input_lens
=
model_input
.
input_lens
.
to
(
self
.
device
)
for
i
in
range
(
num_steps
):
slot_mapping
=
attn_metadata
.
slot_mapping
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
):
output_token_ids
=
self
.
model
(
token_ids
,
position_ids
,
attn_metadata
,
input_lens
,
t
,
p
,
model_input
.
num_samples
,
attn_metadata
,
input_lens
,
t
,
p
,
model_input
.
num_samples
,
kv_caches
)
self
.
cached_step_outputs
.
append
(
output_token_ids
)
...
...
vllm/worker/tpu_worker.py
View file @
0f8cafe2
...
...
@@ -12,7 +12,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
get_dtype_size
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
bind_kv_cache
,
get_dtype_size
from
vllm.worker.tpu_model_runner
import
ExecutionMode
,
TPUModelRunner
from
vllm.worker.worker_base
import
(
LocalOrDistributedWorkerBase
,
LoraNotSupportedWorkerBase
,
WorkerBase
,
...
...
@@ -108,6 +108,8 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
self
.
device
))
for
_
in
range
(
num_layers
)]
bind_kv_cache
(
self
.
compilation_config
.
static_forward_context
,
[
kv_caches
])
self
.
model_runner
.
_dummy_run
(
batch_size
=
1
,
seq_len
=
self
.
scheduler_config
.
max_num_batched_tokens
,
...
...
@@ -170,6 +172,8 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
device
=
"cpu"
)
cpu_v_cache
=
torch
.
zeros_like
(
cpu_k_cache
)
self
.
cpu_cache
.
append
((
cpu_k_cache
,
cpu_v_cache
))
bind_kv_cache
(
self
.
compilation_config
.
static_forward_context
,
[
self
.
tpu_cache
])
self
.
_warmup_model
()
def
_warmup_model
(
self
)
->
None
:
...
...
vllm/worker/xpu_model_runner.py
View file @
0f8cafe2
...
...
@@ -12,6 +12,7 @@ import torch.nn as nn
from
vllm.attention
import
get_attn_backend
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_pp_group
from
vllm.forward_context
import
set_forward_context
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadataCache
...
...
@@ -562,14 +563,16 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
if
(
self
.
observability_config
is
not
None
and
self
.
observability_config
.
collect_model_forward_time
):
model_forward_start_time
=
time
.
time
()
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
):
hidden_or_intermediate_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
model_input
.
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
device
=
self
.
device
))
# Compute the logits in the last pipeline stage.
if
not
get_pp_group
().
is_last_rank
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment