Commit 1b3f4b5b authored by 王敏's avatar 王敏
Browse files

[fix]修复开启并行解码后,打印"Current platform rocm does not have 'Event' attribute"问题

parent 85f3aba3
......@@ -86,9 +86,11 @@ class AsyncMetricsCollector:
if isinstance(device_type, torch.device):
torch.cuda.set_device(device_type)
device_type = device_type.type
stream = current_platform.Stream
if stream is not None:
self._copy_stream = stream()
# stream = current_platform.Stream
# if stream is not None:
# self._copy_stream = stream()
if device_type == 'cuda':
self._copy_stream = torch.cuda.Stream()
pin_memory = is_pin_memory_available()
self._aggregate_num_accepted_tokens = torch.tensor(
......@@ -99,8 +101,8 @@ class AsyncMetricsCollector:
def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# Skip for any platform that doesn't have device Event
if current_platform.Event is None:
return None
# if current_platform.Event is None:
# return None
# If a copy was initiated in the previous call, collect and return.
if self._in_flight_copy is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment