Unverified Commit 40bc2425 authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[Bugfix] Fix OpenVino/Neuron `driver_worker` init (#10779)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
Signed-off-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent 661175bc
...@@ -29,11 +29,13 @@ class NeuronExecutor(ExecutorBase): ...@@ -29,11 +29,13 @@ class NeuronExecutor(ExecutorBase):
wrapper = WorkerWrapperBase(vllm_config=self.vllm_config) wrapper = WorkerWrapperBase(vllm_config=self.vllm_config)
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port()) get_ip(), get_open_port())
self.driver_worker = wrapper.init_worker( wrapper.init_worker(
vllm_config=self.vllm_config, vllm_config=self.vllm_config,
local_rank=0, local_rank=0,
rank=0, rank=0,
distributed_init_method=distributed_init_method) distributed_init_method=distributed_init_method,
)
self.driver_worker = wrapper.worker
self.driver_worker.init_device() self.driver_worker.init_device()
self.driver_worker.load_model() self.driver_worker.load_model()
......
...@@ -36,7 +36,7 @@ class OpenVINOExecutor(ExecutorBase): ...@@ -36,7 +36,7 @@ class OpenVINOExecutor(ExecutorBase):
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port()) get_ip(), get_open_port())
self.driver_worker = wrapper.init_worker( wrapper.init_worker(
ov_core=ov.Core(), ov_core=ov.Core(),
vllm_config=self.vllm_config, vllm_config=self.vllm_config,
local_rank=0, local_rank=0,
...@@ -45,6 +45,7 @@ class OpenVINOExecutor(ExecutorBase): ...@@ -45,6 +45,7 @@ class OpenVINOExecutor(ExecutorBase):
kv_cache_dtype=self.cache_config.cache_dtype, kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True, is_driver_worker=True,
) )
self.driver_worker = wrapper.worker
self.driver_worker.init_device() self.driver_worker.init_device()
self.driver_worker.load_model() self.driver_worker.load_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment