Commit 2ef23675 authored by Jared Casper's avatar Jared Casper Committed by Mohammad Shoeybi
Browse files

Support latest PyTorch RNG state API. (#8)

Fixes #7.
parent a0368ddf
......@@ -41,9 +41,26 @@ def _set_cuda_rng_state(new_state, device=-1):
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
def cb():
with device_ctx_manager(device):
_C._cuda_setRNGState(new_state)
if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
# older PyTorch
def cb():
with device_ctx_manager(device):
_C._cuda_setRNGState(new_state)
else:
# newer PyTorch
if device == -1:
device = torch.device('cuda')
elif isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device('cuda', device)
def cb():
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
default_generator.set_state(new_state)
_lazy_call(cb)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment