Unverified Commit 8c755c3b authored by Simon Mo's avatar Simon Mo Committed by GitHub
Browse files

[bugfix] spec decode worker get tp group only when initialized (#13578)

parent ba811639
...@@ -12,6 +12,7 @@ from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig ...@@ -12,6 +12,7 @@ from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
from vllm.distributed.communication_op import (broadcast_tensor_dict, from vllm.distributed.communication_op import (broadcast_tensor_dict,
get_tp_group, get_tp_group,
tensor_model_parallel_gather) tensor_model_parallel_gather)
from vllm.distributed.parallel_state import model_parallel_is_initialized
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
...@@ -366,8 +367,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -366,8 +367,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
target_lm_head_weight) target_lm_head_weight)
self._metrics.init_tensors(self.rank, device_type=self.device) self._metrics.init_tensors(self.rank, device_type=self.device)
self.spec_decode_sampler.init_tensors(get_tp_group().local_rank, if model_parallel_is_initialized():
device_type=self.device) self.spec_decode_sampler.init_tensors(get_tp_group().local_rank,
device_type=self.device)
else:
self.spec_decode_sampler.init_tensors(self.rank,
device_type=self.device)
scorer_cls: Type[SpeculativeScorer] scorer_cls: Type[SpeculativeScorer]
if self.disable_mqa_scorer: if self.disable_mqa_scorer:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment