Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
2ef23675
Commit
2ef23675
authored
Sep 11, 2019
by
Jared Casper
Committed by
Mohammad Shoeybi
Sep 11, 2019
Browse files
Support latest PyTorch RNG state API. (#8)
Fixes #7.
parent
a0368ddf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
3 deletions
+20
-3
mpu/random.py
mpu/random.py
+20
-3
No files found.
mpu/random.py
View file @
2ef23675
...
...
@@ -41,9 +41,26 @@ def _set_cuda_rng_state(new_state, device=-1):
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
def
cb
():
with
device_ctx_manager
(
device
):
_C
.
_cuda_setRNGState
(
new_state
)
if
hasattr
(
_C
,
'_cuda_setRNGState'
)
and
callable
(
_C
.
_cuda_setRNGState
):
# older PyTorch
def
cb
():
with
device_ctx_manager
(
device
):
_C
.
_cuda_setRNGState
(
new_state
)
else
:
# newer PyTorch
if
device
==
-
1
:
device
=
torch
.
device
(
'cuda'
)
elif
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
elif
isinstance
(
device
,
int
):
device
=
torch
.
device
(
'cuda'
,
device
)
def
cb
():
idx
=
device
.
index
if
idx
is
None
:
idx
=
torch
.
cuda
.
current_device
()
default_generator
=
torch
.
cuda
.
default_generators
[
idx
]
default_generator
.
set_state
(
new_state
)
_lazy_call
(
cb
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment