Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
894843eb
Unverified
Commit
894843eb
authored
Mar 12, 2026
by
Yan Ma
Committed by
GitHub
Mar 11, 2026
Browse files
replace `with torch.cuda.device` with `with torch.accelerator.device_index` (#36144)
Signed-off-by:
Yan Ma
<
yan.ma@intel.com
>
parent
584a3f56
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
17 additions
and
15 deletions
+17
-15
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+5
-1
tools/pre_commit/check_torch_cuda.py
tools/pre_commit/check_torch_cuda.py
+2
-2
vllm/distributed/device_communicators/pynccl.py
vllm/distributed/device_communicators/pynccl.py
+1
-3
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
...ibuted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
+2
-2
vllm/model_executor/layers/fla/ops/utils.py
vllm/model_executor/layers/fla/ops/utils.py
+1
-1
vllm/model_executor/layers/mamba/ops/layernorm_gated.py
vllm/model_executor/layers/mamba/ops/layernorm_gated.py
+1
-1
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
+1
-1
vllm/model_executor/layers/mamba/ops/ssd_bmm.py
vllm/model_executor/layers/mamba/ops/ssd_bmm.py
+1
-1
vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
+2
-2
vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
+1
-1
No files found.
benchmarks/kernels/benchmark_moe.py
View file @
894843eb
...
...
@@ -626,7 +626,11 @@ class BenchmarkWorker:
if
visible_device
!=
f
"
{
self
.
device_id
}
"
:
need_device_guard
=
True
with
torch
.
cuda
.
device
(
self
.
device_id
)
if
need_device_guard
else
nullcontext
():
with
(
torch
.
accelerator
.
device_index
(
self
.
device_id
)
if
need_device_guard
else
nullcontext
()
):
for
idx
,
config
in
enumerate
(
tqdm
(
search_space
)):
try
:
kernel_time
=
benchmark_config
(
...
...
tools/pre_commit/check_torch_cuda.py
View file @
894843eb
...
...
@@ -8,8 +8,8 @@ import regex as re
# Regex: match `torch.cuda.xxx` but allow `torch.accelerator.xxx`
# --------------------------------------------------------------------------- #
_TORCH_CUDA_PATTERNS
=
[
r
"\btorch\.cuda\.empty_cache\b"
,
r
"\btorch\.cuda\.
synchroniz
e\b"
,
r
"\btorch\.cuda\.
(
empty_cache
|synchronize|device\()
\b"
,
r
"\
bwith\
btorch\.cuda\.
devic
e\b"
,
]
ALLOWED_FILES
=
{
"vllm/platforms/"
,
"vllm/device_allocator/"
}
...
...
vllm/distributed/device_communicators/pynccl.py
View file @
894843eb
...
...
@@ -133,9 +133,7 @@ class PyNcclCommunicator:
assert
isinstance
(
device
,
torch
.
device
)
self
.
device
=
device
# nccl communicator and stream will use this device
# `torch.cuda.device` is a context manager that changes the
# current cuda device to the specified one
with
torch
.
cuda
.
device
(
device
):
with
torch
.
accelerator
.
device_index
(
device
.
index
):
self
.
comm
:
ncclComm_t
=
self
.
nccl
.
ncclCommInitRank
(
self
.
world_size
,
self
.
unique_id
,
self
.
rank
)
...
...
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
View file @
894843eb
...
...
@@ -218,7 +218,7 @@ class P2pNcclEngine:
data
=
{
"cmd"
:
"NEW"
,
"unique_id"
:
bytes
(
unique_id
.
internal
)}
sock
.
send
(
msgpack
.
dumps
(
data
))
with
torch
.
cuda
.
device
(
self
.
device
):
with
torch
.
accelerator
.
device
_index
(
self
.
device
.
index
):
rank
=
0
with
set_p2p_nccl_context
(
self
.
nccl_num_channels
):
comm
:
ncclComm_t
=
self
.
nccl
.
ncclCommInitRank
(
2
,
unique_id
,
rank
)
...
...
@@ -377,7 +377,7 @@ class P2pNcclEngine:
data
=
msgpack
.
loads
(
message
)
if
data
[
"cmd"
]
==
"NEW"
:
unique_id
=
self
.
nccl
.
unique_id_from_bytes
(
bytes
(
data
[
"unique_id"
]))
with
torch
.
cuda
.
device
(
self
.
device
):
with
torch
.
accelerator
.
device
_index
(
self
.
device
.
index
):
rank
=
1
with
set_p2p_nccl_context
(
self
.
nccl_num_channels
):
comm
:
ncclComm_t
=
self
.
nccl
.
ncclCommInitRank
(
...
...
vllm/model_executor/layers/fla/ops/utils.py
View file @
894843eb
...
...
@@ -105,7 +105,7 @@ def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
break
if
tensor
is
not
None
:
ctx
=
torch
.
cuda
.
device
(
tensor
.
device
.
index
)
ctx
=
torch
.
accelerator
.
device
_index
(
tensor
.
device
.
index
)
else
:
ctx
=
contextlib
.
nullcontext
()
...
...
vllm/model_executor/layers/mamba/ops/layernorm_gated.py
View file @
894843eb
...
...
@@ -119,7 +119,7 @@ def _layer_norm_fwd(
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK_N
//
256
,
1
),
8
)
grid
=
(
M
,
ngroups
)
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
with
torch
.
accelerator
.
device
_index
(
x
.
device
.
index
):
_layer_norm_fwd_1pass_kernel
[
grid
](
x
,
out
,
...
...
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
View file @
894843eb
...
...
@@ -419,7 +419,7 @@ def selective_state_update(
and
dt
.
stride
(
-
1
)
==
0
and
dt_bias
.
stride
(
-
1
)
==
0
)
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
with
torch
.
accelerator
.
device
_index
(
x
.
device
.
index
):
_selective_scan_update_kernel
[
grid
](
state
,
x
,
...
...
vllm/model_executor/layers/mamba/ops/ssd_bmm.py
View file @
894843eb
...
...
@@ -185,7 +185,7 @@ def _bmm_chunk_fwd(a, b, chunk_size, cu_chunk_seqlens, causal=False, output_dtyp
*
triton
.
cdiv
(
chunk_size
,
META
[
"BLOCK_SIZE_N"
]),
nchunks
*
ngroups
,
)
with
torch
.
cuda
.
device
(
a
.
device
.
index
):
with
torch
.
accelerator
.
device
_index
(
a
.
device
.
index
):
_bmm_chunk_fwd_kernel
[
grid
](
a_ptr
=
a
,
b_ptr
=
b
,
...
...
vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
View file @
894843eb
...
...
@@ -323,7 +323,7 @@ def _chunk_cumsum_fwd(
nheads
,
nchunks
,
chunk_size
,
device
=
dt
.
device
,
dtype
=
torch
.
float32
)
grid_chunk_cs
=
lambda
META
:
(
nchunks
,
triton
.
cdiv
(
nheads
,
META
[
"BLOCK_SIZE_H"
]))
with
torch
.
cuda
.
device
(
dt
.
device
.
index
):
with
torch
.
accelerator
.
device
_index
(
dt
.
device
.
index
):
_chunk_cumsum_fwd_kernel
[
grid_chunk_cs
](
dt_ptr
=
dt
,
A_ptr
=
A
,
...
...
@@ -378,7 +378,7 @@ def _chunk_state_fwd(
nchunks
,
nheads
,
)
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
with
torch
.
accelerator
.
device
_index
(
x
.
device
.
index
):
_chunk_state_fwd_kernel
[
grid
](
x_ptr
=
x
,
b_ptr
=
B
,
...
...
vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
View file @
894843eb
...
...
@@ -120,7 +120,7 @@ def _state_passing_fwd(
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
dim
,
META
[
"BLOCK_SIZE"
]),
batch
,
nheads
)
with
torch
.
cuda
.
device
(
states
.
device
.
index
):
with
torch
.
accelerator
.
device
_index
(
states
.
device
.
index
):
_state_passing_fwd_kernel
[
grid
](
states_ptr
=
states
,
out_ptr
=
out
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment