Unverified Commit 8fac3a72 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

Fix contrib fused_adam to work correctly with multi-GPU (#752)



The cuda kernel used by fused-adam was using the default stream
on the default device. The kernel needs use the same device as
the parameter tensor.

Fixed by using context manager to set correct default device. For
the use_mt case, raised an error. Alternatively, the use_mt
case could launch one kernel per cuda device.

The non-contrib version will also need to be fixed.
Co-authored-by: default avatarMandeep Singh Baines <msb@fb.com>
parent 80b90b9d
......@@ -130,6 +130,7 @@ class FusedAdam(torch.optim.Optimizer):
tensorlists = [[],[],[],[],[]]
else:
tensorlists = [[],[],[],[]]
tensordevice = None
for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group):
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
......@@ -163,7 +164,14 @@ class FusedAdam(torch.optim.Optimizer):
for tl, t in zip(tensorlists, pl):
tl.append(t)
if tensordevice is None:
tensordevice = p.device
elif tensordevice != p.device:
raise RuntimeError('FusedAdam does not support use_mt with tensors on multiple device')
else:
with torch.cuda.device(p.device):
fused_adam_cuda.adam(p.data,
out_p,
exp_avg,
......@@ -180,6 +188,7 @@ class FusedAdam(torch.optim.Optimizer):
group['weight_decay'])
if self._use_multi_tensor:
with torch.cuda.device(tensordevice):
multi_tensor_applier(
fused_adam_cuda.adam_mt,
self._overflow_buf,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment