Unverified Commit 49a4535d authored by Jaemin Choi's avatar Jaemin Choi Committed by GitHub
Browse files

Add NVTX ranges to categorize execution (#1447)


Signed-off-by: default avatarJaemin Choi <jaeminc@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarJaemin Choi <jaeminc@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent b87e539d
......@@ -23,7 +23,11 @@ import torch.nn.functional as F
import transformer_engine_torch as tex
import transformer_engine as te
from transformer_engine.pytorch.utils import get_cudnn_version
from transformer_engine.pytorch.utils import (
get_cudnn_version,
nvtx_range_pop,
nvtx_range_push,
)
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd,
fused_attn_bwd,
......@@ -1834,6 +1838,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
quantizers,
):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
......@@ -2756,12 +2761,14 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.fp8_meta = fp8_meta
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
return out_ret
@staticmethod
def backward(ctx, dout):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward")
cp_size_a2a = ctx.cp_size_a2a
rank_a2a = ctx.rank_a2a
......@@ -3602,6 +3609,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, ctx.qkv_dtype)
dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, ctx.qkv_dtype)
dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, ctx.qkv_dtype)
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.backward")
return (
None,
......@@ -3688,6 +3696,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
cp_stream,
):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward")
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
......@@ -3904,11 +3913,13 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
ctx.attn_mask_type = attn_mask_type
ctx.deterministic = deterministic
ctx.use_fused_attention = use_fused_attention
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward")
return out
@staticmethod
def backward(ctx, dout):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.backward")
cp_size = get_distributed_world_size(ctx.cp_group)
rank = get_distributed_rank(ctx.cp_group)
......@@ -4092,6 +4103,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :])
dk = dk.movedim(0, seq_dim).contiguous()
dv = dv.movedim(0, seq_dim).contiguous()
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward")
return (
None,
......@@ -4151,6 +4163,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
quantizers,
):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
......@@ -4403,11 +4416,13 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx.fp8_meta = fp8_meta
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
return out_ret
@staticmethod
def backward(ctx, dout):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward")
cp_size = get_distributed_world_size(ctx.cp_group)
(
......@@ -4592,6 +4607,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype)
if not ctx.is_input_fp8:
dq, dk, dv = [x.dequantize() for x in [dq, dk, dv]]
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward")
return (
None,
......
......@@ -24,12 +24,14 @@ from .base import (
)
from ..fp8 import FP8GlobalStateManager
from ..utils import (
assert_dim_for_fp8_exec,
cast_if_needed,
clear_tensor_data,
divide,
get_default_init_method,
init_method_constant,
cast_if_needed,
assert_dim_for_fp8_exec,
clear_tensor_data,
nvtx_range_pop,
nvtx_range_push,
requires_grad,
)
from ..distributed import (
......@@ -112,6 +114,12 @@ class _LayerNormLinear(torch.autograd.Function):
skip_fp8_weight_update: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring
# NVTX label for profiling
nvtx_label = "transformer_engine._LayerNormLinear.forward"
if ub_name is not None:
nvtx_label = f"{nvtx_label}.{ub_name}"
# Make sure input dimensions are compatible
out_features, in_features = weight.shape
inp_shape = inp.shape
......@@ -121,10 +129,12 @@ class _LayerNormLinear(torch.autograd.Function):
assert_dim_for_fp8_exec(inputmat, weight)
# Cast for native AMP
nvtx_range_push(f"{nvtx_label}.norm_input_cast")
inputmat = cast_if_needed(inputmat, activation_dtype)
ln_weight = cast_if_needed(ln_weight, activation_dtype)
if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype)
nvtx_range_pop(f"{nvtx_label}.norm_input_cast")
tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag_fprop = (
......@@ -175,6 +185,7 @@ class _LayerNormLinear(torch.autograd.Function):
)
# Apply normalization
nvtx_range_push(f"{nvtx_label}.norm")
ln_out, mu, rsigma = apply_normalization(
inputmat,
ln_out,
......@@ -188,9 +199,11 @@ class _LayerNormLinear(torch.autograd.Function):
zero_centered_gamma,
)
ln_out_return = ln_out if return_layernorm_output else None
nvtx_range_pop(f"{nvtx_label}.norm")
# Prepare GEMM input
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.gemm_input_cast_comm")
if with_input_all_gather and not ub_overlap_ag_fprop:
with_quantized_all_gather = fp8
if return_layernorm_output and return_layernorm_output_gathered:
......@@ -217,6 +230,7 @@ class _LayerNormLinear(torch.autograd.Function):
elif backward_needs_input:
ln_out.update_usage(rowwise_usage=True, columnwise_usage=True)
ln_out_total = ln_out
nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm")
# Cast weight to expected dtype
weightmat = weight
......@@ -275,6 +289,7 @@ class _LayerNormLinear(torch.autograd.Function):
assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer."
ln_out_total = ub_obj.get_buffer(input_quantizer)
nvtx_range_push(f"{nvtx_label}.gemm")
out, *_, rs_out = general_gemm(
weightmat,
ln_out_total,
......@@ -287,6 +302,8 @@ class _LayerNormLinear(torch.autograd.Function):
ub_type=ub_type,
extra_output=rs_out,
)
nvtx_range_pop(f"{nvtx_label}.gemm")
if not weight.requires_grad:
if not return_layernorm_output:
ln_out = ln_out_total = None
......@@ -307,6 +324,7 @@ class _LayerNormLinear(torch.autograd.Function):
# Scatter intermediate/activation tensors saved for the backward pass
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
# shards/unshards the base weights so we don't do it ourselves
nvtx_range_push(f"{nvtx_label}.fsdp_scatter")
ctx.fsdp_group = fsdp_group
ctx.fsdp_shapes = _fsdp_scatter_tensors(
fsdp_group,
......@@ -315,6 +333,7 @@ class _LayerNormLinear(torch.autograd.Function):
weightmat if quantized_weight else None,
ln_out if weight.requires_grad else None,
)
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
tensors_to_save, tensor_objects = prepare_for_saving(
inputmat,
......@@ -372,10 +391,12 @@ class _LayerNormLinear(torch.autograd.Function):
if ub_overlap_rs_fprop:
out = rs_out
elif parallel_mode == "row":
nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
# [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.view(-1, *inp_shape[1:-1], out_features)
......@@ -394,6 +415,11 @@ class _LayerNormLinear(torch.autograd.Function):
) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
# NVTX label for profiling
nvtx_label = "transformer_engine._LayerNormLinear.backward"
if ctx.ub_name is not None:
nvtx_label = f"{nvtx_label}.{ctx.ub_name}"
with torch.cuda.nvtx.range("_LayerNormLinear_backward"):
if (
ctx.fp8
......@@ -433,6 +459,7 @@ class _LayerNormLinear(torch.autograd.Function):
# Gather intermediate/activation tensors if needed
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
# shards/unshards the base weights so we don't do it ourselves
nvtx_range_push(f"{nvtx_label}.fsdp_gather")
_fsdp_gather_tensors(
ctx.fsdp_group,
ctx.fsdp_shapes,
......@@ -441,6 +468,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight if ctx.fp8 and ctx.quantized_weight else None,
ln_out,
)
nvtx_range_pop(f"{nvtx_label}.fsdp_gather")
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one.
......@@ -515,12 +543,14 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.fp8:
quantizer = ctx.input_quantizer
quantizer.set_usage(rowwise=True, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out,
ctx.tp_group,
async_op=True,
quantizer=quantizer,
)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
else:
ln_out_total = ln_out
......@@ -536,6 +566,7 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.grad_input_quantizer is not None:
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
dgrad, *_ = general_gemm(
weight,
grad_output,
......@@ -551,12 +582,14 @@ class _LayerNormLinear(torch.autograd.Function):
extra_output=rs_out,
bulk_overlap=ctx.ub_bulk_dgrad,
)
nvtx_range_pop(f"{nvtx_label}.dgrad_gemm")
# Launch tensor-parallel communication
dgrad_work = None
if ctx.ub_overlap_rs_dgrad:
dgrad = rs_out
elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad:
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad")
if ctx.sequence_parallel:
if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered:
dgrad = dgrad + grad_outputs[1].view_as(dgrad)
......@@ -567,6 +600,7 @@ class _LayerNormLinear(torch.autograd.Function):
)
else:
dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad")
# Compute grad weight tensor
wgrad = None
......@@ -603,6 +637,7 @@ class _LayerNormLinear(torch.autograd.Function):
# wgrad GEMM
# Note: Fuse with bgrad computation if needed
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
wgrad, grad_bias_, *_, rs_out = general_gemm(
ln_out_total,
grad_output,
......@@ -621,6 +656,7 @@ class _LayerNormLinear(torch.autograd.Function):
extra_output=rs_out,
bulk_overlap=ctx.ub_bulk_wgrad,
)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
if ctx.ub_bulk_wgrad:
if ub_obj_wgrad.is_fp8_ubuf():
......@@ -657,6 +693,7 @@ class _LayerNormLinear(torch.autograd.Function):
# Norm gradient
dgamma = None
dbeta = None
nvtx_range_push(f"{nvtx_label}.norm")
if ctx.normalization == "LayerNorm":
dgrad, dgamma, dbeta = tex.layernorm_bwd(
dgrad,
......@@ -679,6 +716,7 @@ class _LayerNormLinear(torch.autograd.Function):
)
dgrad = dgrad.reshape(inputmat.size())
dbeta = None
nvtx_range_pop(f"{nvtx_label}.norm")
clear_tensor_data(mu)
clear_tensor_data(rsigma)
......
......@@ -22,12 +22,14 @@ from .base import (
from ._common import noop_cat, _fix_gathered_fp8_transpose
from ..fp8 import FP8GlobalStateManager
from ..utils import (
divide,
cast_if_needed,
clear_tensor_data,
divide,
init_method_constant,
requires_grad,
non_tn_fp8_gemm_supported,
nvtx_range_pop,
nvtx_range_push,
requires_grad,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
......@@ -100,6 +102,11 @@ class _Linear(torch.autograd.Function):
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# NVTX label for profiling
nvtx_label = "transformer_engine._Linear.forward"
if ub_name is not None:
nvtx_label = f"{nvtx_label}.{ub_name}"
# Make sure input dimensions are compatible
out_features, in_features = weight.shape
inp_shape = inp.shape
......@@ -110,6 +117,7 @@ class _Linear(torch.autograd.Function):
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.input_cast_comm")
inputmat = inp
inputmat_total = None
with_input_all_gather_nccl = (
......@@ -153,6 +161,7 @@ class _Linear(torch.autograd.Function):
inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
else:
inputmat_total = inputmat
nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
# Cast weight to expected dtype
weightmat = weight
......@@ -216,6 +225,7 @@ class _Linear(torch.autograd.Function):
ub_obj.copy_into_buffer(inputmat_total, input_quantizer, local_chunk=True)
inputmat_total = ub_obj.get_buffer(input_quantizer)
nvtx_range_push(f"{nvtx_label}.gemm")
out, *_, rs_out = general_gemm(
weightmat,
inputmat_total,
......@@ -228,6 +238,7 @@ class _Linear(torch.autograd.Function):
ub_type=ub_type,
extra_output=rs_out,
)
nvtx_range_pop(f"{nvtx_label}.gemm")
if is_grad_enabled:
saved_inputmat = None
......@@ -244,12 +255,14 @@ class _Linear(torch.autograd.Function):
# Scatter intermediate/activation tensors saved for the backward pass
# NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights
nvtx_range_push(f"{nvtx_label}.fsdp_scatter")
ctx.fsdp_group = fsdp_group
ctx.fsdp_shapes = _fsdp_scatter_tensors(
fsdp_group,
saved_inputmat,
weightmat if fp8 and not isinstance(weight, QuantizedTensor) else None,
)
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
# TODO(ksivamani): Check memory usage
tensors_to_save, tensor_objects = prepare_for_saving(
......@@ -299,10 +312,12 @@ class _Linear(torch.autograd.Function):
if ub_overlap_rs_fprop:
out = rs_out
elif parallel_mode == "row":
nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
out = out.view(-1, *inp_shape[1:-1], out_features)
return out
......@@ -311,6 +326,11 @@ class _Linear(torch.autograd.Function):
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
# NVTX label for profiling
nvtx_label = "transformer_engine._Linear.backward"
if ctx.ub_name is not None:
nvtx_label = f"{nvtx_label}.{ctx.ub_name}"
with torch.cuda.nvtx.range("_Linear_backward"):
if (
ctx.fp8
......@@ -347,12 +367,14 @@ class _Linear(torch.autograd.Function):
# Gather intermediate/activation tensors if needed
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
# shards/unshards the base weights so we don't do it ourselves
nvtx_range_push(f"{nvtx_label}.fsdp_gather")
_fsdp_gather_tensors(
ctx.fsdp_group,
ctx.fsdp_shapes,
inputmat,
weight_fp8,
)
nvtx_range_pop(f"{nvtx_label}.fsdp_gather")
ctx.ub_obj_gradout = None
ub_obj_dgrad = None
......@@ -424,12 +446,14 @@ class _Linear(torch.autograd.Function):
if ctx.fp8:
quantizer = ctx.input_quantizer
quantizer.set_usage(rowwise=True, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
inputmat_total, inputmat_total_work = gather_along_first_dim(
inputmat,
ctx.tp_group,
async_op=True,
quantizer=quantizer,
)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
else:
inputmat_total = inputmat
......@@ -451,6 +475,7 @@ class _Linear(torch.autograd.Function):
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
# dgrad GEMM
nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
dgrad, *_, rs_out = general_gemm(
weight_fp8,
grad_output,
......@@ -466,11 +491,13 @@ class _Linear(torch.autograd.Function):
extra_output=rs_out,
bulk_overlap=ctx.ub_bulk_dgrad,
)
nvtx_range_pop(f"{nvtx_label}.dgrad_gemm")
# Launch tensor-parallel communication
if ctx.ub_overlap_rs_dgrad:
dgrad = rs_out
elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad:
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad")
if ctx.sequence_parallel:
dgrad, dgrad_work = reduce_scatter_along_first_dim(
dgrad,
......@@ -479,6 +506,7 @@ class _Linear(torch.autograd.Function):
)
else:
dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad")
# Compute grad weight tensor
wgrad = None
......@@ -515,6 +543,7 @@ class _Linear(torch.autograd.Function):
# wgrad GEMM
# Note: Fuse with bgrad computation if needed
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
wgrad, grad_bias_, _, rs_out = general_gemm(
inputmat_total,
grad_output,
......@@ -533,6 +562,7 @@ class _Linear(torch.autograd.Function):
extra_output=rs_out,
bulk_overlap=ctx.ub_bulk_wgrad,
)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
if ctx.ub_bulk_wgrad:
if ub_obj_wgrad.is_fp8_ubuf():
......
......@@ -6,6 +6,7 @@
from __future__ import annotations
import functools
import math
import os
from typing import Any, Callable, List, Optional, Tuple
import torch
......@@ -326,3 +327,62 @@ def round_up_to_nearest_multiple(value, multiple):
if multiple == 0:
raise ValueError("multiple cannot be zero.")
return ((value + multiple - 1) // multiple) * multiple
@functools.lru_cache(maxsize=None)
def _nvtx_enabled() -> bool:
"""Check if NVTX range profiling is enabled"""
return bool(int(os.getenv("NVTE_NVTX_ENABLED", "0")))
# Messages associated with active NVTX ranges
_nvtx_range_messages: list[str] = []
def nvtx_range_push(msg: str) -> None:
"""Push NVTX range onto stack, if NVTX range profiling is enabled
Set `NVTE_NVTX_ENABLED=1` in the environment to enable NVTX range
profiling.
Parameters
----------
msg: str
Message to associate with range
"""
if not _nvtx_enabled():
return
_nvtx_range_messages.append(msg)
torch.cuda.nvtx.range_push(msg)
def nvtx_range_pop(msg: Optional[str] = None) -> None:
"""Pop NVTX range from stack, if NVTX range profiling is enabled
Set `NVTE_NVTX_ENABLED=1` in the environment to enable NVTX range
profiling.
Parameters
----------
msg: str, optional
Message associated with range
"""
# Return immediately if NVTX range profiling is not enabled
if not _nvtx_enabled():
return
# Update list of NVTX range messages and check for consistency
if not _nvtx_range_messages:
raise RuntimeError("Attempted to pop NVTX range from empty stack")
last_msg = _nvtx_range_messages.pop()
if msg is not None and msg != last_msg:
raise ValueError(
f"Attempted to pop NVTX range from stack with msg={msg}, "
f"but last range has msg={last_msg}"
)
# Pop NVTX range
torch.cuda.nvtx.range_pop()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment