Unverified Commit b0f65354 authored by Jaemin Choi's avatar Jaemin Choi Committed by GitHub
Browse files

Disable UB bulk wgrad when weights are frozen (#702)


Signed-off-by: default avatarJaemin Choi <jaeminc@nvidia.com>
Co-authored-by: default avatarJaemin Choi <jaeminc@nvidia.com>
parent 3f8baf95
......@@ -355,7 +355,7 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.ub_bulk_wgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
if tp_world_size == 1 or not weight.requires_grad:
ctx.ub_bulk_wgrad = False
# Column Parallel Linear
......
......@@ -604,7 +604,7 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.ub_bulk_wgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
if tp_world_size == 1 or not fc1_weight.requires_grad:
ctx.ub_bulk_wgrad = False
# Column Parallel Linear
# Overlap input AG with dgrad
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment