Commit 0858880a authored by 王敏's avatar 王敏
Browse files

[fix]修复并行解码多机推理报错

parent 5e77e9b1
......@@ -57,7 +57,6 @@ class AsyncMetricsCollector:
def __init__(self,
spec_decode_sampler: SpecDecodeBaseSampler,
rank: int,
timer: Optional[Timer] = None,
collect_interval_s: float = 5.0):
self.spec_decode_sampler = spec_decode_sampler
......@@ -70,12 +69,6 @@ class AsyncMetricsCollector:
self._in_flight_copy: Optional[torch.cuda.Event] = None
pin_memory = is_pin_memory_available()
torch.cuda.set_device(rank)
self._aggregate_num_accepted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_emitted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_draft_tokens = 0
self._rejsample_metrics_collect_interval_s = collect_interval_s
......@@ -90,10 +83,17 @@ class AsyncMetricsCollector:
device_type: Union[torch.device, str] = 'cuda') -> None:
self._rank = rank
if isinstance(device_type, torch.device):
torch.cuda.set_device(device_type)
device_type = device_type.type
if device_type == 'cuda':
self._copy_stream = torch.cuda.Stream()
pin_memory = is_pin_memory_available()
self._aggregate_num_accepted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_emitted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# currently using cuda.Event, skip for any non_cuda_alike platform
......
......@@ -315,8 +315,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.spec_decode_sampler = spec_decode_sampler
self._allow_zero_draft_token_step = allow_zero_draft_token_step
self._metrics = AsyncMetricsCollector(
self.spec_decode_sampler,
self.rank
self.spec_decode_sampler
) if metrics_collector is None else metrics_collector
# Tracks the sequence IDs that received a bonus token ID in
# their last forward pass. Needed only if KV cache is being
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment