Commit c7096363 authored by zhuwenwen's avatar zhuwenwen
Browse files

[fix]修复并行解码多机推理报错

parent 4dc24bc8
...@@ -58,7 +58,6 @@ class AsyncMetricsCollector: ...@@ -58,7 +58,6 @@ class AsyncMetricsCollector:
def __init__(self, def __init__(self,
spec_decode_sampler: SpecDecodeBaseSampler, spec_decode_sampler: SpecDecodeBaseSampler,
rank: int,
timer: Optional[Timer] = None, timer: Optional[Timer] = None,
collect_interval_s: float = 5.0): collect_interval_s: float = 5.0):
self.spec_decode_sampler = spec_decode_sampler self.spec_decode_sampler = spec_decode_sampler
...@@ -71,12 +70,6 @@ class AsyncMetricsCollector: ...@@ -71,12 +70,6 @@ class AsyncMetricsCollector:
self._in_flight_copy: Optional[torch.cuda.Event] = None self._in_flight_copy: Optional[torch.cuda.Event] = None
pin_memory = is_pin_memory_available()
torch.cuda.set_device(rank)
self._aggregate_num_accepted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_emitted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_draft_tokens = 0 self._aggregate_num_draft_tokens = 0
self._rejsample_metrics_collect_interval_s = collect_interval_s self._rejsample_metrics_collect_interval_s = collect_interval_s
...@@ -91,10 +84,17 @@ class AsyncMetricsCollector: ...@@ -91,10 +84,17 @@ class AsyncMetricsCollector:
device_type: Union[torch.device, str] = 'cuda') -> None: device_type: Union[torch.device, str] = 'cuda') -> None:
self._rank = rank self._rank = rank
if isinstance(device_type, torch.device): if isinstance(device_type, torch.device):
torch.cuda.set_device(device_type)
device_type = device_type.type device_type = device_type.type
stream = current_platform.Stream stream = current_platform.Stream
if stream is not None: if stream is not None:
self._copy_stream = stream() self._copy_stream = stream()
pin_memory = is_pin_memory_available()
self._aggregate_num_accepted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_emitted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
def maybe_collect_rejsample_metrics( def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]: self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
......
...@@ -326,8 +326,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase): ...@@ -326,8 +326,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
self._allow_zero_draft_token_step = allow_zero_draft_token_step self._allow_zero_draft_token_step = allow_zero_draft_token_step
self._enable_lm_head_weight_load = enable_lm_head_weight_load self._enable_lm_head_weight_load = enable_lm_head_weight_load
self._metrics = AsyncMetricsCollector( self._metrics = AsyncMetricsCollector(
self.spec_decode_sampler, self.spec_decode_sampler
self.rank
) if metrics_collector is None else metrics_collector ) if metrics_collector is None else metrics_collector
# Tracks the sequence IDs that received a bonus token ID in # Tracks the sequence IDs that received a bonus token ID in
# their last forward pass. Needed only if KV cache is being # their last forward pass. Needed only if KV cache is being
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment