Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
bf4a4a21
Commit
bf4a4a21
authored
Jun 18, 2025
by
Shangyan Zhou
Browse files
Set `device_id` to suppress pytorch warning.
parent
77f97f79
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
6 deletions
+11
-6
tests/utils.py
tests/utils.py
+11
-6
No files found.
tests/utils.py
View file @
bf4a4a21
...
...
@@ -14,12 +14,17 @@ def init_dist(local_rank: int, num_local_ranks: int):
node_rank
=
int
(
os
.
getenv
(
'RANK'
,
0
))
assert
(
num_local_ranks
<
8
and
num_nodes
==
1
)
or
num_local_ranks
==
8
dist
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
f
'tcp://
{
ip
}
:
{
port
}
'
,
world_size
=
num_nodes
*
num_local_ranks
,
rank
=
node_rank
*
num_local_ranks
+
local_rank
)
import
inspect
sig
=
inspect
.
signature
(
dist
.
init_process_group
)
params
=
{
'backend'
:
'nccl'
,
'init_method'
:
f
'tcp://
{
ip
}
:
{
port
}
'
,
'world_size'
:
num_nodes
*
num_local_ranks
,
'rank'
:
node_rank
*
num_local_ranks
+
local_rank
,
}
if
'device_id'
in
sig
.
parameters
:
params
[
'device_id'
]
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
dist
.
init_process_group
(
**
params
)
torch
.
set_default_dtype
(
torch
.
bfloat16
)
torch
.
set_default_device
(
'cuda'
)
torch
.
cuda
.
set_device
(
local_rank
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment