"vscode:/vscode.git/clone" did not exist on "b90989358507651a6273a7bdfd7b2c9e7f4e6004"
Unverified Commit 0cb595ad authored by GAOXinyu's avatar GAOXinyu Committed by GitHub
Browse files

[bugfix] handle_x not define when using checkpoint_lvl = 2 (#502)

when using checkpoint_lvl=2, we all_gather_raw(x) without async_op=True.
So we don't need to wait for handle. Just skip.
parent 31920dda
...@@ -435,7 +435,7 @@ class FusedMLPFunc(torch.autograd.Function): ...@@ -435,7 +435,7 @@ class FusedMLPFunc(torch.autograd.Function):
grad_input = None grad_input = None
if ctx.heuristic == -1: if ctx.heuristic == -1:
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
if process_group is not None and sequence_parallel: if process_group is not None and sequence_parallel and checkpoint_lvl != 2:
handle_x.wait() handle_x.wait()
grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad( grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
total_x.reshape(batch_dim, total_x.shape[-1]), total_x.reshape(batch_dim, total_x.shape[-1]),
...@@ -447,7 +447,7 @@ class FusedMLPFunc(torch.autograd.Function): ...@@ -447,7 +447,7 @@ class FusedMLPFunc(torch.autograd.Function):
grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None
else: else:
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
if process_group is not None and sequence_parallel: if process_group is not None and sequence_parallel and checkpoint_lvl != 2:
handle_x.wait() handle_x.wait()
grad_weight1 = F.linear( grad_weight1 = F.linear(
grad_pre_act.t(), total_x.reshape(batch_dim, total_x.shape[-1]).t() grad_pre_act.t(), total_x.reshape(batch_dim, total_x.shape[-1]).t()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment