"transformer_engine/common/triton/cross_entropy.py" did not exist on "e9a5fa4e368464f3b310b90ab7f670f35319344b"
Unverified Commit b0ad8ef0 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Debug NCCL communication overlapping in linear backward with FP8 data (#1620)



* Overlap input all-gather with dgrad GEMM in FP8 linear layers
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add missing docstring
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 56653520
......@@ -19,7 +19,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules
from .utils import safely_set_viewless_tensor_data
from .utils import non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data
from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager, fp8_autocast
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
......@@ -860,23 +860,29 @@ def _all_gather_fp8(
process_group: dist_group_type,
*,
async_op: bool = False,
quantizer: Optional[Float8Quantizer] = None,
quantizer: Optional[Quantizer] = None,
out_shape: Optional[list[int]] = None,
) -> tuple[Float8TensorBase, Optional[torch.distributed.Work]]:
"""All-gather FP8 tensor along first dimension."""
world_size = get_distributed_world_size(process_group)
# Check that quantizer is valid
if quantizer is not None and not isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
raise ValueError(f"Got non-FP8 quantizer ({quantizer.__class__.__name__})")
# Output tensor dims
if out_shape is None:
out_shape = list(inp.size())
out_shape[0] *= world_size
# Quantize input tensor if needed
# Cast input tensor to FP8 if needed
# Note: We cannot directly all-gather the transposed FP8 tensor,
# so temporarily modify quantizer to avoid creating FP8 transpose.
if not isinstance(inp, Float8TensorBase):
assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer))
# we cannot directly gather the transposed fp8 tensor
# so we need to disable columnwise usage for the quantizer
# and then set it back to the original value after quantizing
if quantizer is None:
raise ValueError("Input tensor is not FP8 and no quantizer was provided")
init_rowwise_usage = quantizer.rowwise_usage
init_columnwise_usage = quantizer.columnwise_usage
quantizer.set_usage(rowwise=True, columnwise=False)
......@@ -888,7 +894,7 @@ def _all_gather_fp8(
# Construct output tensor
out: Float8TensorBase
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
if quantizer is not None:
dtype = torch.float32
device = "cuda"
if isinstance(inp, Float8Tensor):
......@@ -906,9 +912,8 @@ def _all_gather_fp8(
out._transpose_invalid = True
else:
raise RuntimeError("FP8TensorBase is not supported yet without Quantizer")
# For delayed scaling, scale_inv is from history, so we can pass it from inp to out
# For current scaling, scale_inv is from doing amax reduction in C++ code, so each rank should have same scale_inv,
# so we can just pass it from inp to out
# Assume scaling factors are identical across ranks
out._scale_inv = inp._scale_inv
# Perform communication
......@@ -920,12 +925,13 @@ def _all_gather_fp8(
)
# Make sure FP8 transpose is populated if needed
if out._transpose is not None:
needs_transpose = (
quantizer is not None and quantizer.columnwise_usage and not non_tn_fp8_gemm_supported()
)
if needs_transpose:
if handle is not None:
handle.wait()
handle = None
if not isinstance(out, Float8Tensor):
raise RuntimeError("FP8TensorBase does not support FP8 transpose yet")
out._create_transpose()
return out, handle
......
......@@ -55,6 +55,7 @@ from ..tensor.quantized_tensor import (
prepare_for_saving,
restore_from_saved,
)
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
......@@ -557,12 +558,27 @@ class _LayerNormLinear(torch.autograd.Function):
ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer)
dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer)
# Configure quantizer for grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
# Reduce duplicated transpose, which is performed in grad_output.update_usage
if ctx.ub_overlap_ag and ctx.fp8_recipe.float8_per_tensor_scaling():
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=False)
else:
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True)
rowwise_usage = True
columnwise_usage = True
if ctx.ub_overlap_ag and isinstance(
ctx.grad_output_quantizer,
(Float8Quantizer, Float8CurrentScalingQuantizer),
):
# If data is in FP8 and communication is handled
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage = False
ctx.grad_output_quantizer.set_usage(
rowwise=rowwise_usage,
columnwise=columnwise_usage,
)
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
(
grad_output,
......@@ -575,14 +591,18 @@ class _LayerNormLinear(torch.autograd.Function):
)
nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
# Prepare GEMM input
# Note: Perform tensor-parallel communication if needed
# Launch tensor-parallel communication for LayerNorm out tensor
ln_out_total = None
ln_out_total_work = None
if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad:
quantizer = None
if ctx.fp8:
quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
quantizer.set_usage(rowwise=True, columnwise=False)
else:
# wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
ln_out_total, ln_out_total_work = gather_along_first_dim(
......@@ -652,6 +672,8 @@ class _LayerNormLinear(torch.autograd.Function):
# Compute grad weight tensor
wgrad = None
if ctx.requires_wgrad:
# Synchronize tensor-parallel communication for input tensor
if ctx.ub_bulk_dgrad:
ln_out_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer)
if ctx.fp8:
......@@ -665,18 +687,25 @@ class _LayerNormLinear(torch.autograd.Function):
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
ln_out_total._create_transpose()
else:
if ln_out_total_work is not None:
# Synchronize tensor-parallel communication
ln_out_total_work.wait()
ln_out_total_work = None
# Make sure GEMM inputs have required data
if isinstance(ln_out_total, QuantizedTensor):
ln_out_total.update_usage(columnwise_usage=True)
if isinstance(grad_output, QuantizedTensor):
# This is a no-op if platform supports non-TN FP8 GEMM or the transpose
# already exists.
grad_output.update_usage(rowwise_usage=True, columnwise_usage=True)
grad_output.update_usage(columnwise_usage=True)
# Figure out whether to use split accumulator
use_split_accumulator = _2X_ACC_WGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_wgrad"):
use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator
# Output buffer for overlapping grad input
# reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf():
rs_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=inputmat.device
......@@ -685,14 +714,6 @@ class _LayerNormLinear(torch.autograd.Function):
# wgrad GEMM
# Note: Fuse with bgrad computation if needed
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_wgrad"):
wgrad_gemm_use_split_accumulator = (
recipe.fp8_gemm_wgrad.use_split_accumulator
)
wgrad, grad_bias_, *_, rs_out = general_gemm(
ln_out_total,
grad_output,
......@@ -704,7 +725,7 @@ class _LayerNormLinear(torch.autograd.Function):
),
bias=(bias if (grad_bias is None and not ctx.fp8) else None),
out=main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=wgrad_gemm_use_split_accumulator,
use_split_accumulator=use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad,
ub=ub_obj_wgrad,
ub_type=ub_type_wgrad,
......@@ -728,7 +749,7 @@ class _LayerNormLinear(torch.autograd.Function):
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
clear_tensor_data(ln_out_total)
# Synchronize tensor parallel communication
# Make sure all tensor-parallel communication is finished
if ln_out_total_work is not None:
ln_out_total_work.wait()
ln_out_total_work = None
......
......@@ -56,7 +56,11 @@ from ..distributed import (
from ..constants import dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..tensor.float8_tensor import Float8Tensor
from ..tensor.float8_tensor import (
Float8CurrentScalingQuantizer,
Float8Quantizer,
Float8Tensor,
)
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ._common import apply_normalization, _fix_gathered_fp8_transpose
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
......@@ -642,15 +646,27 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.ub_bulk_dgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_dgrad
ctx.ub_bulk_wgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_wgrad
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# Configure quantizer for FC2 grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_fc2_output_quantizer is not None:
# Reduce duplicated transpose, which is performed in grad_output.update_usage
if ctx.ub_overlap_ag and ctx.fp8_recipe.float8_per_tensor_scaling():
ctx.grad_fc2_output_quantizer.set_usage(rowwise=True, columnwise=False)
else:
ctx.grad_fc2_output_quantizer.set_usage(rowwise=True, columnwise=True)
rowwise_usage = True
columnwise_usage = True
if ctx.ub_overlap_ag and isinstance(
ctx.grad_fc2_output_quantizer,
(Float8Quantizer, Float8CurrentScalingQuantizer),
):
# If data is in FP8 and communication is handled
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage = False
ctx.grad_fc2_output_quantizer.set_usage(
rowwise=rowwise_usage,
columnwise=columnwise_usage,
)
# Prepare FC2 grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
ub_obj_fc2_dgrad = None
if ctx.ub_overlap_ag:
ub_obj_fc2_dgrad = get_ub("fc2_dgrad")
......@@ -662,8 +678,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx, grad_outputs[0], True, ctx.grad_fc2_output_quantizer
)
# Prepare FC1 GEMM input
# Note: Perform tensor-parallel communication if needed
# Launch tensor-parallel communication for FC1 GEMM input
ln_out_total = None
ln_out_total_work = None
if (
......@@ -675,6 +690,11 @@ class _LayerNormMLP(torch.autograd.Function):
quantizer = None
if ctx.fp8:
quantizer = ctx.fc1_input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
quantizer.set_usage(rowwise=True, columnwise=False)
else:
# wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out,
......@@ -868,6 +888,8 @@ class _LayerNormMLP(torch.autograd.Function):
# FC1 WGRAD
fc1_wgrad = None
if ctx.fc1_weight_requires_grad:
# Synchronize tensor-parallel communication for FC1 GEMM input tensor
if ctx.ub_bulk_dgrad:
ln_out_total = ub_obj_fc1_dgrad.get_buffer(ctx.fc1_input_quantizer)
if ctx.fp8:
......@@ -879,24 +901,24 @@ class _LayerNormMLP(torch.autograd.Function):
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
ln_out_total._create_transpose()
else:
if ln_out_total_work is not None:
# Synchronize tensor-parallel communication
ln_out_total_work.wait()
ln_out_total_work = None
# Make sure GEMM inputs have expected data
# Make sure GEMM inputs have required data
if isinstance(ln_out_total, QuantizedTensor):
ln_out_total.update_usage(rowwise_usage=True, columnwise_usage=True)
ln_out_total.update_usage(columnwise_usage=True)
if isinstance(dact, QuantizedTensor):
dact.update_usage(rowwise_usage=True, columnwise_usage=True)
dact.update_usage(columnwise_usage=True)
# Output buffer for overlapping grad input
# reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad and ub_obj_fc1_wgrad.is_fp8_ubuf():
fc1_dgrad_rs_out = torch.empty(
fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda"
)
# wgrad GEMM
fc1_wgrad_outputs = general_gemm(
ln_out_total,
dact,
......@@ -930,7 +952,7 @@ class _LayerNormMLP(torch.autograd.Function):
else:
fc1_dgrad = ub_obj_fc1_wgrad.get_buffer(None, local_chunk=True)
# Synchronize tensor parallel communication
# Make sure all tensor-parallel communication is finished
if ln_out_total_work is not None:
ln_out_total_work.wait()
ln_out_total_work = None
......
......@@ -56,6 +56,7 @@ from ..tensor.quantized_tensor import (
prepare_for_saving,
restore_from_saved,
)
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
......@@ -471,14 +472,27 @@ class _Linear(torch.autograd.Function):
ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer)
dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer)
# Configure quantizer for grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
rowwise_usage = True
columnwise_usage = True
if ctx.ub_overlap_ag and isinstance(
ctx.grad_output_quantizer,
(Float8Quantizer, Float8CurrentScalingQuantizer),
):
# If data is in FP8 and communication is handled
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage = False
ctx.grad_output_quantizer.set_usage(
rowwise=rowwise_usage,
columnwise=columnwise_usage,
)
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
if ctx.grad_output_quantizer is not None:
# Reduce duplicated transpose, which is performed in grad_output.update_usage
if ctx.ub_overlap_ag and ctx.fp8_recipe.float8_per_tensor_scaling():
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=False)
else:
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True)
nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
(
grad_output,
......@@ -491,14 +505,18 @@ class _Linear(torch.autograd.Function):
)
nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
# Prepare input tensor
# Note: Perform tensor-parallel communication if needed
# Launch tensor-parallel communication for input tensor
inputmat_total = None
inputmat_total_work = None
if ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad:
quantizer = None
if ctx.fp8:
quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
quantizer.set_usage(rowwise=True, columnwise=False)
else:
# wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
inputmat_total, inputmat_total_work = gather_along_first_dim(
......@@ -573,6 +591,8 @@ class _Linear(torch.autograd.Function):
# Compute grad weight tensor
wgrad = None
if ctx.requires_wgrad:
# Synchronize tensor-parallel communication for input tensor
if ctx.ub_bulk_dgrad:
inputmat_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer)
if ctx.fp8:
......@@ -586,18 +606,25 @@ class _Linear(torch.autograd.Function):
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
inputmat_total._create_transpose()
else:
if inputmat_total_work is not None:
# Synchronize tensor-parallel communication
inputmat_total_work.wait()
inputmat_total_work = None
# Make sure GEMM inputs have required data
if isinstance(inputmat_total, QuantizedTensor):
inputmat_total.update_usage(columnwise_usage=True)
if isinstance(grad_output, QuantizedTensor):
# This is a no-op if platform supports non-TN FP8 GEMM or the transpose
# already exists.
grad_output.update_usage(rowwise_usage=True, columnwise_usage=True)
grad_output.update_usage(columnwise_usage=True)
# Figure out whether to use split accumulator
use_split_accumulator = _2X_ACC_WGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_wgrad"):
use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator
# Output buffer for overlapping grad input
# reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf():
rs_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
......@@ -606,14 +633,6 @@ class _Linear(torch.autograd.Function):
# wgrad GEMM
# Note: Fuse with bgrad computation if needed
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_wgrad"):
wgrad_gemm_use_split_accumulator = (
recipe.fp8_gemm_wgrad.use_split_accumulator
)
wgrad, grad_bias_, _, rs_out = general_gemm(
inputmat_total,
grad_output,
......@@ -625,7 +644,7 @@ class _Linear(torch.autograd.Function):
),
bias=(bias if (grad_bias is None and not ctx.fp8) else None),
out=main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=wgrad_gemm_use_split_accumulator,
use_split_accumulator=use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad,
ub=ub_obj_wgrad,
ub_type=ub_type_wgrad,
......@@ -652,7 +671,7 @@ class _Linear(torch.autograd.Function):
if not ctx.use_bias:
grad_bias = None
# Synchronize tensor parallel communication
# Make sure all tensor-parallel communication is finished
if inputmat_total_work is not None:
inputmat_total_work.wait()
inputmat_total_work = None
......
......@@ -134,3 +134,11 @@ class Float8TensorBase:
f"data={self.dequantize()}"
")"
)
def _create_transpose(self):
"""Update FP8 transpose cache"""
data = self._data
if not data.is_contiguous():
data = data.contiguous()
self._transpose = tex.fp8_transpose(data, self._fp8_dtype, out=self._transpose)
self._transpose_invalid = False
......@@ -422,13 +422,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
# pylint: disable=missing-function-docstring
return Float8Tensor.make_like(self)
def _create_transpose(self):
data = self._data
if not data.is_contiguous():
data = data.contiguous()
self._transpose = tex.fp8_transpose(data, self._fp8_dtype, out=self._transpose)
self._transpose_invalid = False
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment