Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9c1baa5b
Unverified
Commit
9c1baa5b
authored
May 23, 2025
by
Shanshan Shen
Committed by
GitHub
May 23, 2025
Browse files
[Misc] Replace `cuda` hard code with `current_platform` (#16983)
Signed-off-by:
shen-shanshan
<
467638484@qq.com
>
parent
4be2255c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
7 deletions
+11
-7
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+3
-2
vllm/forward_context.py
vllm/forward_context.py
+4
-1
vllm/spec_decode/metrics.py
vllm/spec_decode/metrics.py
+4
-4
No files found.
vllm/distributed/parallel_state.py
View file @
9c1baa5b
...
...
@@ -1221,8 +1221,9 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
ray
.
shutdown
()
gc
.
collect
()
from
vllm.platforms
import
current_platform
if
not
current_platform
.
is_cpu
():
torch
.
cuda
.
empty_cache
()
empty_cache
=
current_platform
.
empty_cache
if
empty_cache
is
not
None
:
empty_cache
()
try
:
torch
.
_C
.
_host_emptyCache
()
except
AttributeError
:
...
...
vllm/forward_context.py
View file @
9c1baa5b
...
...
@@ -120,7 +120,10 @@ def set_forward_context(attn_metadata: Any,
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# scheduling of the next batch
torch
.
cuda
.
synchronize
()
from
vllm.platforms
import
current_platform
synchronize
=
current_platform
.
synchronize
if
synchronize
is
not
None
:
synchronize
()
now
=
time
.
perf_counter
()
# time measurement is in milliseconds
batchsize_forward_time
[
batchsize
].
append
(
...
...
vllm/spec_decode/metrics.py
View file @
9c1baa5b
...
...
@@ -126,12 +126,12 @@ class AsyncMetricsCollector:
"""Copy rejection/typical-acceptance sampling metrics
(number of accepted tokens, etc) to CPU asynchronously.
Returns a
CUDA
event recording when the copy is complete.
Returns a
device
event recording when the copy is complete.
"""
assert
self
.
_copy_stream
is
not
None
self
.
_copy_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
self
.
_copy_stream
.
wait_stream
(
current_platform
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
_copy_stream
):
with
current_platform
.
stream
(
self
.
_copy_stream
):
self
.
_aggregate_num_accepted_tokens
.
copy_
(
self
.
spec_decode_sampler
.
num_accepted_tokens
,
non_blocking
=
True
)
...
...
@@ -142,7 +142,7 @@ class AsyncMetricsCollector:
self
.
_aggregate_num_draft_tokens
=
(
self
.
spec_decode_sampler
.
num_draft_tokens
)
aggregate_metrics_ready
=
torch
.
cuda
.
Event
()
aggregate_metrics_ready
=
current_platform
.
Event
()
aggregate_metrics_ready
.
record
(
self
.
_copy_stream
)
return
aggregate_metrics_ready
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment