Commit 2ef23675 authored by Jared Casper's avatar Jared Casper Committed by Mohammad Shoeybi
Browse files

Support latest PyTorch RNG state API. (#8)

Fixes #7.
parent a0368ddf
...@@ -41,9 +41,26 @@ def _set_cuda_rng_state(new_state, device=-1): ...@@ -41,9 +41,26 @@ def _set_cuda_rng_state(new_state, device=-1):
with a single change: the input state is not cloned. Cloning caused with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases. major performance issues for +4 GPU cases.
""" """
if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
# older PyTorch
def cb(): def cb():
with device_ctx_manager(device): with device_ctx_manager(device):
_C._cuda_setRNGState(new_state) _C._cuda_setRNGState(new_state)
else:
# newer PyTorch
if device == -1:
device = torch.device('cuda')
elif isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device('cuda', device)
def cb():
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
default_generator.set_state(new_state)
_lazy_call(cb) _lazy_call(cb)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment