Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8afca508
Unverified
Commit
8afca508
authored
Apr 12, 2024
by
bigPYJ1151
Committed by
GitHub
Apr 11, 2024
Browse files
[Hardware][Intel] Isolate CPUModelRunner and ModelRunner for better maintenance (#3824)
parent
08ccee1e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
443 additions
and
61 deletions
+443
-61
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+24
-48
vllm/executor/cpu_executor.py
vllm/executor/cpu_executor.py
+10
-0
vllm/utils.py
vllm/utils.py
+0
-1
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+408
-0
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+1
-12
No files found.
vllm/attention/backends/torch_sdpa.py
View file @
8afca508
...
...
@@ -50,20 +50,15 @@ class TorchSDPABackend(AttentionBackend):
@
dataclass
class
TorchSDPAMetadata
(
AttentionMetadataPerStage
,
PagedAttentionMetadata
):
class
TorchSDPAMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
,
AttentionMetadataPerStage
):
"""Metadata for TorchSDPABackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
slot_mapping
:
torch
.
Tensor
prompt_lens
:
Optional
[
List
[
int
]]
prompt_lens_tensor
:
Optional
[
torch
.
Tensor
]
max_subquery_len
:
Optional
[
int
]
=
None
max_prompt_len
:
Optional
[
int
]
=
None
subquery_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
use_cuda_graph
:
bool
=
False
def
__post_init__
(
self
):
# Set during the execution of the first attention op.
...
...
@@ -111,7 +106,7 @@ class TorchSDPABackendImpl(AttentionImpl):
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
[
TorchSDPAMetadata
]
,
attn_metadata
:
TorchSDPAMetadata
,
kv_scale
:
float
,
)
->
torch
.
Tensor
:
"""Forward pass with torch SDPA and PagedAttention.
...
...
@@ -140,51 +135,36 @@ class TorchSDPABackendImpl(AttentionImpl):
attn_metadata
.
kv_cache_dtype
,
kv_scale
)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
output
=
torch
.
empty_like
(
query
)
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[
num_prefill_tokens
:]
# QKV for prefill.
query
=
query
[:
num_prefill_tokens
]
key
=
key
[:
num_prefill_tokens
]
value
=
value
[:
num_prefill_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
if
(
kv_cache
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
):
if
attn_metadata
.
is_prompt
:
if
(
kv_cache
is
None
or
attn_metadata
.
block_tables
.
numel
()
==
0
):
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
if
prefill_me
ta
.
attn_bias
is
None
:
if
attn_metada
ta
.
attn_bias
is
None
:
if
self
.
alibi_slopes
is
not
None
:
att_masks
=
_make_alibi_bias
(
self
.
alibi_slopes
,
query
.
dtype
,
prefill_me
ta
.
prompt_lens
)
# type: ignore
attn_metada
ta
.
prompt_lens
)
# type: ignore
elif
self
.
sliding_window
is
not
None
:
att_masks
=
_make_sliding_window_bias
(
prefill_me
ta
.
prompt_lens
,
self
.
sliding_window
,
attn_metada
ta
.
prompt_lens
,
self
.
sliding_window
,
query
.
dtype
)
# type: ignore
else
:
att_masks
=
[
None
]
*
len
(
prefill_me
ta
.
prompt_lens
)
prefill_me
ta
.
attn_bias
=
att_masks
att_masks
=
[
None
]
*
len
(
attn_metada
ta
.
prompt_lens
)
attn_metada
ta
.
attn_bias
=
att_masks
query
=
query
.
movedim
(
0
,
query
.
dim
()
-
2
)
key
=
key
.
movedim
(
0
,
key
.
dim
()
-
2
)
value
=
value
.
movedim
(
0
,
value
.
dim
()
-
2
)
start
=
0
out
=
torch
.
empty
((
num_tokens
,
self
.
num_heads
,
self
.
head_size
),
dtype
=
query
.
dtype
)
for
prompt_len
,
mask
in
zip
(
prefill_meta
.
prompt_lens
,
prefill_meta
.
attn_bias
):
output
=
torch
.
empty
(
(
num_tokens
,
self
.
num_heads
,
self
.
head_size
),
dtype
=
query
.
dtype
)
for
prompt_len
,
mask
in
zip
(
attn_metadata
.
prompt_lens
,
attn_metadata
.
attn_bias
):
end
=
start
+
prompt_len
sub_out
=
scaled_dot_product_attention
(
query
[:,
start
:
end
,
:],
...
...
@@ -194,32 +174,28 @@ class TorchSDPABackendImpl(AttentionImpl):
dropout_p
=
0.0
,
is_causal
=
not
self
.
need_mask
,
scale
=
self
.
scale
).
movedim
(
query
.
dim
()
-
2
,
0
)
out
[
start
:
end
,
:,
:]
=
sub_out
out
put
[
start
:
end
,
:,
:]
=
sub_out
start
=
end
assert
out
.
shape
==
output
[:
num_prefill_tokens
].
shape
output
[:
num_prefill_tokens
]
=
out
else
:
# prefix-enabled attention
raise
RuntimeError
(
"Torch SDPA backend doesn't support prefix decoding."
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
else
:
# Decoding run.
out
=
PagedAttention
.
forward_decode
(
decode_
query
,
out
put
=
PagedAttention
.
forward_decode
(
query
,
key_cache
,
value_cache
,
decode_me
ta
.
block_tables
,
decode_me
ta
.
context_lens
,
decode_me
ta
.
max_context_len
,
attn_metada
ta
.
block_tables
,
attn_metada
ta
.
context_lens
,
attn_metada
ta
.
max_context_len
,
attn_metadata
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
kv_scale
,
)
assert
out
.
shape
==
output
[
num_prefill_tokens
:].
shape
output
[
num_prefill_tokens
:]
# Reshape the output tensor.
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
...
...
@@ -241,7 +217,7 @@ def _make_alibi_bias(
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
num_heads
=
alibi_slopes
.
shape
[
0
]
bias
=
bias
[
None
,
:].
expand
(
num_heads
,
prompt_len
,
prompt_len
)
bias
=
bias
[
None
,
:].
repeat
(
(
num_heads
,
1
,
1
)
)
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
inf_mask
=
torch
.
empty
(
(
1
,
prompt_len
,
prompt_len
),
...
...
vllm/executor/cpu_executor.py
View file @
8afca508
...
...
@@ -25,6 +25,7 @@ class CPUExecutor(ExecutorBase):
assert
lora_config
is
None
,
"cpu backend doesn't support LoRA"
model_config
=
_verify_and_get_model_config
(
model_config
)
cache_config
=
_verify_and_get_cache_config
(
cache_config
)
scheduler_config
=
_verify_and_get_scheduler_config
(
scheduler_config
)
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
...
...
@@ -116,6 +117,15 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
return
config
def
_verify_and_get_scheduler_config
(
config
:
SchedulerConfig
)
->
SchedulerConfig
:
if
config
.
chunked_prefill_enabled
:
logger
.
warning
(
"Chunked prefill is not supported on CPU, disable it."
)
config
.
chunked_prefill_enabled
=
False
return
config
def
_verify_and_get_cache_config
(
config
:
CacheConfig
)
->
CacheConfig
:
_GB
=
1
<<
30
if
config
.
enable_prefix_caching
:
...
...
vllm/utils.py
View file @
8afca508
...
...
@@ -372,7 +372,6 @@ def is_pin_memory_available() -> bool:
print_warning_once
(
"Pin memory is not supported on Neuron."
)
return
False
elif
is_cpu
():
print_warning_once
(
"Pin memory is not supported on CPU."
)
return
False
return
True
...
...
vllm/worker/cpu_model_runner.py
0 → 100644
View file @
8afca508
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.config
import
(
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.model_loader
import
get_model
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
make_tensor_with_pad
,
maybe_expand_dim
logger
=
init_logger
(
__name__
)
_PAD_SLOT_ID
=
-
1
class
CPUModelRunner
:
def
__init__
(
self
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
*
args
,
**
kwargs
,
):
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
lora_config
=
lora_config
self
.
is_driver_worker
=
is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py.
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
self
.
sliding_window
=
(
model_config
.
get_sliding_window
()
if
model_config
is
not
None
else
None
)
self
.
device_config
=
(
device_config
if
device_config
is
not
None
else
DeviceConfig
())
self
.
device
=
self
.
device_config
.
device
self
.
model
=
None
self
.
block_size
=
None
# Set after initial profiling.
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
dtype
if
model_config
is
not
None
else
None
)
def
load_model
(
self
)
->
None
:
self
.
model
=
get_model
(
self
.
model_config
,
self
.
device_config
,
lora_config
=
self
.
lora_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
)
def
_prepare_prompt
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
List
[
int
]]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
prompt_lens
:
List
[
int
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
assert
len
(
seq_ids
)
==
1
seq_id
=
seq_ids
[
0
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
prompt_tokens
=
seq_data
.
get_token_ids
()
computed_len
=
seq_data
.
get_num_computed_tokens
()
prompt_len
=
len
(
prompt_tokens
)
prompt_lens
.
append
(
prompt_len
)
# Prompt token num
input_tokens
.
extend
(
prompt_tokens
)
# Token ids
# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions
.
extend
(
list
(
range
(
computed_len
,
prompt_len
)))
# Compute the slot mapping.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, prompt_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx
=
0
if
self
.
sliding_window
is
not
None
:
start_idx
=
max
(
0
,
prompt_len
-
self
.
sliding_window
)
for
i
in
range
(
computed_len
,
prompt_len
):
if
i
<
start_idx
:
slot_mapping
.
append
(
_PAD_SLOT_ID
)
continue
block_number
=
block_table
[
i
//
self
.
block_size
]
# type: ignore
block_offset
=
i
%
self
.
block_size
# type: ignore
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
num_prompt_tokens
=
len
(
input_tokens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
True
,
prompt_lens
=
prompt_lens
,
num_prefills
=
len
(
prompt_lens
),
num_prefill_tokens
=
num_prompt_tokens
,
num_decode_tokens
=
0
,
prefill_metadata
=
None
,
decode_metadata
=
None
,
max_context_len
=
None
,
context_lens
=
None
,
block_tables
=
torch
.
tensor
([]),
slot_mapping
=
slot_mapping
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
prompt_lens
,
)
def
_prepare_decode
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
context_lens
:
List
[
int
]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
not
seq_group_metadata
.
is_prompt
assert
seq_group_metadata
.
token_chunk_size
==
1
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_data
.
get_last_token_id
()
input_tokens
.
append
(
generation_token
)
seq_len
=
seq_data
.
get_len
()
position
=
seq_len
-
1
input_positions
.
append
(
position
)
context_len
=
seq_len
if
self
.
sliding_window
is
None
else
min
(
seq_len
,
self
.
sliding_window
)
context_lens
.
append
(
context_len
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_number
=
block_table
[
position
//
self
.
block_size
]
block_offset
=
position
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
if
self
.
sliding_window
is
not
None
:
sliding_window_blocks
=
(
self
.
sliding_window
//
self
.
block_size
)
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_tables
.
append
(
block_table
)
max_context_len
=
max
(
context_lens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
max_block_table_len
=
max
(
len
(
block_table
)
for
block_table
in
block_tables
)
block_tables
=
make_tensor_with_pad
(
block_tables
,
max_len
=
max_block_table_len
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
self
.
device
,
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
slot_mapping
=
slot_mapping
,
prompt_lens
=
None
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
len
(
input_tokens
),
max_context_len
=
max_context_len
,
num_prefills
=
0
,
prefill_metadata
=
None
,
decode_metadata
=
None
,
context_lens
=
context_lens
,
block_tables
=
block_tables
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
)
def
_prepare_sample
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
prompt_lens
:
List
[
int
],
)
->
SamplingMetadata
:
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
selected_token_indices
:
List
[
int
]
=
[]
generators
:
List
[
torch
.
Generator
]
=
[]
selected_token_start_idx
=
0
categorized_sample_indices
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices_start_idx
=
0
categorized_sampled_token_indices_start_idx
=
0
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
sampling_params
=
seq_group_metadata
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
if
seq_group_metadata
.
is_prompt
:
assert
len
(
seq_ids
)
==
1
subquery_len
=
prompt_lens
[
i
]
if
sampling_params
.
prompt_logprobs
is
not
None
:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx
+=
subquery_len
-
1
categorized_sample_indices
[
sampling_params
.
sampling_type
].
append
([
categorized_sample_indices_start_idx
,
categorized_sampled_token_indices_start_idx
])
categorized_sample_indices_start_idx
+=
1
categorized_sampled_token_indices_start_idx
+=
1
if
sampling_params
.
prompt_logprobs
is
not
None
:
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
selected_token_start_idx
+
subquery_len
-
1
))
selected_token_indices
.
append
(
selected_token_start_idx
+
subquery_len
-
1
)
selected_token_start_idx
+=
subquery_len
if
sampling_params
.
seed
is
not
None
:
seq_group_metadata
.
state
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
sampling_params
.
seed
)
else
:
num_seqs
=
len
(
seq_ids
)
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
selected_token_start_idx
+
num_seqs
))
selected_token_start_idx
+=
num_seqs
categorized_sample_indices
[
sampling_params
.
sampling_type
].
extend
(
zip
(
range
(
categorized_sample_indices_start_idx
,
categorized_sample_indices_start_idx
+
num_seqs
),
range
(
categorized_sampled_token_indices_start_idx
,
categorized_sampled_token_indices_start_idx
+
num_seqs
)))
categorized_sample_indices_start_idx
+=
num_seqs
categorized_sampled_token_indices_start_idx
+=
num_seqs
if
sampling_params
.
seed
is
not
None
:
generators
.
append
(
seq_group_metadata
.
state
.
generator
)
selected_token_indices
=
torch
.
tensor
(
selected_token_indices
,
dtype
=
torch
.
long
)
categorized_sample_indices
=
{
t
:
maybe_expand_dim
(
torch
.
tensor
(
seq_ids
,
dtype
=
torch
.
int
),
2
,
2
)
for
t
,
seq_ids
in
categorized_sample_indices
.
items
()
}
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_data
.
update
(
seq_group_metadata
.
seq_data
)
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
seq_groups
,
seq_data
=
seq_data
,
prompt_lens
=
prompt_lens
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
categorized_sample_indices
,
generators
=
generators
,
)
return
sampling_metadata
def
prepare_input_tensors
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
SamplingMetadata
]:
if
self
.
is_driver_worker
:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
# Prepare input tensors.
if
is_prompt
:
(
input_tokens
,
input_positions
,
attn_metadata
,
prompt_lens
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
(
input_tokens
,
input_positions
,
attn_metadata
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
prompt_lens
=
[]
sampling_metadata
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
# Broadcast the metadata.
metadata_dict
=
{
"input_tokens"
:
input_tokens
,
"input_positions"
:
input_positions
,
"selected_token_indices"
:
sampling_metadata
.
selected_token_indices
,
}
metadata_dict
.
update
(
attn_metadata
.
asdict_zerocopy
())
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
else
:
metadata_dict
=
broadcast_tensor_dict
(
src
=
0
)
input_tokens
=
metadata_dict
.
pop
(
"input_tokens"
)
input_positions
=
metadata_dict
.
pop
(
"input_positions"
)
selected_token_indices
=
metadata_dict
.
pop
(
"selected_token_indices"
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
**
metadata_dict
)
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
None
,
seq_data
=
None
,
prompt_lens
=
None
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
None
,
generators
=
None
,
perform_sampling
=
False
,
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
,
)
@
torch
.
inference_mode
()
def
execute_model
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
kv_caches
:
List
[
torch
.
Tensor
],
)
->
Optional
[
SamplerOutput
]:
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
)
=
self
.
prepare_input_tensors
(
seq_group_metadata_list
)
model_executable
=
self
.
model
execute_model_kwargs
=
{
"input_ids"
:
input_tokens
,
"positions"
:
input_positions
,
"kv_caches"
:
kv_caches
,
"attn_metadata"
:
attn_metadata
,
}
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
# Compute the logits.
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
# Only perform sampling in the driver worker.
if
not
sampling_metadata
.
perform_sampling
:
return
None
# Sample the next token.
output
=
self
.
model
.
sample
(
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
)
return
output
vllm/worker/cpu_worker.py
View file @
8afca508
...
...
@@ -12,25 +12,14 @@ from vllm.distributed import (broadcast_tensor_dict,
init_distributed_environment
)
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor.model_loader
import
get_model
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.worker.model_runner
import
ModelRunner
from
vllm.worker.
cpu_
model_runner
import
CPU
ModelRunner
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
logger
=
init_logger
(
__name__
)
class
CPUModelRunner
(
ModelRunner
):
def
load_model
(
self
)
->
None
:
self
.
model
=
get_model
(
self
.
model_config
,
self
.
device_config
,
lora_config
=
self
.
lora_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
)
class
CPUCacheEngine
:
"""Manages the KV cache for CPU backend.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment