Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
57f09a41
Unverified
Commit
57f09a41
authored
Jun 28, 2024
by
Ilya Lavrenov
Committed by
GitHub
Jun 28, 2024
Browse files
[Hardware][Intel] OpenVINO vLLM backend (#5379)
parent
59326344
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
683 additions
and
0 deletions
+683
-0
vllm/worker/openvino_model_runner.py
vllm/worker/openvino_model_runner.py
+330
-0
vllm/worker/openvino_worker.py
vllm/worker/openvino_worker.py
+353
-0
No files found.
vllm/worker/openvino_model_runner.py
0 → 100644
View file @
57f09a41
from
typing
import
List
,
NamedTuple
,
Optional
,
Tuple
import
openvino
as
ov
import
torch
from
torch
import
nn
from
vllm.attention
import
get_attn_backend
from
vllm.attention.backends.openvino
import
OpenVINOAttentionMetadata
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.model_loader.openvino
import
get_model
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
logger
=
init_logger
(
__name__
)
class
ModelInput
(
NamedTuple
):
input_tokens
:
torch
.
Tensor
input_positions
:
torch
.
Tensor
attn_metadata
:
Optional
[
OpenVINOAttentionMetadata
]
seq_lens
:
List
[
int
]
query_lens
:
List
[
int
]
multi_modal_input
:
Optional
[
torch
.
Tensor
]
@
classmethod
def
empty
(
cls
,
device
):
return
ModelInput
(
input_tokens
=
torch
.
empty
(
0
,
device
=
device
),
input_positions
=
torch
.
empty
(
0
,
device
=
device
),
attn_metadata
=
None
,
seq_lens
=
[],
query_lens
=
[],
multi_modal_input
=
None
)
class
OpenVINOModelRunner
:
def
__init__
(
self
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
cache_config
:
CacheConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
*
args
,
**
kwargs
,
):
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
self
.
vision_language_config
=
vision_language_config
self
.
load_config
=
load_config
self
.
is_driver_worker
=
is_driver_worker
self
.
device
=
self
.
device_config
.
device
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
block_size
=
cache_config
.
block_size
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_num_attention_heads
(
self
.
parallel_config
),
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
),
self
.
model_config
.
get_sliding_window
(),
self
.
model_config
.
dtype
,
self
.
kv_cache_dtype
,
self
.
block_size
,
)
# Lazy initialization.
self
.
model
:
nn
.
Module
# Set after init_Model
def
load_model
(
self
)
->
None
:
self
.
model
=
get_model
(
model_config
=
self
.
model_config
,
device_config
=
self
.
device_config
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
def
_prepare_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
ModelInput
:
"""Prepare the model input based on a given sequence group.
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
The result tensors and data structure also batches input in prefill
-> decode order. For example,
- input_tokens[:num_prefill_tokens] contains prefill tokens.
- input_tokens[num_prefill_tokens:] contains decode tokens.
"""
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
seq_lens
:
List
[
int
]
=
[]
past_lens
:
List
[
int
]
=
[]
query_lens
:
List
[
int
]
=
[]
subsequence_begins
:
List
[
int
]
=
[]
block_indices
:
List
[
int
]
=
[]
block_indices_begins
:
List
[
int
]
=
[]
# initialize beginning of prefix sums
subsequence_begins
.
append
(
0
)
block_indices_begins
.
append
(
0
)
if
len
(
seq_group_metadata_list
)
==
0
:
return
ModelInput
.
empty
(
self
.
device
)
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
is_prompt
=
seq_group_metadata
.
is_prompt
for
seq_id
in
seq_ids
:
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
if
(
self
.
scheduler_config
is
not
None
and
self
.
scheduler_config
.
chunked_prefill_enabled
and
not
(
computed_block_nums
is
None
or
computed_block_nums
==
[])):
raise
RuntimeError
(
"chunked prefill cannot be used with prefix caching "
"now."
)
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
if
is_prompt
:
computed_len
=
seq_data
.
get_num_computed_tokens
()
else
:
# get_num_computed_tokens is incorrect for spec decoding.
# So, we should have a special logic here.
# TODO(sang): Fix it.
computed_len
=
seq_data
.
get_len
()
-
1
seq_len
=
min
(
seq_data
.
get_len
(),
computed_len
+
seq_group_metadata
.
token_chunk_size
,
)
if
is_prompt
:
tokens
=
seq_data
.
get_token_ids
()[
computed_len
:
seq_len
]
else
:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens
=
[
seq_data
.
get_last_token_id
()]
# Prefix cache was hit.
# Prefix is not supported with sliding_window
prefix_cache_hit
=
(
computed_block_nums
is
not
None
and
len
(
computed_block_nums
)
>
0
and
self
.
sliding_window
is
None
and
is_prompt
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
if
prefix_cache_hit
:
assert
computed_block_nums
is
not
None
computed_len
=
len
(
computed_block_nums
)
*
self
.
block_size
tokens
=
tokens
[
computed_len
:]
elif
(
self
.
scheduler_config
.
chunked_prefill_enabled
or
not
is_prompt
):
if
seq_group_metadata
.
block_tables
is
not
None
:
# chunked prefill or decode
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
if
self
.
sliding_window
is
not
None
:
# chunked prefill doesn't support sliding window.
assert
not
self
.
scheduler_config
.
chunked_prefill_enabled
# noqa: E501
sliding_window_blocks
=
(
self
.
sliding_window
//
self
.
block_size
)
block_table
=
block_table
[
-
sliding_window_blocks
:]
else
:
# Only happens when memory profiling runs.
block_table
=
[]
else
:
# prompt phase w/o prefix_caching, chunked_prefill
pass
block_indices
.
extend
(
block_table
)
block_indices_begins
.
append
(
block_indices_begins
[
-
1
]
+
len
(
block_table
))
# TODO(sang): This is a hack to make sliding window work with
# paged attn. We can remove it if we make paged attn kernel
# to properly handle slinding window attn.
if
self
.
sliding_window
is
not
None
and
not
is_prompt
:
seq_len
=
min
(
seq_len
,
self
.
sliding_window
)
computed_len
=
seq_len
-
1
seq_lens
.
append
(
seq_len
)
query_len
=
seq_len
-
computed_len
query_lens
.
append
(
query_len
)
input_tokens
.
extend
(
tokens
)
input_positions
.
extend
(
list
(
range
(
computed_len
,
seq_len
)))
past_lens
.
append
(
computed_len
)
subsequence_begins
.
append
(
subsequence_begins
[
-
1
]
+
query_len
)
if
is_prompt
:
assert
len
(
seq_ids
)
==
1
else
:
assert
(
query_len
==
1
),
"seq_len: {}, computed_len: {}, query_len: {}"
.
format
(
seq_len
,
computed_len
,
query_len
)
max_query_len
=
max
(
query_lens
)
assert
max_query_len
>
0
,
"query_lens: {}"
.
format
(
query_lens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
past_lens_tensor
=
torch
.
tensor
(
past_lens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# type: ignore
subsequence_begins_tensor
=
torch
.
tensor
(
subsequence_begins
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# type: ignore
block_indices_tensor
=
torch
.
tensor
(
block_indices
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# type: ignore
block_indices_begins_tensor
=
torch
.
tensor
(
block_indices_begins
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# type: ignore
max_context_len
=
max
(
seq_lens
)
max_context_len_tensor
=
torch
.
tensor
(
max_context_len
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# type: ignore
attn_metadata
=
self
.
attn_backend
.
make_openvino_metadata
(
past_lens
=
past_lens_tensor
,
subsequence_begins
=
subsequence_begins_tensor
,
block_indices
=
block_indices_tensor
,
block_indices_begins
=
block_indices_begins_tensor
,
max_context_len
=
max_context_len_tensor
,
)
return
ModelInput
(
input_tokens
,
input_positions
,
attn_metadata
,
seq_lens
,
query_lens
,
None
,
)
def
prepare_input_tensors
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
OpenVINOAttentionMetadata
,
SamplingMetadata
,
Optional
[
torch
.
Tensor
],
]:
multi_modal_input
=
None
# Prepare input tensors.
(
input_tokens
,
input_positions
,
attn_metadata
,
seq_lens
,
query_lens
,
multi_modal_input
,
)
=
self
.
_prepare_model_input
(
seq_group_metadata_list
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
self
.
device
,
pin_memory
=
False
,
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
,
multi_modal_input
,
)
@
torch
.
inference_mode
()
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
kv_caches
:
List
[
Tuple
[
"ov.Tensor"
,
"ov.Tensor"
]],
)
->
Optional
[
SamplerOutput
]:
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
,
multi_modal_input
,
)
=
self
.
prepare_input_tensors
(
seq_group_metadata_list
)
model_executable
=
self
.
model
execute_model_kwargs
=
{
"input_ids"
:
input_tokens
,
"positions"
:
input_positions
,
"kv_caches"
:
kv_caches
,
"attn_metadata"
:
attn_metadata
,
}
if
self
.
vision_language_config
:
execute_model_kwargs
.
update
({
"image_input"
:
multi_modal_input
})
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
# Compute the logits.
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
# Sample the next token.
output
=
self
.
model
.
sample
(
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
)
return
output
vllm/worker/openvino_worker.py
0 → 100644
View file @
57f09a41
"""An OpenVINO worker class."""
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
openvino
as
ov
import
torch
import
torch.distributed
from
vllm.attention
import
get_attn_backend
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.distributed
import
(
broadcast_tensor_dict
,
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.worker.openvino_model_runner
import
OpenVINOModelRunner
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
logger
=
init_logger
(
__name__
)
class
OpenVINOCacheEngine
:
"""Manages the KV cache for OpenVINO backend.
This class is responsible for initializing and managing CPU KV
caches. It also provides methods for performing KV cache operations, such
as copying.
"""
def
__init__
(
self
,
cache_config
:
CacheConfig
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
device_config
:
DeviceConfig
,
)
->
None
:
assert
device_config
.
device_type
==
"openvino"
self
.
cache_config
=
cache_config
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
head_size
=
model_config
.
get_head_size
()
if
device_config
.
device
.
type
==
"cpu"
and
\
cache_config
.
cache_dtype
==
ov
.
Type
.
u8
:
# Scale, zero point and quantized data will be stored together.
# The layout for per token per head:
# |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501
# so, we have to extend head_size by 8, which is sizeof(float)
# for scale and sizeof(float) for zeropoint
self
.
head_size
+=
8
self
.
num_layers
=
model_config
.
get_num_layers
(
parallel_config
)
self
.
num_kv_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
self
.
block_size
=
cache_config
.
block_size
# Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
# for OpenVINO backend, because we want to reuse KV cache management
# in the scheduler.
self
.
num_cpu_blocks
=
cache_config
.
num_gpu_blocks
# Get attention backend.
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_num_attention_heads
(
self
.
parallel_config
),
self
.
head_size
,
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
),
self
.
model_config
.
get_sliding_window
(),
self
.
model_config
.
dtype
,
self
.
cache_config
.
cache_dtype
,
self
.
block_size
,
)
# Initialize the cache.
self
.
kv_cache
:
List
[
Tuple
[
ov
.
Tensor
,
ov
.
Tensor
]]
=
self
.
_allocate_kv_cache
(
self
.
num_cpu_blocks
)
def
_allocate_kv_cache
(
self
,
num_blocks
:
int
,
)
->
List
[
Tuple
[
ov
.
Tensor
,
ov
.
Tensor
]]:
"""Allocates KV cache."""
k_block_shape
=
v_block_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
,
self
.
block_size
,
self
.
num_kv_heads
,
self
.
head_size
)[
1
:]
kv_cache
:
List
[
Tuple
[
ov
.
Tensor
,
ov
.
Tensor
]]
=
[]
for
_
in
range
(
self
.
num_layers
):
key_blocks
=
ov
.
Tensor
(
self
.
cache_config
.
cache_dtype
,
k_block_shape
)
value_blocks
=
ov
.
Tensor
(
self
.
cache_config
.
cache_dtype
,
v_block_shape
)
kv_cache
.
append
((
key_blocks
,
value_blocks
))
return
kv_cache
def
swap_in
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
raise
NotImplementedError
(
"Swap is not supported in OpenVINOCacheEngine."
)
def
swap_out
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
raise
NotImplementedError
(
"Swap is not supported in OpenVINOCacheEngine."
)
def
copy
(
self
,
src_to_dsts
:
Dict
[
int
,
List
[
int
]])
->
None
:
self
.
attn_backend
.
copy_blocks
(
self
.
kv_cache
,
src_to_dsts
)
@
staticmethod
def
get_cache_block_size
(
block_size
:
int
,
cache_dtype
:
ov
.
Type
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
)
->
int
:
head_size
=
model_config
.
get_head_size
()
num_kv_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
num_layers
=
model_config
.
get_num_layers
(
parallel_config
)
if
cache_dtype
==
ov
.
Type
.
u8
:
# Scale, zero point and quantized data will be stored together.
# The layout for per token per head:
# |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501
# so, we have to extend head_size by 8, which is sizeof(float)
# for scale and sizeof(float) for zeropoint
head_size
+=
8
key_cache_block
=
block_size
*
num_kv_heads
*
head_size
value_cache_block
=
key_cache_block
total
=
num_layers
*
(
key_cache_block
+
value_cache_block
)
dtype_size
=
cache_dtype
.
size
return
dtype_size
*
total
class
OpenVINOWorker
(
LoraNotSupportedWorkerBase
):
"""A worker class that executes the model on OpenVINO backend.
Each worker is associated with a single OpenVINO device. The worker is
responsible for maintaining the KV cache and executing the model on the
OpenVINO backend.
"""
def
__init__
(
self
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
cache_config
:
CacheConfig
,
load_config
:
LoadConfig
,
local_rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
vision_language_config
:
Optional
[
VisionLanguageConfig
]
=
None
,
kv_cache_dtype
:
Optional
[
ov
.
Type
]
=
ov
.
Type
.
undefined
,
is_driver_worker
:
bool
=
False
,
)
->
None
:
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
cache_config
=
cache_config
self
.
load_config
=
load_config
self
.
local_rank
=
local_rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
lora_config
=
lora_config
self
.
vision_language_config
=
vision_language_config
self
.
is_driver_worker
=
is_driver_worker
if
self
.
is_driver_worker
:
assert
self
.
rank
==
0
,
"The driver worker must have rank 0."
if
self
.
model_config
.
trust_remote_code
:
# note: lazy import to avoid importing torch before initializing
from
vllm.utils
import
init_cached_hf_modules
init_cached_hf_modules
()
self
.
model_runner
=
OpenVINOModelRunner
(
model_config
,
parallel_config
,
scheduler_config
,
device_config
,
cache_config
,
load_config
=
self
.
load_config
,
lora_config
=
self
.
lora_config
,
vision_language_config
=
self
.
vision_language_config
,
kv_cache_dtype
=
kv_cache_dtype
,
is_driver_worker
=
is_driver_worker
,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self
.
cache_engine
:
OpenVINOCacheEngine
self
.
kv_cache
:
List
[
Tuple
[
ov
.
Tensor
,
ov
.
Tensor
]]
def
init_device
(
self
)
->
None
:
self
.
init_distributed_environment
()
# Set random seed.
set_random_seed
(
self
.
model_config
.
seed
)
def
load_model
(
self
):
self
.
model_runner
.
load_model
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of blocks available for the KV cache.
This determines how many KV blocks can fit into the configured
KV cache space.
Note that since vLLM assumes a block resides on GPU if it can be
modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
This allows us to reuse the scheduler of vLLM without generalizing it
to different devices.
"""
# For OpenVINO backend, the block number will be calculated based on the
# openvino_kvcache_space_bytes.
cache_block_size
=
self
.
get_cache_block_size_bytes
()
num_cpu_blocks
=
int
(
self
.
cache_config
.
openvino_kvcache_space_bytes
//
cache_block_size
)
num_cpu_blocks
=
max
(
num_cpu_blocks
,
0
)
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
num_gpu_blocks
=
num_cpu_blocks
num_cpu_blocks
=
0
return
num_gpu_blocks
,
num_cpu_blocks
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
"""Initialize the KV cache. Currently, swappable CPU memory is not
supported.
Since this worker does not support GPUs, we use the num_gpu_blocks to
determine how many non-swappable CPU blocks to allocate.
"""
assert
(
num_cpu_blocks
==
0
),
f
"
{
type
(
self
)
}
does not support swappable cache"
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
num_cpu_blocks
=
num_gpu_blocks
self
.
_validate_num_cpu_blocks
(
num_cpu_blocks
)
self
.
cache_config
.
num_gpu_blocks
=
num_cpu_blocks
self
.
cache_config
.
num_cpu_blocks
=
0
# Initialize the cache.
self
.
_init_cache_engine
()
def
_validate_num_cpu_blocks
(
self
,
num_cpu_blocks
:
int
)
->
None
:
"""Raise errors if the num_cpu_blocks is invalid."""
if
num_cpu_blocks
<=
0
:
raise
ValueError
(
"No available memory for the cache blocks. "
"Try increasing `VLLM_OPENVINO_KVCACHE_SPACE` when "
"initializing the engine."
)
max_seq_len
=
self
.
cache_config
.
block_size
*
num_cpu_blocks
if
self
.
model_config
.
max_model_len
>
max_seq_len
:
raise
ValueError
(
f
"The model's max seq len (
{
self
.
model_config
.
max_model_len
}
) "
"is larger than the maximum number of tokens that can be "
f
"stored in KV cache (
{
max_seq_len
}
). Try increasing "
"`VLLM_OPENVINO_KVCACHE_SPACE` or decreasing `max_model_len` "
"when initializing the engine."
)
def
_init_cache_engine
(
self
)
->
None
:
self
.
cache_engine
=
OpenVINOCacheEngine
(
self
.
cache_config
,
self
.
model_config
,
self
.
parallel_config
,
self
.
device_config
,
)
self
.
kv_cache
=
self
.
cache_engine
.
kv_cache
self
.
model_runner
.
block_size
=
self
.
cache_engine
.
block_size
assert
self
.
kv_cache
is
not
None
# Populate the cache to warmup the memory
for
key_cache
,
value_cache
in
self
.
kv_cache
:
key_cache
.
data
[:]
=
0
value_cache
.
data
[:]
=
0
def
cache_copy
(
self
,
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]],
)
->
None
:
self
.
cache_engine
.
copy
(
blocks_to_copy
)
# type: ignore
@
torch
.
inference_mode
()
def
execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
,
)
->
List
[
SamplerOutput
]:
if
execute_model_req
is
None
:
seq_group_metadata_list
=
None
else
:
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
if
self
.
is_driver_worker
:
assert
seq_group_metadata_list
is
not
None
num_seq_groups
:
int
=
len
(
seq_group_metadata_list
)
assert
execute_model_req
is
not
None
blocks_to_copy
=
execute_model_req
.
blocks_to_copy
assert
len
(
execute_model_req
.
blocks_to_swap_in
)
==
0
assert
len
(
execute_model_req
.
blocks_to_swap_out
)
==
0
data
:
Dict
[
str
,
Any
]
=
{
"num_seq_groups"
:
num_seq_groups
,
"blocks_to_copy"
:
execute_model_req
.
blocks_to_copy
,
}
broadcast_tensor_dict
(
data
,
src
=
0
)
else
:
data
=
broadcast_tensor_dict
(
src
=
0
)
num_seq_groups
=
data
[
"num_seq_groups"
]
blocks_to_copy
=
data
[
"blocks_to_copy"
]
self
.
cache_copy
(
blocks_to_copy
)
# If there is no input, we don't need to execute the model.
if
num_seq_groups
==
0
:
return
[]
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
self
.
kv_cache
)
# OpenVINO worker only supports single-step execution.
return
[
output
]
def
init_distributed_environment
(
self
)
->
None
:
"""Initialize the distributed environment."""
parallel_config
=
self
.
parallel_config
rank
=
self
.
rank
distributed_init_method
=
self
.
distributed_init_method
init_distributed_environment
(
world_size
=
parallel_config
.
world_size
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
backend
=
"gloo"
,
)
# A small all_reduce for warmup.
torch
.
distributed
.
all_reduce
(
torch
.
zeros
(
1
).
cpu
())
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
,
)
def
get_cache_block_size_bytes
(
self
)
->
int
:
"""Return the size in bytes of a single KV cache block."""
return
OpenVINOCacheEngine
.
get_cache_block_size
(
self
.
cache_config
.
block_size
,
self
.
cache_config
.
cache_dtype
,
self
.
model_config
,
self
.
parallel_config
,
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment