Unverified Commit b0f65354 authored by Jaemin Choi's avatar Jaemin Choi Committed by GitHub
Browse files

Disable UB bulk wgrad when weights are frozen (#702)


Signed-off-by: default avatarJaemin Choi <jaeminc@nvidia.com>
Co-authored-by: default avatarJaemin Choi <jaeminc@nvidia.com>
parent 3f8baf95
...@@ -355,7 +355,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -355,7 +355,7 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group) tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1: if tp_world_size == 1 or not weight.requires_grad:
ctx.ub_bulk_wgrad = False ctx.ub_bulk_wgrad = False
# Column Parallel Linear # Column Parallel Linear
......
...@@ -604,7 +604,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -604,7 +604,7 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group) tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1: if tp_world_size == 1 or not fc1_weight.requires_grad:
ctx.ub_bulk_wgrad = False ctx.ub_bulk_wgrad = False
# Column Parallel Linear # Column Parallel Linear
# Overlap input AG with dgrad # Overlap input AG with dgrad
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment