Unverified Commit 45cd1001 authored by eqy's avatar eqy Committed by GitHub
Browse files

Fix missing `dtype` in `recv_forward` (#1276)

CC @crcrpar @ptrblck
parent c4e85f7b
...@@ -227,7 +227,7 @@ def forward_backward_pipelining_without_interleaving( ...@@ -227,7 +227,7 @@ def forward_backward_pipelining_without_interleaving(
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
_logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}") _logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}")
_logger.debug("receive fwd") _logger.debug("receive fwd")
input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes) input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype)
cur_microbatch = get_kth_microbatch(batch, i) cur_microbatch = get_kth_microbatch(batch, i)
output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced) output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced)
_logger.debug("send fwd") _logger.debug("send fwd")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment