"cacheflow/vscode:/vscode.git/clone" did not exist on "655a5e48df3937bf793add53aa95ce0c992a24c6"
Unverified Commit de06a34c authored by Jaemin Choi's avatar Jaemin Choi Committed by GitHub
Browse files

Add NVTX ranges to FP8 amax AR and grad output preprocessing (#1530)



Add NVTX ranges
Signed-off-by: default avatarJaemin Choi <jaeminc@nvidia.com>
Co-authored-by: default avatarJaemin Choi <jaeminc@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 13bd745b
...@@ -522,6 +522,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -522,6 +522,7 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.grad_output_quantizer is not None: if ctx.grad_output_quantizer is not None:
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True) ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True)
nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
( (
grad_output, grad_output,
grad_bias, grad_bias,
...@@ -531,6 +532,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -531,6 +532,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.parallel_mode == "row", ctx.parallel_mode == "row",
ctx.grad_output_quantizer, ctx.grad_output_quantizer,
) )
nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
# Prepare GEMM input # Prepare GEMM input
# Note: Perform tensor-parallel communication if needed # Note: Perform tensor-parallel communication if needed
...@@ -747,7 +749,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -747,7 +749,9 @@ class _LayerNormLinear(torch.autograd.Function):
wgrad = None wgrad = None
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors")
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
# Scatter fp8 weight buffers # Scatter fp8 weight buffers
# if ctx.fp8 and not isinstance(weight, QuantizedTensor): # if ctx.fp8 and not isinstance(weight, QuantizedTensor):
......
...@@ -427,6 +427,7 @@ class _Linear(torch.autograd.Function): ...@@ -427,6 +427,7 @@ class _Linear(torch.autograd.Function):
# Note: Cast to expected dtype and perform tensor-parallel communication # Note: Cast to expected dtype and perform tensor-parallel communication
if ctx.grad_output_quantizer is not None: if ctx.grad_output_quantizer is not None:
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True) ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True)
nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
( (
grad_output, grad_output,
grad_bias, grad_bias,
...@@ -436,6 +437,7 @@ class _Linear(torch.autograd.Function): ...@@ -436,6 +437,7 @@ class _Linear(torch.autograd.Function):
ctx.parallel_mode == "row", ctx.parallel_mode == "row",
ctx.grad_output_quantizer, ctx.grad_output_quantizer,
) )
nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
# Prepare input tensor # Prepare input tensor
# Note: Perform tensor-parallel communication if needed # Note: Perform tensor-parallel communication if needed
...@@ -623,7 +625,9 @@ class _Linear(torch.autograd.Function): ...@@ -623,7 +625,9 @@ class _Linear(torch.autograd.Function):
wgrad = None wgrad = None
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors")
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
# Scatter fp8 weight buffers # Scatter fp8 weight buffers
if ctx.fp8 and not isinstance(weight, QuantizedTensor): if ctx.fp8 and not isinstance(weight, QuantizedTensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment