"vscode:/vscode.git/clone" did not exist on "728a9eb70ee30b1ab355a98f7e19fb81a0a7b873"
Unverified Commit 8c755c3b authored by Simon Mo's avatar Simon Mo Committed by GitHub
Browse files

[bugfix] spec decode worker get tp group only when initialized (#13578)

parent ba811639
......@@ -12,6 +12,7 @@ from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
from vllm.distributed.communication_op import (broadcast_tensor_dict,
get_tp_group,
tensor_model_parallel_gather)
from vllm.distributed.parallel_state import model_parallel_is_initialized
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.sampler import SamplerOutput
......@@ -366,8 +367,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
target_lm_head_weight)
self._metrics.init_tensors(self.rank, device_type=self.device)
self.spec_decode_sampler.init_tensors(get_tp_group().local_rank,
device_type=self.device)
if model_parallel_is_initialized():
self.spec_decode_sampler.init_tensors(get_tp_group().local_rank,
device_type=self.device)
else:
self.spec_decode_sampler.init_tensors(self.rank,
device_type=self.device)
scorer_cls: Type[SpeculativeScorer]
if self.disable_mqa_scorer:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment