Commit 5e77e9b1 authored by xiabo's avatar xiabo
Browse files

需改mtp 16卡性能差问题

parent 3812059e
...@@ -57,6 +57,7 @@ class AsyncMetricsCollector: ...@@ -57,6 +57,7 @@ class AsyncMetricsCollector:
def __init__(self, def __init__(self,
spec_decode_sampler: SpecDecodeBaseSampler, spec_decode_sampler: SpecDecodeBaseSampler,
rank: int,
timer: Optional[Timer] = None, timer: Optional[Timer] = None,
collect_interval_s: float = 5.0): collect_interval_s: float = 5.0):
self.spec_decode_sampler = spec_decode_sampler self.spec_decode_sampler = spec_decode_sampler
...@@ -70,6 +71,7 @@ class AsyncMetricsCollector: ...@@ -70,6 +71,7 @@ class AsyncMetricsCollector:
self._in_flight_copy: Optional[torch.cuda.Event] = None self._in_flight_copy: Optional[torch.cuda.Event] = None
pin_memory = is_pin_memory_available() pin_memory = is_pin_memory_available()
torch.cuda.set_device(rank)
self._aggregate_num_accepted_tokens = torch.tensor( self._aggregate_num_accepted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory) 0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_emitted_tokens = torch.tensor( self._aggregate_num_emitted_tokens = torch.tensor(
......
...@@ -315,7 +315,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -315,7 +315,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.spec_decode_sampler = spec_decode_sampler self.spec_decode_sampler = spec_decode_sampler
self._allow_zero_draft_token_step = allow_zero_draft_token_step self._allow_zero_draft_token_step = allow_zero_draft_token_step
self._metrics = AsyncMetricsCollector( self._metrics = AsyncMetricsCollector(
self.spec_decode_sampler self.spec_decode_sampler,
self.rank
) if metrics_collector is None else metrics_collector ) if metrics_collector is None else metrics_collector
# Tracks the sequence IDs that received a bonus token ID in # Tracks the sequence IDs that received a bonus token ID in
# their last forward pass. Needed only if KV cache is being # their last forward pass. Needed only if KV cache is being
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment