Unverified Commit 3e305f72 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Debug weight matrix usages for dgrad GEMM (#1637)



Make sure that weight matrix has required usages for dgrad GEMM
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent afa1f1b0
......@@ -327,7 +327,6 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out.update_usage(rowwise_usage=False)
# Weight with column-wise usage is needed for dgrad GEMM.
if inp.requires_grad:
if isinstance(weightmat, QuantizedTensor):
weightmat.update_usage(columnwise_usage=True)
......
......@@ -415,7 +415,7 @@ class _LayerNormMLP(torch.autograd.Function):
)
# Weight with column-wise usage is needed for dgrad GEMM.
if is_grad_enabled and inp.requires_grad:
if is_grad_enabled:
if isinstance(fc1_weight_final, QuantizedTensor):
fc1_weight_final.update_usage(columnwise_usage=True)
if isinstance(fc2_weight_final, QuantizedTensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment