"docs/vscode:/vscode.git/clone" did not exist on "fecae12cd7deb969dcbba37fda9d2d234697a944"
Unverified Commit 9c1baa5b authored by Shanshan Shen's avatar Shanshan Shen Committed by GitHub
Browse files

[Misc] Replace `cuda` hard code with `current_platform` (#16983)


Signed-off-by: default avatarshen-shanshan <467638484@qq.com>
parent 4be2255c
...@@ -1221,8 +1221,9 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): ...@@ -1221,8 +1221,9 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
ray.shutdown() ray.shutdown()
gc.collect() gc.collect()
from vllm.platforms import current_platform from vllm.platforms import current_platform
if not current_platform.is_cpu(): empty_cache = current_platform.empty_cache
torch.cuda.empty_cache() if empty_cache is not None:
empty_cache()
try: try:
torch._C._host_emptyCache() torch._C._host_emptyCache()
except AttributeError: except AttributeError:
......
...@@ -120,7 +120,10 @@ def set_forward_context(attn_metadata: Any, ...@@ -120,7 +120,10 @@ def set_forward_context(attn_metadata: Any,
# we use synchronous scheduling right now, # we use synchronous scheduling right now,
# adding a sync point here should not affect # adding a sync point here should not affect
# scheduling of the next batch # scheduling of the next batch
torch.cuda.synchronize() from vllm.platforms import current_platform
synchronize = current_platform.synchronize
if synchronize is not None:
synchronize()
now = time.perf_counter() now = time.perf_counter()
# time measurement is in milliseconds # time measurement is in milliseconds
batchsize_forward_time[batchsize].append( batchsize_forward_time[batchsize].append(
......
...@@ -126,12 +126,12 @@ class AsyncMetricsCollector: ...@@ -126,12 +126,12 @@ class AsyncMetricsCollector:
"""Copy rejection/typical-acceptance sampling metrics """Copy rejection/typical-acceptance sampling metrics
(number of accepted tokens, etc) to CPU asynchronously. (number of accepted tokens, etc) to CPU asynchronously.
Returns a CUDA event recording when the copy is complete. Returns a device event recording when the copy is complete.
""" """
assert self._copy_stream is not None assert self._copy_stream is not None
self._copy_stream.wait_stream(torch.cuda.current_stream()) self._copy_stream.wait_stream(current_platform.current_stream())
with torch.cuda.stream(self._copy_stream): with current_platform.stream(self._copy_stream):
self._aggregate_num_accepted_tokens.copy_( self._aggregate_num_accepted_tokens.copy_(
self.spec_decode_sampler.num_accepted_tokens, self.spec_decode_sampler.num_accepted_tokens,
non_blocking=True) non_blocking=True)
...@@ -142,7 +142,7 @@ class AsyncMetricsCollector: ...@@ -142,7 +142,7 @@ class AsyncMetricsCollector:
self._aggregate_num_draft_tokens = ( self._aggregate_num_draft_tokens = (
self.spec_decode_sampler.num_draft_tokens) self.spec_decode_sampler.num_draft_tokens)
aggregate_metrics_ready = torch.cuda.Event() aggregate_metrics_ready = current_platform.Event()
aggregate_metrics_ready.record(self._copy_stream) aggregate_metrics_ready.record(self._copy_stream)
return aggregate_metrics_ready return aggregate_metrics_ready
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment