Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4f95ffee
Unverified
Commit
4f95ffee
authored
Oct 07, 2024
by
Isotr0py
Committed by
GitHub
Oct 07, 2024
Browse files
[Hardware][CPU] Cross-attention and Encoder-Decoder models support on CPU backend (#9089)
parent
8c6de96e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
834 additions
and
287 deletions
+834
-287
.buildkite/run-cpu-test.sh
.buildkite/run-cpu-test.sh
+1
-0
tests/models/encoder_decoder/language/test_bart.py
tests/models/encoder_decoder/language/test_bart.py
+211
-217
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+299
-61
vllm/worker/cpu_enc_dec_model_runner.py
vllm/worker/cpu_enc_dec_model_runner.py
+311
-0
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+3
-7
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+9
-2
No files found.
.buildkite/run-cpu-test.sh
View file @
4f95ffee
...
@@ -23,6 +23,7 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
...
@@ -23,6 +23,7 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
# Run basic model test
# Run basic model test
docker
exec
cpu-test bash
-c
"
docker
exec
cpu-test bash
-c
"
pip install pytest matplotlib einops transformers_stream_generator datamodel_code_generator
pip install pytest matplotlib einops transformers_stream_generator datamodel_code_generator
pytest -v -s tests/models/encoder_decoder/language
pytest -v -s tests/models/decoder_only/language
\
pytest -v -s tests/models/decoder_only/language
\
--ignore=tests/models/test_fp8.py
\
--ignore=tests/models/test_fp8.py
\
--ignore=tests/models/decoder_only/language/test_jamba.py
\
--ignore=tests/models/decoder_only/language/test_jamba.py
\
...
...
tests/models/encoder_decoder/language/test_bart.py
View file @
4f95ffee
...
@@ -4,220 +4,214 @@ Run `pytest tests/models/encoder_decoder/language/test_bart.py`.
...
@@ -4,220 +4,214 @@ Run `pytest tests/models/encoder_decoder/language/test_bart.py`.
"""
"""
from
typing
import
List
,
Optional
,
Tuple
,
Type
from
typing
import
List
,
Optional
,
Tuple
,
Type
from
vllm.utils
import
is_cpu
import
pytest
from
transformers
import
AutoModelForSeq2SeqLM
if
not
is_cpu
():
# CPU backend is not currently supported with encoder/decoder models
from
vllm.sequence
import
SampleLogprobs
# skip test definitions entirely to avoid importing GPU kernel libs
# (xFormers, etc.)
from
....conftest
import
(
DecoderPromptType
,
ExplicitEncoderDecoderPrompt
,
HfRunner
,
VllmRunner
)
import
pytest
from
....utils
import
multi_gpu_test
from
transformers
import
AutoModelForSeq2SeqLM
from
...utils
import
check_logprobs_close
from
vllm.sequence
import
SampleLogprobs
MODELS
=
[
"facebook/bart-base"
,
"facebook/bart-large-cnn"
]
from
....conftest
import
(
DecoderPromptType
,
ExplicitEncoderDecoderPrompt
,
HfRunner
,
VllmRunner
)
def
vllm_to_hf_output
(
from
....utils
import
multi_gpu_test
vllm_output
:
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]],
from
...utils
import
check_logprobs_close
decoder_prompt_type
:
DecoderPromptType
,
):
MODELS
=
[
"facebook/bart-base"
,
"facebook/bart-large-cnn"
]
"""Sanitize vllm output to be comparable with hf output."""
output_ids
,
output_str
,
out_logprobs
=
vllm_output
def
vllm_to_hf_output
(
vllm_output
:
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]],
hf_output_str
=
output_str
+
"</s>"
decoder_prompt_type
:
DecoderPromptType
,
if
decoder_prompt_type
==
DecoderPromptType
.
NONE
:
):
hf_output_str
=
"<s>"
+
hf_output_str
"""Sanitize vllm output to be comparable with hf output."""
output_ids
,
output_str
,
out_logprobs
=
vllm_output
return
output_ids
,
hf_output_str
,
out_logprobs
hf_output_str
=
output_str
+
"</s>"
if
decoder_prompt_type
==
DecoderPromptType
.
NONE
:
def
run_test
(
hf_output_str
=
"<s>"
+
hf_output_str
hf_runner
:
Type
[
HfRunner
],
vllm_runner
:
Type
[
VllmRunner
],
return
output_ids
,
hf_output_str
,
out_logprobs
prompts
:
List
[
ExplicitEncoderDecoderPrompt
[
str
,
str
]],
decoder_prompt_type
:
DecoderPromptType
,
def
run_test
(
model
:
str
,
hf_runner
:
Type
[
HfRunner
],
*
,
vllm_runner
:
Type
[
VllmRunner
],
dtype
:
str
,
prompts
:
List
[
ExplicitEncoderDecoderPrompt
[
str
,
str
]],
max_tokens
:
int
,
decoder_prompt_type
:
DecoderPromptType
,
num_logprobs
:
int
,
model
:
str
,
tensor_parallel_size
:
int
,
*
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
dtype
:
str
,
)
->
None
:
max_tokens
:
int
,
'''
num_logprobs
:
int
,
Test the vLLM BART model for a variety of encoder/decoder input prompts,
tensor_parallel_size
:
int
,
by validating it against HuggingFace (HF) BART.
distributed_executor_backend
:
Optional
[
str
]
=
None
,
)
->
None
:
Arguments:
'''
Test the vLLM BART model for a variety of encoder/decoder input prompts,
* hf_runner: HuggingFace (HF) test model runner
by validating it against HuggingFace (HF) BART.
* vllm_runner: vLLM test model runner
* example_encoder_decoder_prompts: test fixture which provides a
Arguments:
dictionary of dummy prompts
* model: the HF ID of the specific BART variant under test
* hf_runner: HuggingFace (HF) test model runner
* dtype: the tensor datatype to employ
* vllm_runner: vLLM test model runner
* max_tokens
* example_encoder_decoder_prompts: test fixture which provides a
* num_logprobs
dictionary of dummy prompts
* decoder_prompt_type: key into the example_encoder_decoder_prompts
* model: the HF ID of the specific BART variant under test
dictionary; selects specific encoder/decoder
* dtype: the tensor datatype to employ
prompt scenarios to test
* max_tokens
* num_logprobs
A note on using HF BART as a baseline for validating vLLM BART,
* decoder_prompt_type: key into the example_encoder_decoder_prompts
specifically when the decoder prompt is None.
dictionary; selects specific encoder/decoder
prompt scenarios to test
The HF GenerationMixin's default behavior is to force the first
decoded token to be <BOS> if the prompt does not already contain
A note on using HF BART as a baseline for validating vLLM BART,
<BOS> (this is accomplished using a logit
specifically when the decoder prompt is None.
processor setting.)
The HF GenerationMixin's default behavior is to force the first
So when we use HF BART as our baseline for comparison, note that
decoded token to be <BOS> if the prompt does not already contain
when the user provides a request with a None decoder prompt
<BOS> (this is accomplished using a logit
(i.e. a singleton encoder prompt, or else an explicit encoder/
processor setting.)
decoder prompt with the decoder sub-prompt set to None), HF and
vLLM handle this in different ways:
So when we use HF BART as our baseline for comparison, note that
when the user provides a request with a None decoder prompt
* HF will (1) tokenize the None prompt as an empty token-list,
(i.e. a singleton encoder prompt, or else an explicit encoder/
(2) append <decoder-start-token> to the beginning, yielding
decoder prompt with the decoder sub-prompt set to None), HF and
[<decoder-start-token>], (3) pass this token list to the model, and
vLLM handle this in different ways:
then (4) after computing logits during prefill, override the model
logits & force <BOS> to be the first generated token.
* HF will (1) tokenize the None prompt as an empty token-list,
(2) append <decoder-start-token> to the beginning, yielding
* vLLM will (1) tokenize the None prompt as [<BOS>], (2) append decoder-
[<decoder-start-token>], (3) pass this token list to the model, and
start-token to the beginning, yielding [<decoder-start-token><BOS>],
then (4) after computing logits during prefill, override the model
(3) pass these tokens to the model & proceed with generation.
logits & force <BOS> to be the first generated token.
The net effect is that compared to vLLM, the list of HF *decoded* tokens
* vLLM will (1) tokenize the None prompt as [<BOS>], (2) append decoder-
will contain one more initial <BOS> than the vLLM generated tokens,
start-token to the beginning, yielding [<decoder-start-token><BOS>],
because vLLM's <BOS> token is injected into the prompt rather than into
(3) pass these tokens to the model & proceed with generation.
the generated output. This is in spite of the fact that overall, the
complete sequences (prompt + decoded tokens) produced by vLLM will match
The net effect is that compared to vLLM, the list of HF *decoded* tokens
HF.
will contain one more initial <BOS> than the vLLM generated tokens,
because vLLM's <BOS> token is injected into the prompt rather than into
So when we use HF decoded token output to validate vLLM's decoded token
the generated output. This is in spite of the fact that overall, the
output, the testing process must account for the difference in decoded
complete sequences (prompt + decoded tokens) produced by vLLM will match
token sequences between vLLM and HF specifically in the
HF.
decoder-prompt-is-None case.
So when we use HF decoded token output to validate vLLM's decoded token
One option is to disable the logit processor feature that forces the
output, the testing process must account for the difference in decoded
<BOS> token to be decoded (forced_bos_token_id = None), eliminating
token sequences between vLLM and HF specifically in the
the problem entirely. However this is not "normal" BART usage.
decoder-prompt-is-None case.
The other option is - only in the decoder-prompt-is-None case - to
One option is to disable the logit processor feature that forces the
discard the first decoded token from the HF output before comparing it
<BOS> token to be decoded (forced_bos_token_id = None), eliminating
to vLLM.
the problem entirely. However this is not "normal" BART usage.
To that end, when testing the scenario where the decoder prompt is None
The other option is - only in the decoder-prompt-is-None case - to
(and only in that one scenario), this test skips the first HF decoded
discard the first decoded token from the HF output before comparing it
token during the process of validating the vLLM decoded output.
to vLLM.
'''
To that end, when testing the scenario where the decoder prompt is None
# NOTE: take care of the order. run vLLM first, and then run HF.
(and only in that one scenario), this test skips the first HF decoded
# vLLM needs a fresh new process without cuda initialization.
token during the process of validating the vLLM decoded output.
# if we run HF first, the cuda initialization will be done and it
'''
# will hurt multiprocessing backend with fork method (the default).
# NOTE: take care of the order. run vLLM first, and then run HF.
# Note: currently encoder/decoder models are only compatible with
# vLLM needs a fresh new process without cuda initialization.
# enforce_eager=True. Normally this is not a problem because
# if we run HF first, the cuda initialization will be done and it
# for encoder/decoder models vLLM will
# will hurt multiprocessing backend with fork method (the default).
# default to enforce_eager=True if enforce_eager
# is left unspecified. However, the
# Note: currently encoder/decoder models are only compatible with
# VllmRunner test fixture (which wraps around the LLM class) defaults to
# enforce_eager=True. Normally this is not a problem because
# enforce_eager=False (a behavior which a number of already-exisitng
# for encoder/decoder models vLLM will
# decoder-only unit tests expect), so when testing an encoder/decoder
# default to enforce_eager=True if enforce_eager
# model we must explicitly specify enforce_eager=True in the VllmRunner
# is left unspecified. However, the
# constructor.
# VllmRunner test fixture (which wraps around the LLM class) defaults to
with
vllm_runner
(
model
,
# enforce_eager=False (a behavior which a number of already-exisitng
dtype
=
dtype
,
# decoder-only unit tests expect), so when testing an encoder/decoder
tensor_parallel_size
=
tensor_parallel_size
,
# model we must explicitly specify enforce_eager=True in the VllmRunner
distributed_executor_backend
=
distributed_executor_backend
,
# constructor.
enforce_eager
=
True
)
as
vllm_model
:
with
vllm_runner
(
vllm_outputs
=
vllm_model
.
generate_encoder_decoder_greedy_logprobs
(
model
,
prompts
,
max_tokens
,
num_logprobs
)
dtype
=
dtype
,
tensor_parallel_size
=
tensor_parallel_size
,
# Configuration settings for HF baseline
distributed_executor_backend
=
distributed_executor_backend
,
hf_kwargs
=
{
enforce_eager
=
True
)
as
vllm_model
:
"top_k"
:
None
,
vllm_outputs
=
vllm_model
.
generate_encoder_decoder_greedy_logprobs
(
"num_beams"
:
1
,
prompts
,
max_tokens
,
num_logprobs
)
"repetition_penalty"
:
1.0
,
"top_p"
:
1.0
,
# Configuration settings for HF baseline
"length_penalty"
:
1.0
,
hf_kwargs
=
{
"early_stopping"
:
False
,
"top_k"
:
None
,
"no_repeat_ngram_size"
:
None
,
"num_beams"
:
1
,
"min_length"
:
0
"repetition_penalty"
:
1.0
,
}
"top_p"
:
1.0
,
"length_penalty"
:
1.0
,
with
hf_runner
(
model
,
dtype
=
dtype
,
"early_stopping"
:
False
,
auto_cls
=
AutoModelForSeq2SeqLM
)
as
hf_model
:
"no_repeat_ngram_size"
:
None
,
hf_outputs
=
(
hf_model
.
generate_encoder_decoder_greedy_logprobs_limit
(
"min_length"
:
0
prompts
,
}
max_tokens
,
num_logprobs
,
with
hf_runner
(
model
,
dtype
=
dtype
,
**
hf_kwargs
,
auto_cls
=
AutoModelForSeq2SeqLM
)
as
hf_model
:
))
hf_outputs
=
(
hf_model
.
generate_encoder_decoder_greedy_logprobs_limit
(
hf_skip_tokens
=
(
1
prompts
,
if
decoder_prompt_type
==
DecoderPromptType
.
NONE
else
0
)
max_tokens
,
num_logprobs
,
check_logprobs_close
(
**
hf_kwargs
,
outputs_0_lst
=
hf_outputs
,
))
outputs_1_lst
=
[
vllm_to_hf_output
(
vllm_output
,
decoder_prompt_type
)
hf_skip_tokens
=
(
1
if
decoder_prompt_type
==
DecoderPromptType
.
NONE
for
vllm_output
in
vllm_outputs
else
0
)
],
name_0
=
"hf"
,
check_logprobs_close
(
name_1
=
"vllm"
,
outputs_0_lst
=
hf_outputs
,
num_outputs_0_skip_tokens
=
hf_skip_tokens
,
outputs_1_lst
=
[
)
vllm_to_hf_output
(
vllm_output
,
decoder_prompt_type
)
for
vllm_output
in
vllm_outputs
],
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
name_0
=
"hf"
,
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
,
"bfloat16"
])
name_1
=
"vllm"
,
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
num_outputs_0_skip_tokens
=
hf_skip_tokens
,
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
)
@
pytest
.
mark
.
parametrize
(
"decoder_prompt_type"
,
list
(
DecoderPromptType
))
def
test_models
(
hf_runner
,
vllm_runner
,
example_encoder_decoder_prompts
,
model
,
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
dtype
,
max_tokens
,
num_logprobs
,
decoder_prompt_type
)
->
None
:
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
,
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
run_test
(
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
hf_runner
,
@
pytest
.
mark
.
parametrize
(
"decoder_prompt_type"
,
list
(
DecoderPromptType
))
vllm_runner
,
def
test_models
(
hf_runner
,
vllm_runner
,
example_encoder_decoder_prompts
,
example_encoder_decoder_prompts
[
decoder_prompt_type
],
model
,
dtype
,
max_tokens
,
num_logprobs
,
decoder_prompt_type
,
decoder_prompt_type
)
->
None
:
model
,
dtype
=
dtype
,
run_test
(
max_tokens
=
max_tokens
,
hf_runner
,
num_logprobs
=
num_logprobs
,
vllm_runner
,
tensor_parallel_size
=
1
,
example_encoder_decoder_prompts
[
decoder_prompt_type
],
)
decoder_prompt_type
,
model
,
dtype
=
dtype
,
@
multi_gpu_test
(
num_gpus
=
2
)
max_tokens
=
max_tokens
,
@
pytest
.
mark
.
parametrize
(
"distributed_executor_backend"
,
[
"ray"
,
"mp"
])
num_logprobs
=
num_logprobs
,
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"facebook/bart-large-cnn"
])
tensor_parallel_size
=
1
,
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
)
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
"decoder_prompt_type"
,
[
DecoderPromptType
.
CUSTOM
])
@
pytest
.
mark
.
parametrize
(
"distributed_executor_backend"
,
[
"ray"
,
"mp"
])
def
test_models_distributed
(
hf_runner
,
vllm_runner
,
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"facebook/bart-large-cnn"
])
example_encoder_decoder_prompts
,
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
distributed_executor_backend
,
model
,
dtype
,
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
max_tokens
,
num_logprobs
,
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
decoder_prompt_type
)
->
None
:
@
pytest
.
mark
.
parametrize
(
"decoder_prompt_type"
,
[
DecoderPromptType
.
CUSTOM
])
run_test
(
def
test_models_distributed
(
hf_runner
,
vllm_runner
,
hf_runner
,
example_encoder_decoder_prompts
,
vllm_runner
,
distributed_executor_backend
,
model
,
dtype
,
example_encoder_decoder_prompts
[
decoder_prompt_type
],
max_tokens
,
num_logprobs
,
decoder_prompt_type
,
decoder_prompt_type
)
->
None
:
model
,
run_test
(
dtype
=
dtype
,
hf_runner
,
max_tokens
=
max_tokens
,
vllm_runner
,
num_logprobs
=
num_logprobs
,
example_encoder_decoder_prompts
[
decoder_prompt_type
],
tensor_parallel_size
=
2
,
decoder_prompt_type
,
distributed_executor_backend
=
distributed_executor_backend
,
model
,
)
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
2
,
distributed_executor_backend
=
distributed_executor_backend
,
)
vllm/attention/backends/torch_sdpa.py
View file @
4f95ffee
...
@@ -75,6 +75,22 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -75,6 +75,22 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
slot_mapping
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
seq_lens
:
Optional
[
List
[
int
]]
seq_lens
:
Optional
[
List
[
int
]]
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens
:
Optional
[
List
[
int
]]
=
None
encoder_seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# Maximum sequence length among encoder sequences
max_encoder_seq_len
:
Optional
[
int
]
=
None
# Number of tokens input to encoder
num_encoder_tokens
:
Optional
[
int
]
=
None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping
:
Optional
[
torch
.
Tensor
]
=
None
cross_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Set during the execution of the first attention op.
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# It is a list because it is needed to set per prompt
...
@@ -82,6 +98,28 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -82,6 +98,28 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
# from xformer API.
# from xformer API.
# will not appear in the __repr__ and __init__
# will not appear in the __repr__ and __init__
self
.
attn_bias
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
self
.
attn_bias
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
self
.
encoder_attn_bias
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
self
.
cross_attn_bias
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
@
property
def
is_all_encoder_attn_metadata_set
(
self
):
'''
All attention metadata required for encoder attention is set.
'''
return
((
self
.
encoder_seq_lens
is
not
None
)
and
(
self
.
encoder_seq_lens_tensor
is
not
None
)
and
(
self
.
max_encoder_seq_len
is
not
None
))
@
property
def
is_all_cross_attn_metadata_set
(
self
):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return
(
self
.
is_all_encoder_attn_metadata_set
and
(
self
.
cross_slot_mapping
is
not
None
)
and
(
self
.
cross_block_tables
is
not
None
))
@
property
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"TorchSDPAMetadata"
]:
def
prefill_metadata
(
self
)
->
Optional
[
"TorchSDPAMetadata"
]:
...
@@ -101,6 +139,136 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -101,6 +139,136 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
return
self
return
self
def
get_seq_lens
(
self
,
attn_type
:
AttentionType
,
):
'''
Extract appropriate sequence lengths from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence lengths tensor for query
* Appropriate sequence lengths tensor for key & value
'''
if
attn_type
==
AttentionType
.
DECODER
:
seq_lens_q
=
self
.
seq_lens
seq_lens_kv
=
self
.
seq_lens
elif
attn_type
==
AttentionType
.
ENCODER
:
seq_lens_q
=
self
.
encoder_seq_lens
seq_lens_kv
=
self
.
encoder_seq_lens
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
seq_lens_q
=
self
.
seq_lens
seq_lens_kv
=
self
.
encoder_seq_lens
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
return
seq_lens_q
,
seq_lens_kv
def
get_attn_bias
(
self
,
attn_type
:
AttentionType
,
)
->
Optional
[
List
[
torch
.
Tensor
]]:
'''
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
'''
if
attn_type
==
AttentionType
.
DECODER
:
return
self
.
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER
:
return
self
.
encoder_attn_bias
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
return
self
.
cross_attn_bias
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
def
set_attn_bias
(
self
,
attn_bias
:
List
[
torch
.
Tensor
],
attn_type
:
AttentionType
,
)
->
None
:
'''
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
'''
if
attn_type
==
AttentionType
.
DECODER
:
self
.
attn_bias
=
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER
:
self
.
encoder_attn_bias
=
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
self
.
cross_attn_bias
=
attn_bias
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
def
get_seq_len_block_table_args
(
self
,
attn_type
:
AttentionType
,
)
->
tuple
:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if
attn_type
==
AttentionType
.
DECODER
:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
return
(
self
.
seq_lens_tensor
,
self
.
max_decode_seq_len
,
self
.
block_tables
)
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return
(
self
.
encoder_seq_lens_tensor
,
self
.
max_encoder_seq_len
,
self
.
cross_block_tables
)
elif
attn_type
==
AttentionType
.
ENCODER
:
# No block tables associated with encoder attention
return
(
self
.
encoder_seq_lens_tensor
,
self
.
max_encoder_seq_len
,
None
)
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
class
TorchSDPABackendImpl
(
AttentionImpl
[
TorchSDPAMetadata
]):
class
TorchSDPABackendImpl
(
AttentionImpl
[
TorchSDPAMetadata
]):
...
@@ -171,84 +339,101 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -171,84 +339,101 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
assert
k_scale
==
1.0
and
v_scale
==
1.0
assert
k_scale
==
1.0
and
v_scale
==
1.0
if
attn_type
!=
AttentionType
.
DECODER
:
if
(
attn_type
==
AttentionType
.
ENCODER
raise
NotImplementedError
(
"Encoder self-attention and "
and
(
not
attn_metadata
.
is_all_encoder_attn_metadata_set
)):
"encoder/decoder cross-attention "
raise
AttributeError
(
"Encoder attention requires setting "
"are not implemented for "
"encoder metadata attributes."
)
"TorchSDPABackendImpl"
)
elif
(
attn_type
==
AttentionType
.
ENCODER_DECODER
num_tokens
,
hidden_size
=
query
.
shape
and
(
not
attn_metadata
.
is_all_cross_attn_metadata_set
)):
raise
AttributeError
(
"Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes."
)
# Reshape the query, key, and value tensors.
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
key
is
not
None
:
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
assert
value
is
not
None
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
.
numel
()
>
0
:
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
else
:
assert
value
is
None
if
(
attn_type
!=
AttentionType
.
ENCODER
and
kv_cache
.
numel
()
>
0
):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
)
if
attn_metadata
.
is_prompt
:
if
(
key
is
not
None
)
and
(
value
is
not
None
):
if
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# Update cross-attention KV cache (prefill-only)
# During cross-attention decode, key & value will be None,
# preventing this IF-statement branch from running
updated_slot_mapping
=
attn_metadata
.
cross_slot_mapping
else
:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping
=
attn_metadata
.
slot_mapping
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
updated_slot_mapping
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
)
if
attn_type
!=
AttentionType
.
ENCODER
:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
else
:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert
attn_metadata
.
num_encoder_tokens
is
not
None
num_prefill_tokens
=
attn_metadata
.
num_encoder_tokens
num_decode_tokens
=
0
if
attn_type
==
AttentionType
.
DECODER
:
# Only enforce this shape-constraint for decoder
# self-attention
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
assert
attn_metadata
.
seq_lens
is
not
None
assert
attn_metadata
.
seq_lens
is
not
None
if
(
kv_cache
.
numel
()
==
0
if
(
kv_cache
.
numel
()
==
0
or
attn_metadata
.
block_tables
.
numel
()
==
0
):
or
prefill_meta
.
block_tables
.
numel
()
==
0
):
if
self
.
num_kv_heads
!=
self
.
num_heads
:
output
=
self
.
_run_sdpa_forward
(
query
,
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
key
,
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
value
,
dim
=
1
)
prefill_meta
,
attn_type
=
attn_type
)
if
attn_metadata
.
attn_bias
is
None
:
if
self
.
alibi_slopes
is
not
None
:
att_masks
=
_make_alibi_bias
(
self
.
alibi_slopes
,
query
.
dtype
,
attn_metadata
.
seq_lens
)
# type: ignore
elif
self
.
sliding_window
is
not
None
:
att_masks
=
_make_sliding_window_bias
(
attn_metadata
.
seq_lens
,
self
.
sliding_window
,
query
.
dtype
)
# type: ignore
else
:
att_masks
=
[
None
]
*
len
(
attn_metadata
.
seq_lens
)
attn_metadata
.
attn_bias
=
att_masks
query
=
query
.
movedim
(
0
,
query
.
dim
()
-
2
)
key
=
key
.
movedim
(
0
,
key
.
dim
()
-
2
)
value
=
value
.
movedim
(
0
,
value
.
dim
()
-
2
)
start
=
0
output
=
torch
.
empty
(
(
num_tokens
,
self
.
num_heads
,
self
.
head_size
),
dtype
=
query
.
dtype
)
for
seq_len
,
mask
in
zip
(
attn_metadata
.
seq_lens
,
attn_metadata
.
attn_bias
):
end
=
start
+
seq_len
sub_out
=
scaled_dot_product_attention
(
query
[
None
,
:,
start
:
end
,
:],
key
[
None
,
:,
start
:
end
,
:],
value
[
None
,
:,
start
:
end
,
:],
attn_mask
=
mask
,
dropout_p
=
0.0
,
is_causal
=
not
self
.
need_mask
,
scale
=
self
.
scale
).
squeeze
(
0
).
movedim
(
query
.
dim
()
-
2
,
0
)
output
[
start
:
end
,
:,
:]
=
sub_out
start
=
end
else
:
else
:
# prefix-enabled attention
# prefix-enabled attention
raise
RuntimeError
(
raise
RuntimeError
(
"Torch SDPA backend doesn't support prefix decoding."
)
"Torch SDPA backend doesn't support prefix decoding."
)
else
:
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
# Decoding run.
(
seq_lens_arg
,
max_seq_len_arg
,
block_tables_arg
,
)
=
decode_meta
.
get_seq_len_block_table_args
(
attn_type
)
output
=
PagedAttention
.
forward_decode
(
output
=
PagedAttention
.
forward_decode
(
query
,
query
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
attn_metadata
.
block_tables
,
block_tables
_arg
,
attn_metadata
.
seq_lens_
tensor
,
seq_lens_
arg
,
attn_metadata
.
max_decode
_seq_len
,
max
_seq_len
_arg
,
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
scale
,
...
@@ -260,6 +445,59 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -260,6 +445,59 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
# Reshape the output tensor.
# Reshape the output tensor.
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
def
_run_sdpa_forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_metadata
:
TorchSDPAMetadata
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
):
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
attn_masks
=
attn_metadata
.
get_attn_bias
(
attn_type
)
if
attn_masks
is
None
:
if
self
.
alibi_slopes
is
not
None
:
attn_masks
=
_make_alibi_bias
(
self
.
alibi_slopes
,
query
.
dtype
,
attn_metadata
.
seq_lens
)
# type: ignore
elif
self
.
sliding_window
is
not
None
:
assert
attn_metadata
.
seq_lens
is
not
None
attn_masks
=
_make_sliding_window_bias
(
attn_metadata
.
seq_lens
,
self
.
sliding_window
,
query
.
dtype
)
# type: ignore
else
:
seq_lens
,
_
=
attn_metadata
.
get_seq_lens
(
attn_type
)
attn_masks
=
[
None
]
*
len
(
seq_lens
)
attn_metadata
.
set_attn_bias
(
attn_masks
,
attn_type
)
output
=
torch
.
empty_like
(
query
)
query
=
query
.
movedim
(
0
,
query
.
dim
()
-
2
)
key
=
key
.
movedim
(
0
,
key
.
dim
()
-
2
)
value
=
value
.
movedim
(
0
,
value
.
dim
()
-
2
)
causal_attn
=
(
attn_type
==
AttentionType
.
DECODER
)
seq_lens_q
,
seq_lens_kv
=
attn_metadata
.
get_seq_lens
(
attn_type
)
start_q
,
start_kv
=
0
,
0
for
seq_len_q
,
seq_len_kv
,
mask
in
zip
(
seq_lens_q
,
seq_lens_kv
,
attn_masks
):
end_q
=
start_q
+
seq_len_q
end_kv
=
start_kv
+
seq_len_kv
sub_out
=
scaled_dot_product_attention
(
query
[
None
,
:,
start_q
:
end_q
,
:],
key
[
None
,
:,
start_kv
:
end_kv
,
:],
value
[
None
,
:,
start_kv
:
end_kv
,
:],
attn_mask
=
mask
,
dropout_p
=
0.0
,
is_causal
=
causal_attn
and
not
self
.
need_mask
,
scale
=
self
.
scale
).
squeeze
(
0
).
movedim
(
query
.
dim
()
-
2
,
0
)
output
[
start_q
:
end_q
,
:,
:]
=
sub_out
start_q
,
start_kv
=
end_q
,
end_kv
return
output
def
_make_alibi_bias
(
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
alibi_slopes
:
torch
.
Tensor
,
...
...
vllm/worker/cpu_enc_dec_model_runner.py
0 → 100644
View file @
4f95ffee
import
dataclasses
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
cast
import
torch
from
vllm.attention
import
AttentionMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.multimodal
import
MultiModalInputs
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
from
vllm.utils
import
make_tensor_with_pad
from
vllm.worker.cpu_model_runner
import
(
CPUModelRunner
,
ModelInputForCPUBuilder
,
ModelInputForCPUWithSamplingMetadata
)
from
vllm.worker.model_runner_base
import
(
_add_attn_metadata_broadcastable_dict
,
_add_sampling_metadata_broadcastable_dict
)
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
EncoderDecoderModelInputForCPU
(
ModelInputForCPUWithSamplingMetadata
):
"""
Used by the EncoderDecoderModelRunner.
"""
encoder_input_tokens
:
Optional
[
torch
.
Tensor
]
=
None
encoder_input_positions
:
Optional
[
torch
.
Tensor
]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
"input_tokens"
:
self
.
input_tokens
,
"input_positions"
:
self
.
input_positions
,
"encoder_input_tokens"
:
self
.
encoder_input_tokens
,
"encoder_input_positions"
:
self
.
encoder_input_positions
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
_add_sampling_metadata_broadcastable_dict
(
tensor_dict
,
self
.
sampling_metadata
)
return
tensor_dict
@
classmethod
def
from_broadcasted_tensor_dict
(
cls
,
tensor_dict
:
Dict
[
str
,
Any
],
attn_backend
:
Optional
[
"AttentionBackend"
]
=
None
,
)
->
"EncoderDecoderModelInputForCPU"
:
return
cast
(
EncoderDecoderModelInputForCPU
,
super
().
from_broadcasted_tensor_dict
(
tensor_dict
,
attn_backend
))
class
CPUEncoderDecoderModelRunner
(
CPUModelRunner
):
_model_input_cls
:
Type
[
EncoderDecoderModelInputForCPU
]
=
(
EncoderDecoderModelInputForCPU
)
_builder_cls
:
Type
[
ModelInputForCPUBuilder
]
=
ModelInputForCPUBuilder
def
_list_to_int32_tensor
(
self
,
_list
:
List
[
int
],
)
->
torch
.
Tensor
:
return
torch
.
tensor
(
_list
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
def
_list_to_long_tensor
(
self
,
_list
:
List
[
int
],
)
->
torch
.
Tensor
:
return
torch
.
tensor
(
_list
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
def
_empty_int32_tensor
(
self
)
->
torch
.
Tensor
:
return
self
.
_list_to_int32_tensor
([])
def
_empty_long_tensor
(
self
)
->
torch
.
Tensor
:
return
self
.
_list_to_long_tensor
([])
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Any
])
->
EncoderDecoderModelInputForCPU
:
return
EncoderDecoderModelInputForCPU
.
from_broadcasted_tensor_dict
(
tensor_dict
,
attn_backend
=
self
.
attn_backend
,
)
def
prepare_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
EncoderDecoderModelInputForCPU
:
model_input
=
super
().
prepare_model_input
(
seq_group_metadata_list
,
virtual_engine
,
finished_requests_ids
)
model_input
=
cast
(
EncoderDecoderModelInputForCPU
,
model_input
)
(
attn_metadata
,
encoder_input_tokens_tensor
,
encoder_input_positions_tensor
,
)
=
self
.
_prepare_encoder_model_input_tensors
(
seq_group_metadata_list
,
model_input
)
return
dataclasses
.
replace
(
model_input
,
attn_metadata
=
attn_metadata
,
encoder_input_tokens
=
encoder_input_tokens_tensor
,
encoder_input_positions
=
encoder_input_positions_tensor
,
)
def
_prepare_encoder_model_input_tensors
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
model_input
:
EncoderDecoderModelInputForCPU
,
)
->
Tuple
[
AttentionMetadata
,
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
"""Helper method to prepare the encoder- and cross-attn-related
model inputs based on a given sequence group. These additional inputs
are used to augment an already-computed `EncoderDecoderModelInput`
data structure which already has decoder-related model inputs
populated.
Sets the following attn_metadata fields:
* `num_encoder_tokens`
* `encoder_seq_lens`
* `encoder_seq_lens_tensor`
* `max_encoder_seq_len`
* `cross_slot_mapping`
* `cross_block_tables`
Constructs a new model inputs data structure, based on
(1) the existing fields in the `model_inputs` argument,
and (2) the following additional fields which are
computed (or in the case of `attn_metadata`, updated)
by this function:
* attn_metadata
* encoder_input_tokens
* encoder_input_positions
Arguments:
* seq_group_metadata_list: list of sequence groups for which to
compute inputs
* model_inputs: model inputs data structure with decoder-oriented
fields already computed.
Return:
* Updated model inputs data structure
"""
if
len
(
seq_group_metadata_list
)
==
0
:
return
(
model_input
.
attn_metadata
,
None
,
None
)
# Since we are not supporting chunked prefill either the entire
# batch is prefill or it is decode
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
# Build encoder inputs
encoder_seq_lens
:
List
[
int
]
=
[]
if
is_prompt
:
# Prefill phase.
cross_block_tables
=
self
.
_empty_int32_tensor
().
view
(
len
(
seq_group_metadata_list
),
-
1
)
# Extract input tokens/positions, cross-attention slot-mapping,
# & seq len from each sequence group metadata
(
encoder_input_tokens
,
encoder_input_positions
,
cross_slot_mapping
,
)
=
(
[],
[],
[],
)
for
seq_group_metadata
in
seq_group_metadata_list
:
# Build seq lens
seq_len
=
seq_group_metadata
.
encoder_seq_data
.
get_len
()
token_ids
=
seq_group_metadata
.
encoder_seq_data
.
get_token_ids
()
encoder_seq_lens
.
append
(
seq_len
)
# Build slot mapping
for
i
in
range
(
0
,
seq_len
):
block_number
=
seq_group_metadata
.
cross_block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
cross_slot_mapping
.
append
(
slot
)
# Build encoder input tokens
encoder_input_tokens
.
extend
(
token_ids
)
encoder_input_positions
.
extend
(
list
(
range
(
0
,
seq_len
)))
# Convert tokens/positions & cross-attention
# slot-mapping to encoder input tensors
encoder_input_tokens_tensor
=
self
.
_list_to_long_tensor
(
encoder_input_tokens
)
encoder_input_positions_tensor
=
self
.
_list_to_long_tensor
(
encoder_input_positions
)
cross_slot_mapping_tensor
=
self
.
_list_to_long_tensor
(
cross_slot_mapping
)
else
:
# Decode phase.
encoder_input_tokens_tensor
=
self
.
_empty_long_tensor
()
encoder_input_positions_tensor
=
self
.
_empty_long_tensor
()
cross_slot_mapping_tensor
=
self
.
_empty_long_tensor
()
# Extract cross-attention block tables &
# seq len from each sequence group metadata.
# Cross-attention block tables are empty
# during vLLM memory profiling.
cross_block_tables
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
for
_
in
range
(
len
(
seq_group_metadata
.
seq_data
)):
encoder_seq_lens
.
append
(
seq_group_metadata
.
encoder_seq_data
.
get_len
())
cross_block_table
=
seq_group_metadata
.
cross_block_table
cross_block_tables
.
append
([]
if
(
cross_block_table
is
None
)
else
cross_block_table
)
max_len_of_block_table
=
max
(
len
(
block_table
)
for
block_table
in
cross_block_tables
)
cross_block_tables
=
make_tensor_with_pad
(
cross_block_tables
,
max_len
=
max_len_of_block_table
,
pad
=
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
# Compute encoder sequence lengths & encoder
# sequence starting offset tensors
max_encoder_seq_len
=
max
(
encoder_seq_lens
,
default
=
0
)
encoder_seq_lens_tensor
=
self
.
_list_to_int32_tensor
(
encoder_seq_lens
)
encoder_seq_start_loc
=
torch
.
zeros
(
encoder_seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
cumsum
(
encoder_seq_lens_tensor
,
dim
=
0
,
dtype
=
encoder_seq_start_loc
.
dtype
,
out
=
encoder_seq_start_loc
[
1
:])
# Update attention metadata with encoder-oriented attributes
attn_metadata
=
model_input
.
attn_metadata
assert
attn_metadata
is
not
None
(
attn_metadata
.
num_encoder_tokens
,
attn_metadata
.
encoder_seq_lens
,
attn_metadata
.
encoder_seq_lens_tensor
,
attn_metadata
.
max_encoder_seq_len
,
attn_metadata
.
cross_slot_mapping
,
attn_metadata
.
cross_block_tables
,
)
=
(
sum
(
encoder_seq_lens
),
encoder_seq_lens
,
encoder_seq_lens_tensor
,
max_encoder_seq_len
,
cross_slot_mapping_tensor
,
cross_block_tables
,
)
return
(
attn_metadata
,
encoder_input_tokens_tensor
,
encoder_input_positions_tensor
)
@
torch
.
no_grad
()
def
execute_model
(
self
,
model_input
:
EncoderDecoderModelInputForCPU
,
kv_caches
:
List
[
torch
.
Tensor
],
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
)
->
Optional
[
List
[
SamplerOutput
]]:
if
num_steps
>
1
:
raise
ValueError
(
"CPU worker does not support multi-step execution."
)
model_executable
=
self
.
model
execute_model_kwargs
=
{
"input_ids"
:
model_input
.
input_tokens
,
"positions"
:
model_input
.
input_positions
,
"encoder_input_ids"
:
model_input
.
encoder_input_tokens
,
"encoder_positions"
:
model_input
.
encoder_input_positions
,
"kv_caches"
:
kv_caches
,
"attn_metadata"
:
model_input
.
attn_metadata
,
**
MultiModalInputs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
device
=
self
.
device
),
"intermediate_tensors"
:
intermediate_tensors
,
}
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
# Compute the logits.
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
model_input
.
sampling_metadata
)
# Only perform sampling in the driver worker.
if
not
self
.
is_driver_worker
:
return
[]
# Sample the next token.
output
=
self
.
model
.
sample
(
logits
=
logits
,
sampling_metadata
=
model_input
.
sampling_metadata
,
)
return
[
output
]
vllm/worker/cpu_model_runner.py
View file @
4f95ffee
...
@@ -19,7 +19,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
...
@@ -19,7 +19,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs
)
MultiModalInputs
)
from
vllm.sequence
import
(
IntermediateTensors
,
SequenceData
,
from
vllm.sequence
import
(
IntermediateTensors
,
SequenceData
,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
from
vllm.utils
import
STR_NOT_IMPL_ENC_DEC_ERR_STRS
,
make_tensor_with_pad
from
vllm.utils
import
make_tensor_with_pad
from
vllm.worker.model_runner_base
import
(
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
_add_attn_metadata_broadcastable_dict
,
_add_attn_metadata_broadcastable_dict
,
...
@@ -434,10 +434,6 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
...
@@ -434,10 +434,6 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
# Lazy initialization.
# Lazy initialization.
self
.
model
:
nn
.
Module
# Set after init_Model
self
.
model
:
nn
.
Module
# Set after init_Model
if
self
.
model_config
.
is_encoder_decoder_model
:
raise
NotImplementedError
(
STR_NOT_IMPL_ENC_DEC_ERR_STRS
[
'STR_NOT_IMPL_ENC_DEC_CPU'
])
@
property
@
property
def
model_is_mrope
(
self
)
->
bool
:
def
model_is_mrope
(
self
)
->
bool
:
"""Detect if the model has "mrope" rope_scaling type.
"""Detect if the model has "mrope" rope_scaling type.
...
@@ -459,8 +455,8 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
...
@@ -459,8 +455,8 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
def
make_model_input_from_broadcasted_tensor_dict
(
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
self
,
tensor_dict
:
Dict
[
str
,
Any
],
tensor_dict
:
Dict
[
str
,
Any
],
)
->
ModelInputForCPU
:
)
->
ModelInputForCPU
WithSamplingMetadata
:
return
ModelInputForCPU
.
from_broadcasted_tensor_dict
(
return
ModelInputForCPU
WithSamplingMetadata
.
from_broadcasted_tensor_dict
(
# noqa: E501
tensor_dict
,
tensor_dict
,
attn_backend
=
self
.
attn_backend
,
attn_backend
=
self
.
attn_backend
,
)
)
...
...
vllm/worker/cpu_worker.py
View file @
4f95ffee
"""A CPU worker class."""
"""A CPU worker class."""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
import
torch.distributed
import
torch.distributed
...
@@ -15,6 +15,7 @@ from vllm.logger import init_logger
...
@@ -15,6 +15,7 @@ from vllm.logger import init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.worker.cpu_enc_dec_model_runner
import
CPUEncoderDecoderModelRunner
from
vllm.worker.cpu_model_runner
import
CPUModelRunner
from
vllm.worker.cpu_model_runner
import
CPUModelRunner
from
vllm.worker.worker_base
import
(
LocalOrDistributedWorkerBase
,
from
vllm.worker.worker_base
import
(
LocalOrDistributedWorkerBase
,
LoraNotSupportedWorkerBase
,
WorkerInput
)
LoraNotSupportedWorkerBase
,
WorkerInput
)
...
@@ -163,7 +164,10 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -163,7 +164,10 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
else
:
else
:
self
.
local_omp_cpuid
=
omp_cpuids
.
split
(
"|"
)[
rank
]
self
.
local_omp_cpuid
=
omp_cpuids
.
split
(
"|"
)[
rank
]
self
.
model_runner
:
CPUModelRunner
=
CPUModelRunner
(
ModelRunnerClass
:
Type
[
CPUModelRunner
]
=
CPUModelRunner
if
self
.
_is_encoder_decoder_model
():
ModelRunnerClass
=
CPUEncoderDecoderModelRunner
self
.
model_runner
:
CPUModelRunner
=
ModelRunnerClass
(
model_config
,
model_config
,
parallel_config
,
parallel_config
,
scheduler_config
,
scheduler_config
,
...
@@ -205,6 +209,9 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -205,6 +209,9 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
raise
RuntimeError
(
"Profiler is not enabled."
)
raise
RuntimeError
(
"Profiler is not enabled."
)
self
.
profiler
.
stop
()
self
.
profiler
.
stop
()
def
_is_encoder_decoder_model
(
self
):
return
self
.
model_config
.
is_encoder_decoder_model
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
if
self
.
local_omp_cpuid
!=
"all"
:
if
self
.
local_omp_cpuid
!=
"all"
:
ret
=
torch
.
ops
.
_C_utils
.
init_cpu_threads_env
(
self
.
local_omp_cpuid
)
ret
=
torch
.
ops
.
_C_utils
.
init_cpu_threads_env
(
self
.
local_omp_cpuid
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment