Commit 5e77e9b1 authored by xiabo's avatar xiabo
Browse files

需改mtp 16卡性能差问题

parent 3812059e
......@@ -57,6 +57,7 @@ class AsyncMetricsCollector:
def __init__(self,
spec_decode_sampler: SpecDecodeBaseSampler,
rank: int,
timer: Optional[Timer] = None,
collect_interval_s: float = 5.0):
self.spec_decode_sampler = spec_decode_sampler
......@@ -70,6 +71,7 @@ class AsyncMetricsCollector:
self._in_flight_copy: Optional[torch.cuda.Event] = None
pin_memory = is_pin_memory_available()
torch.cuda.set_device(rank)
self._aggregate_num_accepted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_emitted_tokens = torch.tensor(
......
......@@ -315,7 +315,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.spec_decode_sampler = spec_decode_sampler
self._allow_zero_draft_token_step = allow_zero_draft_token_step
self._metrics = AsyncMetricsCollector(
self.spec_decode_sampler
self.spec_decode_sampler,
self.rank
) if metrics_collector is None else metrics_collector
# Tracks the sequence IDs that received a bonus token ID in
# their last forward pass. Needed only if KV cache is being
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment