Commit 9ff6f473 authored by mshoeybi's avatar mshoeybi Committed by Deepak Narayanan
Browse files

Allocate tensor in `communicate()` method directly on GPU (instead of...

Allocate tensor in `communicate()` method directly on GPU (instead of allocating on CPU and then moving to GPU)
parent fdcaeba0
...@@ -249,11 +249,13 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward) ...@@ -249,11 +249,13 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward)
if recv_forward: if recv_forward:
tensor_recv_prev = torch.empty(tensor_shape, tensor_recv_prev = torch.empty(tensor_shape,
requires_grad=True, requires_grad=True,
dtype=args.params_dtype).cuda() device=torch.cuda.current_device(),
dtype=args.params_dtype)
if recv_backward: if recv_backward:
tensor_recv_next = torch.empty(tensor_shape, tensor_recv_next = torch.empty(tensor_shape,
requires_grad=True, requires_grad=True,
dtype=args.params_dtype).cuda() device=torch.cuda.current_device(),
dtype=args.params_dtype)
# Send tensors in both the forward and backward directions as appropriate. # Send tensors in both the forward and backward directions as appropriate.
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev, torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment