Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4f95ffee
Unverified
Commit
4f95ffee
authored
Oct 07, 2024
by
Isotr0py
Committed by
GitHub
Oct 07, 2024
Browse files
[Hardware][CPU] Cross-attention and Encoder-Decoder models support on CPU backend (#9089)
parent
8c6de96e
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
834 additions
and
287 deletions
+834
-287
.buildkite/run-cpu-test.sh
.buildkite/run-cpu-test.sh
+1
-0
tests/models/encoder_decoder/language/test_bart.py
tests/models/encoder_decoder/language/test_bart.py
+211
-217
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+299
-61
vllm/worker/cpu_enc_dec_model_runner.py
vllm/worker/cpu_enc_dec_model_runner.py
+311
-0
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+3
-7
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+9
-2
No files found.
.buildkite/run-cpu-test.sh
View file @
4f95ffee
...
@@ -23,6 +23,7 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
...
@@ -23,6 +23,7 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
# Run basic model test
# Run basic model test
docker
exec
cpu-test bash
-c
"
docker
exec
cpu-test bash
-c
"
pip install pytest matplotlib einops transformers_stream_generator datamodel_code_generator
pip install pytest matplotlib einops transformers_stream_generator datamodel_code_generator
pytest -v -s tests/models/encoder_decoder/language
pytest -v -s tests/models/decoder_only/language
\
pytest -v -s tests/models/decoder_only/language
\
--ignore=tests/models/test_fp8.py
\
--ignore=tests/models/test_fp8.py
\
--ignore=tests/models/decoder_only/language/test_jamba.py
\
--ignore=tests/models/decoder_only/language/test_jamba.py
\
...
...
tests/models/encoder_decoder/language/test_bart.py
View file @
4f95ffee
...
@@ -4,29 +4,23 @@ Run `pytest tests/models/encoder_decoder/language/test_bart.py`.
...
@@ -4,29 +4,23 @@ Run `pytest tests/models/encoder_decoder/language/test_bart.py`.
"""
"""
from
typing
import
List
,
Optional
,
Tuple
,
Type
from
typing
import
List
,
Optional
,
Tuple
,
Type
from
vllm.utils
import
is_cpu
import
pytest
from
transformers
import
AutoModelForSeq2SeqLM
if
not
is_cpu
():
from
vllm.sequence
import
SampleLogprobs
# CPU backend is not currently supported with encoder/decoder models
# skip test definitions entirely to avoid importing GPU kernel libs
# (xFormers, etc.)
import
pytest
from
....conftest
import
(
DecoderPromptType
,
ExplicitEncoderDecoderPrompt
,
from
transformers
import
AutoModelForSeq2SeqLM
from
vllm.sequence
import
SampleLogprobs
from
....conftest
import
(
DecoderPromptType
,
ExplicitEncoderDecoderPrompt
,
HfRunner
,
VllmRunner
)
HfRunner
,
VllmRunner
)
from
....utils
import
multi_gpu_test
from
....utils
import
multi_gpu_test
from
...utils
import
check_logprobs_close
from
...utils
import
check_logprobs_close
MODELS
=
[
"facebook/bart-base"
,
"facebook/bart-large-cnn"
]
MODELS
=
[
"facebook/bart-base"
,
"facebook/bart-large-cnn"
]
def
vllm_to_hf_output
(
def
vllm_to_hf_output
(
vllm_output
:
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]],
vllm_output
:
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]],
decoder_prompt_type
:
DecoderPromptType
,
decoder_prompt_type
:
DecoderPromptType
,
):
):
"""Sanitize vllm output to be comparable with hf output."""
"""Sanitize vllm output to be comparable with hf output."""
output_ids
,
output_str
,
out_logprobs
=
vllm_output
output_ids
,
output_str
,
out_logprobs
=
vllm_output
...
@@ -36,7 +30,8 @@ if not is_cpu():
...
@@ -36,7 +30,8 @@ if not is_cpu():
return
output_ids
,
hf_output_str
,
out_logprobs
return
output_ids
,
hf_output_str
,
out_logprobs
def
run_test
(
def
run_test
(
hf_runner
:
Type
[
HfRunner
],
hf_runner
:
Type
[
HfRunner
],
vllm_runner
:
Type
[
VllmRunner
],
vllm_runner
:
Type
[
VllmRunner
],
prompts
:
List
[
ExplicitEncoderDecoderPrompt
[
str
,
str
]],
prompts
:
List
[
ExplicitEncoderDecoderPrompt
[
str
,
str
]],
...
@@ -48,7 +43,7 @@ if not is_cpu():
...
@@ -48,7 +43,7 @@ if not is_cpu():
num_logprobs
:
int
,
num_logprobs
:
int
,
tensor_parallel_size
:
int
,
tensor_parallel_size
:
int
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
'''
'''
Test the vLLM BART model for a variety of encoder/decoder input prompts,
Test the vLLM BART model for a variety of encoder/decoder input prompts,
by validating it against HuggingFace (HF) BART.
by validating it against HuggingFace (HF) BART.
...
@@ -131,8 +126,7 @@ if not is_cpu():
...
@@ -131,8 +126,7 @@ if not is_cpu():
# decoder-only unit tests expect), so when testing an encoder/decoder
# decoder-only unit tests expect), so when testing an encoder/decoder
# model we must explicitly specify enforce_eager=True in the VllmRunner
# model we must explicitly specify enforce_eager=True in the VllmRunner
# constructor.
# constructor.
with
vllm_runner
(
with
vllm_runner
(
model
,
model
,
dtype
=
dtype
,
dtype
=
dtype
,
tensor_parallel_size
=
tensor_parallel_size
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
distributed_executor_backend
=
distributed_executor_backend
,
...
@@ -154,16 +148,15 @@ if not is_cpu():
...
@@ -154,16 +148,15 @@ if not is_cpu():
with
hf_runner
(
model
,
dtype
=
dtype
,
with
hf_runner
(
model
,
dtype
=
dtype
,
auto_cls
=
AutoModelForSeq2SeqLM
)
as
hf_model
:
auto_cls
=
AutoModelForSeq2SeqLM
)
as
hf_model
:
hf_outputs
=
(
hf_outputs
=
(
hf_model
.
generate_encoder_decoder_greedy_logprobs_limit
(
hf_model
.
generate_encoder_decoder_greedy_logprobs_limit
(
prompts
,
prompts
,
max_tokens
,
max_tokens
,
num_logprobs
,
num_logprobs
,
**
hf_kwargs
,
**
hf_kwargs
,
))
))
hf_skip_tokens
=
(
1
if
decoder_prompt_type
==
DecoderPromptType
.
NONE
hf_skip_tokens
=
(
1
else
0
)
if
decoder_prompt_type
==
DecoderPromptType
.
NONE
else
0
)
check_logprobs_close
(
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_0_lst
=
hf_outputs
,
...
@@ -176,14 +169,14 @@ if not is_cpu():
...
@@ -176,14 +169,14 @@ if not is_cpu():
num_outputs_0_skip_tokens
=
hf_skip_tokens
,
num_outputs_0_skip_tokens
=
hf_skip_tokens
,
)
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"
dtype"
,
[
"float"
,
"bfloat16"
]
)
@
pytest
.
mark
.
parametrize
(
"
model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"
max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"
dtype"
,
[
"float"
,
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"
num_logprob
s"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"
max_token
s"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"
decoder_prompt_type"
,
list
(
DecoderPromptType
)
)
@
pytest
.
mark
.
parametrize
(
"
num_logprobs"
,
[
5
]
)
def
test_models
(
hf_runner
,
vllm_runner
,
example_encoder_d
ecoder
_p
rompt
s
,
@
pytest
.
mark
.
parametrize
(
"decoder_prompt_type"
,
list
(
D
ecoder
P
rompt
Type
))
model
,
dtype
,
max_tokens
,
num_logprobs
,
def
test_models
(
hf_runner
,
vllm_runner
,
example_encoder_decoder_prompts
,
model
,
decoder_prompt_type
)
->
None
:
dtype
,
max_tokens
,
num_logprobs
,
decoder_prompt_type
)
->
None
:
run_test
(
run_test
(
hf_runner
,
hf_runner
,
...
@@ -197,14 +190,15 @@ if not is_cpu():
...
@@ -197,14 +190,15 @@ if not is_cpu():
tensor_parallel_size
=
1
,
tensor_parallel_size
=
1
,
)
)
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
"distributed_executor_backend"
,
[
"ray"
,
"mp"
])
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"facebook/bart-large-cnn"
])
@
pytest
.
mark
.
parametrize
(
"distributed_executor_backend"
,
[
"ray"
,
"mp"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"facebook/bart-large-cnn"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"decoder_prompt_type"
,
[
DecoderPromptType
.
CUSTOM
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models_distributed
(
hf_runner
,
vllm_runner
,
@
pytest
.
mark
.
parametrize
(
"decoder_prompt_type"
,
[
DecoderPromptType
.
CUSTOM
])
def
test_models_distributed
(
hf_runner
,
vllm_runner
,
example_encoder_decoder_prompts
,
example_encoder_decoder_prompts
,
distributed_executor_backend
,
model
,
dtype
,
distributed_executor_backend
,
model
,
dtype
,
max_tokens
,
num_logprobs
,
max_tokens
,
num_logprobs
,
...
...
vllm/attention/backends/torch_sdpa.py
View file @
4f95ffee
...
@@ -75,6 +75,22 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -75,6 +75,22 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
slot_mapping
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
seq_lens
:
Optional
[
List
[
int
]]
seq_lens
:
Optional
[
List
[
int
]]
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens
:
Optional
[
List
[
int
]]
=
None
encoder_seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# Maximum sequence length among encoder sequences
max_encoder_seq_len
:
Optional
[
int
]
=
None
# Number of tokens input to encoder
num_encoder_tokens
:
Optional
[
int
]
=
None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping
:
Optional
[
torch
.
Tensor
]
=
None
cross_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Set during the execution of the first attention op.
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# It is a list because it is needed to set per prompt
...
@@ -82,6 +98,28 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -82,6 +98,28 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
# from xformer API.
# from xformer API.
# will not appear in the __repr__ and __init__
# will not appear in the __repr__ and __init__
self
.
attn_bias
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
self
.
attn_bias
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
self
.
encoder_attn_bias
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
self
.
cross_attn_bias
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
@
property
def
is_all_encoder_attn_metadata_set
(
self
):
'''
All attention metadata required for encoder attention is set.
'''
return
((
self
.
encoder_seq_lens
is
not
None
)
and
(
self
.
encoder_seq_lens_tensor
is
not
None
)
and
(
self
.
max_encoder_seq_len
is
not
None
))
@
property
def
is_all_cross_attn_metadata_set
(
self
):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return
(
self
.
is_all_encoder_attn_metadata_set
and
(
self
.
cross_slot_mapping
is
not
None
)
and
(
self
.
cross_block_tables
is
not
None
))
@
property
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"TorchSDPAMetadata"
]:
def
prefill_metadata
(
self
)
->
Optional
[
"TorchSDPAMetadata"
]:
...
@@ -101,6 +139,136 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -101,6 +139,136 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
return
self
return
self
def
get_seq_lens
(
self
,
attn_type
:
AttentionType
,
):
'''
Extract appropriate sequence lengths from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence lengths tensor for query
* Appropriate sequence lengths tensor for key & value
'''
if
attn_type
==
AttentionType
.
DECODER
:
seq_lens_q
=
self
.
seq_lens
seq_lens_kv
=
self
.
seq_lens
elif
attn_type
==
AttentionType
.
ENCODER
:
seq_lens_q
=
self
.
encoder_seq_lens
seq_lens_kv
=
self
.
encoder_seq_lens
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
seq_lens_q
=
self
.
seq_lens
seq_lens_kv
=
self
.
encoder_seq_lens
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
return
seq_lens_q
,
seq_lens_kv
def
get_attn_bias
(
self
,
attn_type
:
AttentionType
,
)
->
Optional
[
List
[
torch
.
Tensor
]]:
'''
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
'''
if
attn_type
==
AttentionType
.
DECODER
:
return
self
.
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER
:
return
self
.
encoder_attn_bias
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
return
self
.
cross_attn_bias
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
def
set_attn_bias
(
self
,
attn_bias
:
List
[
torch
.
Tensor
],
attn_type
:
AttentionType
,
)
->
None
:
'''
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
'''
if
attn_type
==
AttentionType
.
DECODER
:
self
.
attn_bias
=
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER
:
self
.
encoder_attn_bias
=
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
self
.
cross_attn_bias
=
attn_bias
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
def
get_seq_len_block_table_args
(
self
,
attn_type
:
AttentionType
,
)
->
tuple
:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if
attn_type
==
AttentionType
.
DECODER
:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
return
(
self
.
seq_lens_tensor
,
self
.
max_decode_seq_len
,
self
.
block_tables
)
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return
(
self
.
encoder_seq_lens_tensor
,
self
.
max_encoder_seq_len
,
self
.
cross_block_tables
)
elif
attn_type
==
AttentionType
.
ENCODER
:
# No block tables associated with encoder attention
return
(
self
.
encoder_seq_lens_tensor
,
self
.
max_encoder_seq_len
,
None
)
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
class
TorchSDPABackendImpl
(
AttentionImpl
[
TorchSDPAMetadata
]):
class
TorchSDPABackendImpl
(
AttentionImpl
[
TorchSDPAMetadata
]):
...
@@ -171,84 +339,101 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -171,84 +339,101 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
assert
k_scale
==
1.0
and
v_scale
==
1.0
assert
k_scale
==
1.0
and
v_scale
==
1.0
if
attn_type
!=
AttentionType
.
DECODER
:
if
(
attn_type
==
AttentionType
.
ENCODER
raise
NotImplementedError
(
"Encoder self-attention and "
and
(
not
attn_metadata
.
is_all_encoder_attn_metadata_set
)):
"encoder/decoder cross-attention "
raise
AttributeError
(
"Encoder attention requires setting "
"are not implemented for "
"encoder metadata attributes."
)
"TorchSDPABackendImpl"
)
elif
(
attn_type
==
AttentionType
.
ENCODER_DECODER
num_tokens
,
hidden_size
=
query
.
shape
and
(
not
attn_metadata
.
is_all_cross_attn_metadata_set
)):
raise
AttributeError
(
"Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes."
)
# Reshape the query, key, and value tensors.
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
if
key
is
not
None
:
assert
value
is
not
None
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
else
:
if
kv_cache
.
numel
()
>
0
:
assert
value
is
None
if
(
attn_type
!=
AttentionType
.
ENCODER
and
kv_cache
.
numel
()
>
0
):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
if
(
key
is
not
None
)
and
(
value
is
not
None
):
if
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# Update cross-attention KV cache (prefill-only)
# During cross-attention decode, key & value will be None,
# preventing this IF-statement branch from running
updated_slot_mapping
=
attn_metadata
.
cross_slot_mapping
else
:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping
=
attn_metadata
.
slot_mapping
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
updated_slot_mapping
,
self
.
kv_cache_dtype
,
k_scale
,
self
.
kv_cache_dtype
,
v_scale
)
k_scale
,
v_scale
)
if
attn_metadata
.
is_prompt
:
if
attn_type
!=
AttentionType
.
ENCODER
:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
else
:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert
attn_metadata
.
num_encoder_tokens
is
not
None
num_prefill_tokens
=
attn_metadata
.
num_encoder_tokens
num_decode_tokens
=
0
if
attn_type
==
AttentionType
.
DECODER
:
# Only enforce this shape-constraint for decoder
# self-attention
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
assert
attn_metadata
.
seq_lens
is
not
None
assert
attn_metadata
.
seq_lens
is
not
None
if
(
kv_cache
.
numel
()
==
0
if
(
kv_cache
.
numel
()
==
0
or
attn_metadata
.
block_tables
.
numel
()
==
0
):
or
prefill_meta
.
block_tables
.
numel
()
==
0
):
if
self
.
num_kv_heads
!=
self
.
num_heads
:
output
=
self
.
_run_sdpa_forward
(
query
,
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
key
,
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
value
,
dim
=
1
)
prefill_meta
,
attn_type
=
attn_type
)
if
attn_metadata
.
attn_bias
is
None
:
if
self
.
alibi_slopes
is
not
None
:
att_masks
=
_make_alibi_bias
(
self
.
alibi_slopes
,
query
.
dtype
,
attn_metadata
.
seq_lens
)
# type: ignore
elif
self
.
sliding_window
is
not
None
:
att_masks
=
_make_sliding_window_bias
(
attn_metadata
.
seq_lens
,
self
.
sliding_window
,
query
.
dtype
)
# type: ignore
else
:
att_masks
=
[
None
]
*
len
(
attn_metadata
.
seq_lens
)
attn_metadata
.
attn_bias
=
att_masks
query
=
query
.
movedim
(
0
,
query
.
dim
()
-
2
)
key
=
key
.
movedim
(
0
,
key
.
dim
()
-
2
)
value
=
value
.
movedim
(
0
,
value
.
dim
()
-
2
)
start
=
0
output
=
torch
.
empty
(
(
num_tokens
,
self
.
num_heads
,
self
.
head_size
),
dtype
=
query
.
dtype
)
for
seq_len
,
mask
in
zip
(
attn_metadata
.
seq_lens
,
attn_metadata
.
attn_bias
):
end
=
start
+
seq_len
sub_out
=
scaled_dot_product_attention
(
query
[
None
,
:,
start
:
end
,
:],
key
[
None
,
:,
start
:
end
,
:],
value
[
None
,
:,
start
:
end
,
:],
attn_mask
=
mask
,
dropout_p
=
0.0
,
is_causal
=
not
self
.
need_mask
,
scale
=
self
.
scale
).
squeeze
(
0
).
movedim
(
query
.
dim
()
-
2
,
0
)
output
[
start
:
end
,
:,
:]
=
sub_out
start
=
end
else
:
else
:
# prefix-enabled attention
# prefix-enabled attention
raise
RuntimeError
(
raise
RuntimeError
(
"Torch SDPA backend doesn't support prefix decoding."
)
"Torch SDPA backend doesn't support prefix decoding."
)
else
:
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
# Decoding run.
(
seq_lens_arg
,
max_seq_len_arg
,
block_tables_arg
,
)
=
decode_meta
.
get_seq_len_block_table_args
(
attn_type
)
output
=
PagedAttention
.
forward_decode
(
output
=
PagedAttention
.
forward_decode
(
query
,
query
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
attn_metadata
.
block_tables
,
block_tables
_arg
,
attn_metadata
.
seq_lens_
tensor
,
seq_lens_
arg
,
attn_metadata
.
max_decode
_seq_len
,
max
_seq_len
_arg
,
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
scale
,
...
@@ -260,6 +445,59 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -260,6 +445,59 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
# Reshape the output tensor.
# Reshape the output tensor.
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
def
_run_sdpa_forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_metadata
:
TorchSDPAMetadata
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
):
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
attn_masks
=
attn_metadata
.
get_attn_bias
(
attn_type
)
if
attn_masks
is
None
:
if
self
.
alibi_slopes
is
not
None
:
attn_masks
=
_make_alibi_bias
(
self
.
alibi_slopes
,
query
.
dtype
,
attn_metadata
.
seq_lens
)
# type: ignore
elif
self
.
sliding_window
is
not
None
:
assert
attn_metadata
.
seq_lens
is
not
None
attn_masks
=
_make_sliding_window_bias
(
attn_metadata
.
seq_lens
,
self
.
sliding_window
,
query
.
dtype
)
# type: ignore
else
:
seq_lens
,
_
=
attn_metadata
.
get_seq_lens
(
attn_type
)
attn_masks
=
[
None
]
*
len
(
seq_lens
)
attn_metadata
.
set_attn_bias
(
attn_masks
,
attn_type
)
output
=
torch
.
empty_like
(
query
)
query
=
query
.
movedim
(
0
,
query
.
dim
()
-
2
)
key
=
key
.
movedim
(
0
,
key
.
dim
()
-
2
)
value
=
value
.
movedim
(
0
,
value
.
dim
()
-
2
)
causal_attn
=
(
attn_type
==
AttentionType
.
DECODER
)
seq_lens_q
,
seq_lens_kv
=
attn_metadata
.
get_seq_lens
(
attn_type
)
start_q
,
start_kv
=
0
,
0
for
seq_len_q
,
seq_len_kv
,
mask
in
zip
(
seq_lens_q
,
seq_lens_kv
,
attn_masks
):
end_q
=
start_q
+
seq_len_q
end_kv
=
start_kv
+
seq_len_kv
sub_out
=
scaled_dot_product_attention
(
query
[
None
,
:,
start_q
:
end_q
,
:],
key
[
None
,
:,
start_kv
:
end_kv
,
:],
value
[
None
,
:,
start_kv
:
end_kv
,
:],
attn_mask
=
mask
,
dropout_p
=
0.0
,
is_causal
=
causal_attn
and
not
self
.
need_mask
,
scale
=
self
.
scale
).
squeeze
(
0
).
movedim
(
query
.
dim
()
-
2
,
0
)
output
[
start_q
:
end_q
,
:,
:]
=
sub_out
start_q
,
start_kv
=
end_q
,
end_kv
return
output
def
_make_alibi_bias
(
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
alibi_slopes
:
torch
.
Tensor
,
...
...
vllm/worker/cpu_enc_dec_model_runner.py
0 → 100644
View file @
4f95ffee
import
dataclasses
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
cast
import
torch
from
vllm.attention
import
AttentionMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.multimodal
import
MultiModalInputs
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
from
vllm.utils
import
make_tensor_with_pad
from
vllm.worker.cpu_model_runner
import
(
CPUModelRunner
,
ModelInputForCPUBuilder
,
ModelInputForCPUWithSamplingMetadata
)
from
vllm.worker.model_runner_base
import
(
_add_attn_metadata_broadcastable_dict
,
_add_sampling_metadata_broadcastable_dict
)
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
EncoderDecoderModelInputForCPU
(
ModelInputForCPUWithSamplingMetadata
):
"""
Used by the EncoderDecoderModelRunner.
"""
encoder_input_tokens
:
Optional
[
torch
.
Tensor
]
=
None
encoder_input_positions
:
Optional
[
torch
.
Tensor
]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
"input_tokens"
:
self
.
input_tokens
,
"input_positions"
:
self
.
input_positions
,
"encoder_input_tokens"
:
self
.
encoder_input_tokens
,
"encoder_input_positions"
:
self
.
encoder_input_positions
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
_add_sampling_metadata_broadcastable_dict
(
tensor_dict
,
self
.
sampling_metadata
)
return
tensor_dict
@
classmethod
def
from_broadcasted_tensor_dict
(
cls
,
tensor_dict
:
Dict
[
str
,
Any
],
attn_backend
:
Optional
[
"AttentionBackend"
]
=
None
,
)
->
"EncoderDecoderModelInputForCPU"
:
return
cast
(
EncoderDecoderModelInputForCPU
,
super
().
from_broadcasted_tensor_dict
(
tensor_dict
,
attn_backend
))
class
CPUEncoderDecoderModelRunner
(
CPUModelRunner
):
_model_input_cls
:
Type
[
EncoderDecoderModelInputForCPU
]
=
(
EncoderDecoderModelInputForCPU
)
_builder_cls
:
Type
[
ModelInputForCPUBuilder
]
=
ModelInputForCPUBuilder
def
_list_to_int32_tensor
(
self
,
_list
:
List
[
int
],
)
->
torch
.
Tensor
:
return
torch
.
tensor
(
_list
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
def
_list_to_long_tensor
(
self
,
_list
:
List
[
int
],
)
->
torch
.
Tensor
:
return
torch
.
tensor
(
_list
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
def
_empty_int32_tensor
(
self
)
->
torch
.
Tensor
:
return
self
.
_list_to_int32_tensor
([])
def
_empty_long_tensor
(
self
)
->
torch
.
Tensor
:
return
self
.
_list_to_long_tensor
([])
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Any
])
->
EncoderDecoderModelInputForCPU
:
return
EncoderDecoderModelInputForCPU
.
from_broadcasted_tensor_dict
(
tensor_dict
,
attn_backend
=
self
.
attn_backend
,
)
def
prepare_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
EncoderDecoderModelInputForCPU
:
model_input
=
super
().
prepare_model_input
(
seq_group_metadata_list
,
virtual_engine
,
finished_requests_ids
)
model_input
=
cast
(
EncoderDecoderModelInputForCPU
,
model_input
)
(
attn_metadata
,
encoder_input_tokens_tensor
,
encoder_input_positions_tensor
,
)
=
self
.
_prepare_encoder_model_input_tensors
(
seq_group_metadata_list
,
model_input
)
return
dataclasses
.
replace
(
model_input
,
attn_metadata
=
attn_metadata
,
encoder_input_tokens
=
encoder_input_tokens_tensor
,
encoder_input_positions
=
encoder_input_positions_tensor
,
)
def
_prepare_encoder_model_input_tensors
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
model_input
:
EncoderDecoderModelInputForCPU
,
)
->
Tuple
[
AttentionMetadata
,
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
"""Helper method to prepare the encoder- and cross-attn-related
model inputs based on a given sequence group. These additional inputs
are used to augment an already-computed `EncoderDecoderModelInput`
data structure which already has decoder-related model inputs
populated.
Sets the following attn_metadata fields:
* `num_encoder_tokens`
* `encoder_seq_lens`
* `encoder_seq_lens_tensor`
* `max_encoder_seq_len`
* `cross_slot_mapping`
* `cross_block_tables`
Constructs a new model inputs data structure, based on
(1) the existing fields in the `model_inputs` argument,
and (2) the following additional fields which are
computed (or in the case of `attn_metadata`, updated)
by this function:
* attn_metadata
* encoder_input_tokens
* encoder_input_positions
Arguments:
* seq_group_metadata_list: list of sequence groups for which to
compute inputs
* model_inputs: model inputs data structure with decoder-oriented
fields already computed.
Return:
* Updated model inputs data structure
"""
if
len
(
seq_group_metadata_list
)
==
0
:
return
(
model_input
.
attn_metadata
,
None
,
None
)
# Since we are not supporting chunked prefill either the entire
# batch is prefill or it is decode
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
# Build encoder inputs
encoder_seq_lens
:
List
[
int
]
=
[]
if
is_prompt
:
# Prefill phase.
cross_block_tables
=
self
.
_empty_int32_tensor
().
view
(
len
(
seq_group_metadata_list
),
-
1
)
# Extract input tokens/positions, cross-attention slot-mapping,
# & seq len from each sequence group metadata
(
encoder_input_tokens
,
encoder_input_positions
,
cross_slot_mapping
,
)
=
(
[],
[],
[],
)
for
seq_group_metadata
in
seq_group_metadata_list
:
# Build seq lens
seq_len
=
seq_group_metadata
.
encoder_seq_data
.
get_len
()
token_ids
=
seq_group_metadata
.
encoder_seq_data
.
get_token_ids
()
encoder_seq_lens
.
append
(
seq_len
)
# Build slot mapping
for
i
in
range
(
0
,
seq_len
):
block_number
=
seq_group_metadata
.
cross_block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
cross_slot_mapping
.
append
(
slot
)
# Build encoder input tokens
encoder_input_tokens
.
extend
(
token_ids
)
encoder_input_positions
.
extend
(
list
(
range
(
0
,
seq_len
)))
# Convert tokens/positions & cross-attention
# slot-mapping to encoder input tensors
encoder_input_tokens_tensor
=
self
.
_list_to_long_tensor
(
encoder_input_tokens
)
encoder_input_positions_tensor
=
self
.
_list_to_long_tensor
(
encoder_input_positions
)
cross_slot_mapping_tensor
=
self
.
_list_to_long_tensor
(
cross_slot_mapping
)
else
:
# Decode phase.
encoder_input_tokens_tensor
=
self
.
_empty_long_tensor
()
encoder_input_positions_tensor
=
self
.
_empty_long_tensor
()
cross_slot_mapping_tensor
=
self
.
_empty_long_tensor
()
# Extract cross-attention block tables &
# seq len from each sequence group metadata.
# Cross-attention block tables are empty
# during vLLM memory profiling.
cross_block_tables
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
for
_
in
range
(
len
(
seq_group_metadata
.
seq_data
)):
encoder_seq_lens
.
append
(
seq_group_metadata
.
encoder_seq_data
.
get_len
())
cross_block_table
=
seq_group_metadata
.
cross_block_table
cross_block_tables
.
append
([]
if
(
cross_block_table
is
None
)
else
cross_block_table
)
max_len_of_block_table
=
max
(
len
(
block_table
)
for
block_table
in
cross_block_tables
)
cross_block_tables
=
make_tensor_with_pad
(
cross_block_tables
,
max_len
=
max_len_of_block_table
,
pad
=
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
# Compute encoder sequence lengths & encoder
# sequence starting offset tensors
max_encoder_seq_len
=
max
(
encoder_seq_lens
,
default
=
0
)
encoder_seq_lens_tensor
=
self
.
_list_to_int32_tensor
(
encoder_seq_lens
)
encoder_seq_start_loc
=
torch
.
zeros
(
encoder_seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
cumsum
(
encoder_seq_lens_tensor
,
dim
=
0
,
dtype
=
encoder_seq_start_loc
.
dtype
,
out
=
encoder_seq_start_loc
[
1
:])
# Update attention metadata with encoder-oriented attributes
attn_metadata
=
model_input
.
attn_metadata
assert
attn_metadata
is
not
None
(
attn_metadata
.
num_encoder_tokens
,
attn_metadata
.
encoder_seq_lens
,
attn_metadata
.
encoder_seq_lens_tensor
,
attn_metadata
.
max_encoder_seq_len
,
attn_metadata
.
cross_slot_mapping
,
attn_metadata
.
cross_block_tables
,
)
=
(
sum
(
encoder_seq_lens
),
encoder_seq_lens
,
encoder_seq_lens_tensor
,
max_encoder_seq_len
,
cross_slot_mapping_tensor
,
cross_block_tables
,
)
return
(
attn_metadata
,
encoder_input_tokens_tensor
,
encoder_input_positions_tensor
)
@
torch
.
no_grad
()
def
execute_model
(
self
,
model_input
:
EncoderDecoderModelInputForCPU
,
kv_caches
:
List
[
torch
.
Tensor
],
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
)
->
Optional
[
List
[
SamplerOutput
]]:
if
num_steps
>
1
:
raise
ValueError
(
"CPU worker does not support multi-step execution."
)
model_executable
=
self
.
model
execute_model_kwargs
=
{
"input_ids"
:
model_input
.
input_tokens
,
"positions"
:
model_input
.
input_positions
,
"encoder_input_ids"
:
model_input
.
encoder_input_tokens
,
"encoder_positions"
:
model_input
.
encoder_input_positions
,
"kv_caches"
:
kv_caches
,
"attn_metadata"
:
model_input
.
attn_metadata
,
**
MultiModalInputs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
device
=
self
.
device
),
"intermediate_tensors"
:
intermediate_tensors
,
}
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
# Compute the logits.
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
model_input
.
sampling_metadata
)
# Only perform sampling in the driver worker.
if
not
self
.
is_driver_worker
:
return
[]
# Sample the next token.
output
=
self
.
model
.
sample
(
logits
=
logits
,
sampling_metadata
=
model_input
.
sampling_metadata
,
)
return
[
output
]
vllm/worker/cpu_model_runner.py
View file @
4f95ffee
...
@@ -19,7 +19,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
...
@@ -19,7 +19,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs
)
MultiModalInputs
)
from
vllm.sequence
import
(
IntermediateTensors
,
SequenceData
,
from
vllm.sequence
import
(
IntermediateTensors
,
SequenceData
,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
from
vllm.utils
import
STR_NOT_IMPL_ENC_DEC_ERR_STRS
,
make_tensor_with_pad
from
vllm.utils
import
make_tensor_with_pad
from
vllm.worker.model_runner_base
import
(
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
_add_attn_metadata_broadcastable_dict
,
_add_attn_metadata_broadcastable_dict
,
...
@@ -434,10 +434,6 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
...
@@ -434,10 +434,6 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
# Lazy initialization.
# Lazy initialization.
self
.
model
:
nn
.
Module
# Set after init_Model
self
.
model
:
nn
.
Module
# Set after init_Model
if
self
.
model_config
.
is_encoder_decoder_model
:
raise
NotImplementedError
(
STR_NOT_IMPL_ENC_DEC_ERR_STRS
[
'STR_NOT_IMPL_ENC_DEC_CPU'
])
@
property
@
property
def
model_is_mrope
(
self
)
->
bool
:
def
model_is_mrope
(
self
)
->
bool
:
"""Detect if the model has "mrope" rope_scaling type.
"""Detect if the model has "mrope" rope_scaling type.
...
@@ -459,8 +455,8 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
...
@@ -459,8 +455,8 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
def
make_model_input_from_broadcasted_tensor_dict
(
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
self
,
tensor_dict
:
Dict
[
str
,
Any
],
tensor_dict
:
Dict
[
str
,
Any
],
)
->
ModelInputForCPU
:
)
->
ModelInputForCPU
WithSamplingMetadata
:
return
ModelInputForCPU
.
from_broadcasted_tensor_dict
(
return
ModelInputForCPU
WithSamplingMetadata
.
from_broadcasted_tensor_dict
(
# noqa: E501
tensor_dict
,
tensor_dict
,
attn_backend
=
self
.
attn_backend
,
attn_backend
=
self
.
attn_backend
,
)
)
...
...
vllm/worker/cpu_worker.py
View file @
4f95ffee
"""A CPU worker class."""
"""A CPU worker class."""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
import
torch.distributed
import
torch.distributed
...
@@ -15,6 +15,7 @@ from vllm.logger import init_logger
...
@@ -15,6 +15,7 @@ from vllm.logger import init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.worker.cpu_enc_dec_model_runner
import
CPUEncoderDecoderModelRunner
from
vllm.worker.cpu_model_runner
import
CPUModelRunner
from
vllm.worker.cpu_model_runner
import
CPUModelRunner
from
vllm.worker.worker_base
import
(
LocalOrDistributedWorkerBase
,
from
vllm.worker.worker_base
import
(
LocalOrDistributedWorkerBase
,
LoraNotSupportedWorkerBase
,
WorkerInput
)
LoraNotSupportedWorkerBase
,
WorkerInput
)
...
@@ -163,7 +164,10 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -163,7 +164,10 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
else
:
else
:
self
.
local_omp_cpuid
=
omp_cpuids
.
split
(
"|"
)[
rank
]
self
.
local_omp_cpuid
=
omp_cpuids
.
split
(
"|"
)[
rank
]
self
.
model_runner
:
CPUModelRunner
=
CPUModelRunner
(
ModelRunnerClass
:
Type
[
CPUModelRunner
]
=
CPUModelRunner
if
self
.
_is_encoder_decoder_model
():
ModelRunnerClass
=
CPUEncoderDecoderModelRunner
self
.
model_runner
:
CPUModelRunner
=
ModelRunnerClass
(
model_config
,
model_config
,
parallel_config
,
parallel_config
,
scheduler_config
,
scheduler_config
,
...
@@ -205,6 +209,9 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -205,6 +209,9 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
raise
RuntimeError
(
"Profiler is not enabled."
)
raise
RuntimeError
(
"Profiler is not enabled."
)
self
.
profiler
.
stop
()
self
.
profiler
.
stop
()
def
_is_encoder_decoder_model
(
self
):
return
self
.
model_config
.
is_encoder_decoder_model
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
if
self
.
local_omp_cpuid
!=
"all"
:
if
self
.
local_omp_cpuid
!=
"all"
:
ret
=
torch
.
ops
.
_C_utils
.
init_cpu_threads_env
(
self
.
local_omp_cpuid
)
ret
=
torch
.
ops
.
_C_utils
.
init_cpu_threads_env
(
self
.
local_omp_cpuid
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment