Commit 1b3f4b5b authored by 王敏's avatar 王敏
Browse files

[fix]修复开启并行解码后,打印"Current platform rocm does not have 'Event' attribute"问题

parent 85f3aba3
...@@ -86,10 +86,12 @@ class AsyncMetricsCollector: ...@@ -86,10 +86,12 @@ class AsyncMetricsCollector:
if isinstance(device_type, torch.device): if isinstance(device_type, torch.device):
torch.cuda.set_device(device_type) torch.cuda.set_device(device_type)
device_type = device_type.type device_type = device_type.type
stream = current_platform.Stream # stream = current_platform.Stream
if stream is not None: # if stream is not None:
self._copy_stream = stream() # self._copy_stream = stream()
if device_type == 'cuda':
self._copy_stream = torch.cuda.Stream()
pin_memory = is_pin_memory_available() pin_memory = is_pin_memory_available()
self._aggregate_num_accepted_tokens = torch.tensor( self._aggregate_num_accepted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory) 0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
...@@ -99,8 +101,8 @@ class AsyncMetricsCollector: ...@@ -99,8 +101,8 @@ class AsyncMetricsCollector:
def maybe_collect_rejsample_metrics( def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]: self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# Skip for any platform that doesn't have device Event # Skip for any platform that doesn't have device Event
if current_platform.Event is None: # if current_platform.Event is None:
return None # return None
# If a copy was initiated in the previous call, collect and return. # If a copy was initiated in the previous call, collect and return.
if self._in_flight_copy is not None: if self._in_flight_copy is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment