"tests/vscode:/vscode.git/clone" did not exist on "c71e3af73f80140023ee399a982678a8b968b5c5"
Unverified Commit 3e305f72 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Debug weight matrix usages for dgrad GEMM (#1637)



Make sure that weight matrix has required usages for dgrad GEMM
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent afa1f1b0
...@@ -327,9 +327,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -327,9 +327,8 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out.update_usage(rowwise_usage=False) ln_out.update_usage(rowwise_usage=False)
# Weight with column-wise usage is needed for dgrad GEMM. # Weight with column-wise usage is needed for dgrad GEMM.
if inp.requires_grad: if isinstance(weightmat, QuantizedTensor):
if isinstance(weightmat, QuantizedTensor): weightmat.update_usage(columnwise_usage=True)
weightmat.update_usage(columnwise_usage=True)
if cpu_offloading: if cpu_offloading:
if fp8 and weightmat is not None: if fp8 and weightmat is not None:
......
...@@ -415,7 +415,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -415,7 +415,7 @@ class _LayerNormMLP(torch.autograd.Function):
) )
# Weight with column-wise usage is needed for dgrad GEMM. # Weight with column-wise usage is needed for dgrad GEMM.
if is_grad_enabled and inp.requires_grad: if is_grad_enabled:
if isinstance(fc1_weight_final, QuantizedTensor): if isinstance(fc1_weight_final, QuantizedTensor):
fc1_weight_final.update_usage(columnwise_usage=True) fc1_weight_final.update_usage(columnwise_usage=True)
if isinstance(fc2_weight_final, QuantizedTensor): if isinstance(fc2_weight_final, QuantizedTensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment