Unverified Commit de06a34c authored by Jaemin Choi's avatar Jaemin Choi Committed by GitHub
Browse files

Add NVTX ranges to FP8 amax AR and grad output preprocessing (#1530)



Add NVTX ranges
Signed-off-by: default avatarJaemin Choi <jaeminc@nvidia.com>
Co-authored-by: default avatarJaemin Choi <jaeminc@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 13bd745b
......@@ -522,6 +522,7 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.grad_output_quantizer is not None:
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True)
nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
(
grad_output,
grad_bias,
......@@ -531,6 +532,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.parallel_mode == "row",
ctx.grad_output_quantizer,
)
nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
# Prepare GEMM input
# Note: Perform tensor-parallel communication if needed
......@@ -747,7 +749,9 @@ class _LayerNormLinear(torch.autograd.Function):
wgrad = None
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors")
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
# Scatter fp8 weight buffers
# if ctx.fp8 and not isinstance(weight, QuantizedTensor):
......
......@@ -427,6 +427,7 @@ class _Linear(torch.autograd.Function):
# Note: Cast to expected dtype and perform tensor-parallel communication
if ctx.grad_output_quantizer is not None:
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True)
nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
(
grad_output,
grad_bias,
......@@ -436,6 +437,7 @@ class _Linear(torch.autograd.Function):
ctx.parallel_mode == "row",
ctx.grad_output_quantizer,
)
nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
# Prepare input tensor
# Note: Perform tensor-parallel communication if needed
......@@ -623,7 +625,9 @@ class _Linear(torch.autograd.Function):
wgrad = None
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors")
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
# Scatter fp8 weight buffers
if ctx.fp8 and not isinstance(weight, QuantizedTensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment