Commit c7cd5a9c authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.7.2-dev-wm' into 'v0.7.2-dev'

[fix]修复并行解码多机推理报错

See merge request dcutoolkit/deeplearing/vllm!107
parents 5e77e9b1 0858880a
...@@ -57,7 +57,6 @@ class AsyncMetricsCollector: ...@@ -57,7 +57,6 @@ class AsyncMetricsCollector:
def __init__(self, def __init__(self,
spec_decode_sampler: SpecDecodeBaseSampler, spec_decode_sampler: SpecDecodeBaseSampler,
rank: int,
timer: Optional[Timer] = None, timer: Optional[Timer] = None,
collect_interval_s: float = 5.0): collect_interval_s: float = 5.0):
self.spec_decode_sampler = spec_decode_sampler self.spec_decode_sampler = spec_decode_sampler
...@@ -70,12 +69,6 @@ class AsyncMetricsCollector: ...@@ -70,12 +69,6 @@ class AsyncMetricsCollector:
self._in_flight_copy: Optional[torch.cuda.Event] = None self._in_flight_copy: Optional[torch.cuda.Event] = None
pin_memory = is_pin_memory_available()
torch.cuda.set_device(rank)
self._aggregate_num_accepted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_emitted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_draft_tokens = 0 self._aggregate_num_draft_tokens = 0
self._rejsample_metrics_collect_interval_s = collect_interval_s self._rejsample_metrics_collect_interval_s = collect_interval_s
...@@ -90,10 +83,17 @@ class AsyncMetricsCollector: ...@@ -90,10 +83,17 @@ class AsyncMetricsCollector:
device_type: Union[torch.device, str] = 'cuda') -> None: device_type: Union[torch.device, str] = 'cuda') -> None:
self._rank = rank self._rank = rank
if isinstance(device_type, torch.device): if isinstance(device_type, torch.device):
torch.cuda.set_device(device_type)
device_type = device_type.type device_type = device_type.type
if device_type == 'cuda': if device_type == 'cuda':
self._copy_stream = torch.cuda.Stream() self._copy_stream = torch.cuda.Stream()
pin_memory = is_pin_memory_available()
self._aggregate_num_accepted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_emitted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
def maybe_collect_rejsample_metrics( def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]: self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# currently using cuda.Event, skip for any non_cuda_alike platform # currently using cuda.Event, skip for any non_cuda_alike platform
......
...@@ -315,8 +315,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -315,8 +315,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.spec_decode_sampler = spec_decode_sampler self.spec_decode_sampler = spec_decode_sampler
self._allow_zero_draft_token_step = allow_zero_draft_token_step self._allow_zero_draft_token_step = allow_zero_draft_token_step
self._metrics = AsyncMetricsCollector( self._metrics = AsyncMetricsCollector(
self.spec_decode_sampler, self.spec_decode_sampler
self.rank
) if metrics_collector is None else metrics_collector ) if metrics_collector is None else metrics_collector
# Tracks the sequence IDs that received a bonus token ID in # Tracks the sequence IDs that received a bonus token ID in
# their last forward pass. Needed only if KV cache is being # their last forward pass. Needed only if KV cache is being
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment