Unverified Commit 49a4535d authored by Jaemin Choi's avatar Jaemin Choi Committed by GitHub
Browse files

Add NVTX ranges to categorize execution (#1447)


Signed-off-by: default avatarJaemin Choi <jaeminc@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarJaemin Choi <jaeminc@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent b87e539d
...@@ -23,7 +23,11 @@ import torch.nn.functional as F ...@@ -23,7 +23,11 @@ import torch.nn.functional as F
import transformer_engine_torch as tex import transformer_engine_torch as tex
import transformer_engine as te import transformer_engine as te
from transformer_engine.pytorch.utils import get_cudnn_version from transformer_engine.pytorch.utils import (
get_cudnn_version,
nvtx_range_pop,
nvtx_range_push,
)
from transformer_engine.pytorch.cpp_extensions.fused_attn import ( from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd, fused_attn_fwd,
fused_attn_bwd, fused_attn_bwd,
...@@ -1834,6 +1838,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1834,6 +1838,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
quantizers, quantizers,
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
...@@ -2756,12 +2761,14 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2756,12 +2761,14 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.fp8_meta = fp8_meta ctx.fp8_meta = fp8_meta
ctx.is_input_fp8 = is_input_fp8 ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8 ctx.is_output_fp8 = is_output_fp8
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
return out_ret return out_ret
@staticmethod @staticmethod
def backward(ctx, dout): def backward(ctx, dout):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward")
cp_size_a2a = ctx.cp_size_a2a cp_size_a2a = ctx.cp_size_a2a
rank_a2a = ctx.rank_a2a rank_a2a = ctx.rank_a2a
...@@ -3602,6 +3609,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -3602,6 +3609,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, ctx.qkv_dtype) dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, ctx.qkv_dtype)
dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, ctx.qkv_dtype) dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, ctx.qkv_dtype)
dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, ctx.qkv_dtype) dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, ctx.qkv_dtype)
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.backward")
return ( return (
None, None,
...@@ -3688,6 +3696,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -3688,6 +3696,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
cp_stream, cp_stream,
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward")
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
...@@ -3904,11 +3913,13 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -3904,11 +3913,13 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
ctx.attn_mask_type = attn_mask_type ctx.attn_mask_type = attn_mask_type
ctx.deterministic = deterministic ctx.deterministic = deterministic
ctx.use_fused_attention = use_fused_attention ctx.use_fused_attention = use_fused_attention
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward")
return out return out
@staticmethod @staticmethod
def backward(ctx, dout): def backward(ctx, dout):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.backward")
cp_size = get_distributed_world_size(ctx.cp_group) cp_size = get_distributed_world_size(ctx.cp_group)
rank = get_distributed_rank(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group)
...@@ -4092,6 +4103,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -4092,6 +4103,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :]) dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :])
dk = dk.movedim(0, seq_dim).contiguous() dk = dk.movedim(0, seq_dim).contiguous()
dv = dv.movedim(0, seq_dim).contiguous() dv = dv.movedim(0, seq_dim).contiguous()
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward")
return ( return (
None, None,
...@@ -4151,6 +4163,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4151,6 +4163,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
quantizers, quantizers,
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
...@@ -4403,11 +4416,13 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4403,11 +4416,13 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx.fp8_meta = fp8_meta ctx.fp8_meta = fp8_meta
ctx.is_input_fp8 = is_input_fp8 ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8 ctx.is_output_fp8 = is_output_fp8
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
return out_ret return out_ret
@staticmethod @staticmethod
def backward(ctx, dout): def backward(ctx, dout):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward")
cp_size = get_distributed_world_size(ctx.cp_group) cp_size = get_distributed_world_size(ctx.cp_group)
( (
...@@ -4592,6 +4607,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4592,6 +4607,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype) dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype)
if not ctx.is_input_fp8: if not ctx.is_input_fp8:
dq, dk, dv = [x.dequantize() for x in [dq, dk, dv]] dq, dk, dv = [x.dequantize() for x in [dq, dk, dv]]
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward")
return ( return (
None, None,
......
...@@ -24,12 +24,14 @@ from .base import ( ...@@ -24,12 +24,14 @@ from .base import (
) )
from ..fp8 import FP8GlobalStateManager from ..fp8 import FP8GlobalStateManager
from ..utils import ( from ..utils import (
assert_dim_for_fp8_exec,
cast_if_needed,
clear_tensor_data,
divide, divide,
get_default_init_method, get_default_init_method,
init_method_constant, init_method_constant,
cast_if_needed, nvtx_range_pop,
assert_dim_for_fp8_exec, nvtx_range_push,
clear_tensor_data,
requires_grad, requires_grad,
) )
from ..distributed import ( from ..distributed import (
...@@ -112,6 +114,12 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -112,6 +114,12 @@ class _LayerNormLinear(torch.autograd.Function):
skip_fp8_weight_update: bool, skip_fp8_weight_update: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# NVTX label for profiling
nvtx_label = "transformer_engine._LayerNormLinear.forward"
if ub_name is not None:
nvtx_label = f"{nvtx_label}.{ub_name}"
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
out_features, in_features = weight.shape out_features, in_features = weight.shape
inp_shape = inp.shape inp_shape = inp.shape
...@@ -121,10 +129,12 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -121,10 +129,12 @@ class _LayerNormLinear(torch.autograd.Function):
assert_dim_for_fp8_exec(inputmat, weight) assert_dim_for_fp8_exec(inputmat, weight)
# Cast for native AMP # Cast for native AMP
nvtx_range_push(f"{nvtx_label}.norm_input_cast")
inputmat = cast_if_needed(inputmat, activation_dtype) inputmat = cast_if_needed(inputmat, activation_dtype)
ln_weight = cast_if_needed(ln_weight, activation_dtype) ln_weight = cast_if_needed(ln_weight, activation_dtype)
if ln_bias is not None: if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype)
nvtx_range_pop(f"{nvtx_label}.norm_input_cast")
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag_fprop = ( ub_overlap_ag_fprop = (
...@@ -175,6 +185,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -175,6 +185,7 @@ class _LayerNormLinear(torch.autograd.Function):
) )
# Apply normalization # Apply normalization
nvtx_range_push(f"{nvtx_label}.norm")
ln_out, mu, rsigma = apply_normalization( ln_out, mu, rsigma = apply_normalization(
inputmat, inputmat,
ln_out, ln_out,
...@@ -188,9 +199,11 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -188,9 +199,11 @@ class _LayerNormLinear(torch.autograd.Function):
zero_centered_gamma, zero_centered_gamma,
) )
ln_out_return = ln_out if return_layernorm_output else None ln_out_return = ln_out if return_layernorm_output else None
nvtx_range_pop(f"{nvtx_label}.norm")
# Prepare GEMM input # Prepare GEMM input
# Note: Cast to expected dtype and perform tensor-parallel communication # Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.gemm_input_cast_comm")
if with_input_all_gather and not ub_overlap_ag_fprop: if with_input_all_gather and not ub_overlap_ag_fprop:
with_quantized_all_gather = fp8 with_quantized_all_gather = fp8
if return_layernorm_output and return_layernorm_output_gathered: if return_layernorm_output and return_layernorm_output_gathered:
...@@ -217,6 +230,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -217,6 +230,7 @@ class _LayerNormLinear(torch.autograd.Function):
elif backward_needs_input: elif backward_needs_input:
ln_out.update_usage(rowwise_usage=True, columnwise_usage=True) ln_out.update_usage(rowwise_usage=True, columnwise_usage=True)
ln_out_total = ln_out ln_out_total = ln_out
nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm")
# Cast weight to expected dtype # Cast weight to expected dtype
weightmat = weight weightmat = weight
...@@ -275,6 +289,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -275,6 +289,7 @@ class _LayerNormLinear(torch.autograd.Function):
assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer." assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer."
ln_out_total = ub_obj.get_buffer(input_quantizer) ln_out_total = ub_obj.get_buffer(input_quantizer)
nvtx_range_push(f"{nvtx_label}.gemm")
out, *_, rs_out = general_gemm( out, *_, rs_out = general_gemm(
weightmat, weightmat,
ln_out_total, ln_out_total,
...@@ -287,6 +302,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -287,6 +302,8 @@ class _LayerNormLinear(torch.autograd.Function):
ub_type=ub_type, ub_type=ub_type,
extra_output=rs_out, extra_output=rs_out,
) )
nvtx_range_pop(f"{nvtx_label}.gemm")
if not weight.requires_grad: if not weight.requires_grad:
if not return_layernorm_output: if not return_layernorm_output:
ln_out = ln_out_total = None ln_out = ln_out_total = None
...@@ -307,6 +324,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -307,6 +324,7 @@ class _LayerNormLinear(torch.autograd.Function):
# Scatter intermediate/activation tensors saved for the backward pass # Scatter intermediate/activation tensors saved for the backward pass
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
# shards/unshards the base weights so we don't do it ourselves # shards/unshards the base weights so we don't do it ourselves
nvtx_range_push(f"{nvtx_label}.fsdp_scatter")
ctx.fsdp_group = fsdp_group ctx.fsdp_group = fsdp_group
ctx.fsdp_shapes = _fsdp_scatter_tensors( ctx.fsdp_shapes = _fsdp_scatter_tensors(
fsdp_group, fsdp_group,
...@@ -315,6 +333,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -315,6 +333,7 @@ class _LayerNormLinear(torch.autograd.Function):
weightmat if quantized_weight else None, weightmat if quantized_weight else None,
ln_out if weight.requires_grad else None, ln_out if weight.requires_grad else None,
) )
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
inputmat, inputmat,
...@@ -372,10 +391,12 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -372,10 +391,12 @@ class _LayerNormLinear(torch.autograd.Function):
if ub_overlap_rs_fprop: if ub_overlap_rs_fprop:
out = rs_out out = rs_out
elif parallel_mode == "row": elif parallel_mode == "row":
nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
if sequence_parallel: if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group) out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel: elif tensor_parallel:
out, _ = allreduce(out, tp_group) out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.view(-1, *inp_shape[1:-1], out_features) out = out.view(-1, *inp_shape[1:-1], out_features)
...@@ -394,6 +415,11 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -394,6 +415,11 @@ class _LayerNormLinear(torch.autograd.Function):
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# NVTX label for profiling
nvtx_label = "transformer_engine._LayerNormLinear.backward"
if ctx.ub_name is not None:
nvtx_label = f"{nvtx_label}.{ctx.ub_name}"
with torch.cuda.nvtx.range("_LayerNormLinear_backward"): with torch.cuda.nvtx.range("_LayerNormLinear_backward"):
if ( if (
ctx.fp8 ctx.fp8
...@@ -433,6 +459,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -433,6 +459,7 @@ class _LayerNormLinear(torch.autograd.Function):
# Gather intermediate/activation tensors if needed # Gather intermediate/activation tensors if needed
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
# shards/unshards the base weights so we don't do it ourselves # shards/unshards the base weights so we don't do it ourselves
nvtx_range_push(f"{nvtx_label}.fsdp_gather")
_fsdp_gather_tensors( _fsdp_gather_tensors(
ctx.fsdp_group, ctx.fsdp_group,
ctx.fsdp_shapes, ctx.fsdp_shapes,
...@@ -441,6 +468,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -441,6 +468,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight if ctx.fp8 and ctx.quantized_weight else None, weight if ctx.fp8 and ctx.quantized_weight else None,
ln_out, ln_out,
) )
nvtx_range_pop(f"{nvtx_label}.fsdp_gather")
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors, # For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one. # we need to connect them into one.
...@@ -515,12 +543,14 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -515,12 +543,14 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.fp8: if ctx.fp8:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
quantizer.set_usage(rowwise=True, columnwise=True) quantizer.set_usage(rowwise=True, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out, ln_out,
ctx.tp_group, ctx.tp_group,
async_op=True, async_op=True,
quantizer=quantizer, quantizer=quantizer,
) )
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
else: else:
ln_out_total = ln_out ln_out_total = ln_out
...@@ -536,6 +566,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -536,6 +566,7 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.grad_input_quantizer is not None: if ctx.grad_input_quantizer is not None:
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
dgrad, *_ = general_gemm( dgrad, *_ = general_gemm(
weight, weight,
grad_output, grad_output,
...@@ -551,12 +582,14 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -551,12 +582,14 @@ class _LayerNormLinear(torch.autograd.Function):
extra_output=rs_out, extra_output=rs_out,
bulk_overlap=ctx.ub_bulk_dgrad, bulk_overlap=ctx.ub_bulk_dgrad,
) )
nvtx_range_pop(f"{nvtx_label}.dgrad_gemm")
# Launch tensor-parallel communication # Launch tensor-parallel communication
dgrad_work = None dgrad_work = None
if ctx.ub_overlap_rs_dgrad: if ctx.ub_overlap_rs_dgrad:
dgrad = rs_out dgrad = rs_out
elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad: elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad:
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad")
if ctx.sequence_parallel: if ctx.sequence_parallel:
if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered: if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered:
dgrad = dgrad + grad_outputs[1].view_as(dgrad) dgrad = dgrad + grad_outputs[1].view_as(dgrad)
...@@ -567,6 +600,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -567,6 +600,7 @@ class _LayerNormLinear(torch.autograd.Function):
) )
else: else:
dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad")
# Compute grad weight tensor # Compute grad weight tensor
wgrad = None wgrad = None
...@@ -603,6 +637,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -603,6 +637,7 @@ class _LayerNormLinear(torch.autograd.Function):
# wgrad GEMM # wgrad GEMM
# Note: Fuse with bgrad computation if needed # Note: Fuse with bgrad computation if needed
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
wgrad, grad_bias_, *_, rs_out = general_gemm( wgrad, grad_bias_, *_, rs_out = general_gemm(
ln_out_total, ln_out_total,
grad_output, grad_output,
...@@ -621,6 +656,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -621,6 +656,7 @@ class _LayerNormLinear(torch.autograd.Function):
extra_output=rs_out, extra_output=rs_out,
bulk_overlap=ctx.ub_bulk_wgrad, bulk_overlap=ctx.ub_bulk_wgrad,
) )
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
if ub_obj_wgrad.is_fp8_ubuf(): if ub_obj_wgrad.is_fp8_ubuf():
...@@ -657,6 +693,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -657,6 +693,7 @@ class _LayerNormLinear(torch.autograd.Function):
# Norm gradient # Norm gradient
dgamma = None dgamma = None
dbeta = None dbeta = None
nvtx_range_push(f"{nvtx_label}.norm")
if ctx.normalization == "LayerNorm": if ctx.normalization == "LayerNorm":
dgrad, dgamma, dbeta = tex.layernorm_bwd( dgrad, dgamma, dbeta = tex.layernorm_bwd(
dgrad, dgrad,
...@@ -679,6 +716,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -679,6 +716,7 @@ class _LayerNormLinear(torch.autograd.Function):
) )
dgrad = dgrad.reshape(inputmat.size()) dgrad = dgrad.reshape(inputmat.size())
dbeta = None dbeta = None
nvtx_range_pop(f"{nvtx_label}.norm")
clear_tensor_data(mu) clear_tensor_data(mu)
clear_tensor_data(rsigma) clear_tensor_data(rsigma)
......
...@@ -22,12 +22,14 @@ from .base import ( ...@@ -22,12 +22,14 @@ from .base import (
from ._common import noop_cat, _fix_gathered_fp8_transpose from ._common import noop_cat, _fix_gathered_fp8_transpose
from ..fp8 import FP8GlobalStateManager from ..fp8 import FP8GlobalStateManager
from ..utils import ( from ..utils import (
divide,
cast_if_needed, cast_if_needed,
clear_tensor_data, clear_tensor_data,
divide,
init_method_constant, init_method_constant,
requires_grad,
non_tn_fp8_gemm_supported, non_tn_fp8_gemm_supported,
nvtx_range_pop,
nvtx_range_push,
requires_grad,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -100,6 +102,11 @@ class _Linear(torch.autograd.Function): ...@@ -100,6 +102,11 @@ class _Linear(torch.autograd.Function):
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# NVTX label for profiling
nvtx_label = "transformer_engine._Linear.forward"
if ub_name is not None:
nvtx_label = f"{nvtx_label}.{ub_name}"
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
out_features, in_features = weight.shape out_features, in_features = weight.shape
inp_shape = inp.shape inp_shape = inp.shape
...@@ -110,6 +117,7 @@ class _Linear(torch.autograd.Function): ...@@ -110,6 +117,7 @@ class _Linear(torch.autograd.Function):
# Prepare input tensor # Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication # Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.input_cast_comm")
inputmat = inp inputmat = inp
inputmat_total = None inputmat_total = None
with_input_all_gather_nccl = ( with_input_all_gather_nccl = (
...@@ -153,6 +161,7 @@ class _Linear(torch.autograd.Function): ...@@ -153,6 +161,7 @@ class _Linear(torch.autograd.Function):
inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
else: else:
inputmat_total = inputmat inputmat_total = inputmat
nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
# Cast weight to expected dtype # Cast weight to expected dtype
weightmat = weight weightmat = weight
...@@ -216,6 +225,7 @@ class _Linear(torch.autograd.Function): ...@@ -216,6 +225,7 @@ class _Linear(torch.autograd.Function):
ub_obj.copy_into_buffer(inputmat_total, input_quantizer, local_chunk=True) ub_obj.copy_into_buffer(inputmat_total, input_quantizer, local_chunk=True)
inputmat_total = ub_obj.get_buffer(input_quantizer) inputmat_total = ub_obj.get_buffer(input_quantizer)
nvtx_range_push(f"{nvtx_label}.gemm")
out, *_, rs_out = general_gemm( out, *_, rs_out = general_gemm(
weightmat, weightmat,
inputmat_total, inputmat_total,
...@@ -228,6 +238,7 @@ class _Linear(torch.autograd.Function): ...@@ -228,6 +238,7 @@ class _Linear(torch.autograd.Function):
ub_type=ub_type, ub_type=ub_type,
extra_output=rs_out, extra_output=rs_out,
) )
nvtx_range_pop(f"{nvtx_label}.gemm")
if is_grad_enabled: if is_grad_enabled:
saved_inputmat = None saved_inputmat = None
...@@ -244,12 +255,14 @@ class _Linear(torch.autograd.Function): ...@@ -244,12 +255,14 @@ class _Linear(torch.autograd.Function):
# Scatter intermediate/activation tensors saved for the backward pass # Scatter intermediate/activation tensors saved for the backward pass
# NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights
nvtx_range_push(f"{nvtx_label}.fsdp_scatter")
ctx.fsdp_group = fsdp_group ctx.fsdp_group = fsdp_group
ctx.fsdp_shapes = _fsdp_scatter_tensors( ctx.fsdp_shapes = _fsdp_scatter_tensors(
fsdp_group, fsdp_group,
saved_inputmat, saved_inputmat,
weightmat if fp8 and not isinstance(weight, QuantizedTensor) else None, weightmat if fp8 and not isinstance(weight, QuantizedTensor) else None,
) )
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
# TODO(ksivamani): Check memory usage # TODO(ksivamani): Check memory usage
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
...@@ -299,10 +312,12 @@ class _Linear(torch.autograd.Function): ...@@ -299,10 +312,12 @@ class _Linear(torch.autograd.Function):
if ub_overlap_rs_fprop: if ub_overlap_rs_fprop:
out = rs_out out = rs_out
elif parallel_mode == "row": elif parallel_mode == "row":
nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
if sequence_parallel: if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group) out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel: elif tensor_parallel:
out, _ = allreduce(out, tp_group) out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
out = out.view(-1, *inp_shape[1:-1], out_features) out = out.view(-1, *inp_shape[1:-1], out_features)
return out return out
...@@ -311,6 +326,11 @@ class _Linear(torch.autograd.Function): ...@@ -311,6 +326,11 @@ class _Linear(torch.autograd.Function):
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# NVTX label for profiling
nvtx_label = "transformer_engine._Linear.backward"
if ctx.ub_name is not None:
nvtx_label = f"{nvtx_label}.{ctx.ub_name}"
with torch.cuda.nvtx.range("_Linear_backward"): with torch.cuda.nvtx.range("_Linear_backward"):
if ( if (
ctx.fp8 ctx.fp8
...@@ -347,12 +367,14 @@ class _Linear(torch.autograd.Function): ...@@ -347,12 +367,14 @@ class _Linear(torch.autograd.Function):
# Gather intermediate/activation tensors if needed # Gather intermediate/activation tensors if needed
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
# shards/unshards the base weights so we don't do it ourselves # shards/unshards the base weights so we don't do it ourselves
nvtx_range_push(f"{nvtx_label}.fsdp_gather")
_fsdp_gather_tensors( _fsdp_gather_tensors(
ctx.fsdp_group, ctx.fsdp_group,
ctx.fsdp_shapes, ctx.fsdp_shapes,
inputmat, inputmat,
weight_fp8, weight_fp8,
) )
nvtx_range_pop(f"{nvtx_label}.fsdp_gather")
ctx.ub_obj_gradout = None ctx.ub_obj_gradout = None
ub_obj_dgrad = None ub_obj_dgrad = None
...@@ -424,12 +446,14 @@ class _Linear(torch.autograd.Function): ...@@ -424,12 +446,14 @@ class _Linear(torch.autograd.Function):
if ctx.fp8: if ctx.fp8:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
quantizer.set_usage(rowwise=True, columnwise=True) quantizer.set_usage(rowwise=True, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
inputmat_total, inputmat_total_work = gather_along_first_dim( inputmat_total, inputmat_total_work = gather_along_first_dim(
inputmat, inputmat,
ctx.tp_group, ctx.tp_group,
async_op=True, async_op=True,
quantizer=quantizer, quantizer=quantizer,
) )
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
else: else:
inputmat_total = inputmat inputmat_total = inputmat
...@@ -451,6 +475,7 @@ class _Linear(torch.autograd.Function): ...@@ -451,6 +475,7 @@ class _Linear(torch.autograd.Function):
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
# dgrad GEMM # dgrad GEMM
nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
dgrad, *_, rs_out = general_gemm( dgrad, *_, rs_out = general_gemm(
weight_fp8, weight_fp8,
grad_output, grad_output,
...@@ -466,11 +491,13 @@ class _Linear(torch.autograd.Function): ...@@ -466,11 +491,13 @@ class _Linear(torch.autograd.Function):
extra_output=rs_out, extra_output=rs_out,
bulk_overlap=ctx.ub_bulk_dgrad, bulk_overlap=ctx.ub_bulk_dgrad,
) )
nvtx_range_pop(f"{nvtx_label}.dgrad_gemm")
# Launch tensor-parallel communication # Launch tensor-parallel communication
if ctx.ub_overlap_rs_dgrad: if ctx.ub_overlap_rs_dgrad:
dgrad = rs_out dgrad = rs_out
elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad: elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad:
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad")
if ctx.sequence_parallel: if ctx.sequence_parallel:
dgrad, dgrad_work = reduce_scatter_along_first_dim( dgrad, dgrad_work = reduce_scatter_along_first_dim(
dgrad, dgrad,
...@@ -479,6 +506,7 @@ class _Linear(torch.autograd.Function): ...@@ -479,6 +506,7 @@ class _Linear(torch.autograd.Function):
) )
else: else:
dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad")
# Compute grad weight tensor # Compute grad weight tensor
wgrad = None wgrad = None
...@@ -515,6 +543,7 @@ class _Linear(torch.autograd.Function): ...@@ -515,6 +543,7 @@ class _Linear(torch.autograd.Function):
# wgrad GEMM # wgrad GEMM
# Note: Fuse with bgrad computation if needed # Note: Fuse with bgrad computation if needed
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
wgrad, grad_bias_, _, rs_out = general_gemm( wgrad, grad_bias_, _, rs_out = general_gemm(
inputmat_total, inputmat_total,
grad_output, grad_output,
...@@ -533,6 +562,7 @@ class _Linear(torch.autograd.Function): ...@@ -533,6 +562,7 @@ class _Linear(torch.autograd.Function):
extra_output=rs_out, extra_output=rs_out,
bulk_overlap=ctx.ub_bulk_wgrad, bulk_overlap=ctx.ub_bulk_wgrad,
) )
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
if ub_obj_wgrad.is_fp8_ubuf(): if ub_obj_wgrad.is_fp8_ubuf():
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
from __future__ import annotations from __future__ import annotations
import functools import functools
import math import math
import os
from typing import Any, Callable, List, Optional, Tuple from typing import Any, Callable, List, Optional, Tuple
import torch import torch
...@@ -326,3 +327,62 @@ def round_up_to_nearest_multiple(value, multiple): ...@@ -326,3 +327,62 @@ def round_up_to_nearest_multiple(value, multiple):
if multiple == 0: if multiple == 0:
raise ValueError("multiple cannot be zero.") raise ValueError("multiple cannot be zero.")
return ((value + multiple - 1) // multiple) * multiple return ((value + multiple - 1) // multiple) * multiple
@functools.lru_cache(maxsize=None)
def _nvtx_enabled() -> bool:
"""Check if NVTX range profiling is enabled"""
return bool(int(os.getenv("NVTE_NVTX_ENABLED", "0")))
# Messages associated with active NVTX ranges
_nvtx_range_messages: list[str] = []
def nvtx_range_push(msg: str) -> None:
"""Push NVTX range onto stack, if NVTX range profiling is enabled
Set `NVTE_NVTX_ENABLED=1` in the environment to enable NVTX range
profiling.
Parameters
----------
msg: str
Message to associate with range
"""
if not _nvtx_enabled():
return
_nvtx_range_messages.append(msg)
torch.cuda.nvtx.range_push(msg)
def nvtx_range_pop(msg: Optional[str] = None) -> None:
"""Pop NVTX range from stack, if NVTX range profiling is enabled
Set `NVTE_NVTX_ENABLED=1` in the environment to enable NVTX range
profiling.
Parameters
----------
msg: str, optional
Message associated with range
"""
# Return immediately if NVTX range profiling is not enabled
if not _nvtx_enabled():
return
# Update list of NVTX range messages and check for consistency
if not _nvtx_range_messages:
raise RuntimeError("Attempted to pop NVTX range from empty stack")
last_msg = _nvtx_range_messages.pop()
if msg is not None and msg != last_msg:
raise ValueError(
f"Attempted to pop NVTX range from stack with msg={msg}, "
f"but last range has msg={last_msg}"
)
# Pop NVTX range
torch.cuda.nvtx.range_pop()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment