Fix contrib fused_adam to work correctly with multi-GPU (#752)
The cuda kernel used by fused-adam was using the default stream
on the default device. The kernel needs use the same device as
the parameter tensor.
Fixed by using context manager to set correct default device. For
the use_mt case, raised an error. Alternatively, the use_mt
case could launch one kernel per cuda device.
The non-contrib version will also need to be fixed.
Co-authored-by:
Mandeep Singh Baines <msb@fb.com>
Showing
Please register or sign in to comment