Unverified Commit b0ad8ef0 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Debug NCCL communication overlapping in linear backward with FP8 data (#1620)



* Overlap input all-gather with dgrad GEMM in FP8 linear layers
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add missing docstring
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 56653520
...@@ -19,7 +19,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP ...@@ -19,7 +19,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules
from .utils import safely_set_viewless_tensor_data from .utils import non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data
from .constants import dist_group_type from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager, fp8_autocast from .fp8 import FP8GlobalStateManager, fp8_autocast
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
...@@ -860,23 +860,29 @@ def _all_gather_fp8( ...@@ -860,23 +860,29 @@ def _all_gather_fp8(
process_group: dist_group_type, process_group: dist_group_type,
*, *,
async_op: bool = False, async_op: bool = False,
quantizer: Optional[Float8Quantizer] = None, quantizer: Optional[Quantizer] = None,
out_shape: Optional[list[int]] = None, out_shape: Optional[list[int]] = None,
) -> tuple[Float8TensorBase, Optional[torch.distributed.Work]]: ) -> tuple[Float8TensorBase, Optional[torch.distributed.Work]]:
"""All-gather FP8 tensor along first dimension.""" """All-gather FP8 tensor along first dimension."""
world_size = get_distributed_world_size(process_group) world_size = get_distributed_world_size(process_group)
# Check that quantizer is valid
if quantizer is not None and not isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
raise ValueError(f"Got non-FP8 quantizer ({quantizer.__class__.__name__})")
# Output tensor dims # Output tensor dims
if out_shape is None: if out_shape is None:
out_shape = list(inp.size()) out_shape = list(inp.size())
out_shape[0] *= world_size out_shape[0] *= world_size
# Quantize input tensor if needed # Cast input tensor to FP8 if needed
# Note: We cannot directly all-gather the transposed FP8 tensor,
# so temporarily modify quantizer to avoid creating FP8 transpose.
if not isinstance(inp, Float8TensorBase): if not isinstance(inp, Float8TensorBase):
assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)) if quantizer is None:
# we cannot directly gather the transposed fp8 tensor raise ValueError("Input tensor is not FP8 and no quantizer was provided")
# so we need to disable columnwise usage for the quantizer
# and then set it back to the original value after quantizing
init_rowwise_usage = quantizer.rowwise_usage init_rowwise_usage = quantizer.rowwise_usage
init_columnwise_usage = quantizer.columnwise_usage init_columnwise_usage = quantizer.columnwise_usage
quantizer.set_usage(rowwise=True, columnwise=False) quantizer.set_usage(rowwise=True, columnwise=False)
...@@ -888,7 +894,7 @@ def _all_gather_fp8( ...@@ -888,7 +894,7 @@ def _all_gather_fp8(
# Construct output tensor # Construct output tensor
out: Float8TensorBase out: Float8TensorBase
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if quantizer is not None:
dtype = torch.float32 dtype = torch.float32
device = "cuda" device = "cuda"
if isinstance(inp, Float8Tensor): if isinstance(inp, Float8Tensor):
...@@ -906,9 +912,8 @@ def _all_gather_fp8( ...@@ -906,9 +912,8 @@ def _all_gather_fp8(
out._transpose_invalid = True out._transpose_invalid = True
else: else:
raise RuntimeError("FP8TensorBase is not supported yet without Quantizer") raise RuntimeError("FP8TensorBase is not supported yet without Quantizer")
# For delayed scaling, scale_inv is from history, so we can pass it from inp to out
# For current scaling, scale_inv is from doing amax reduction in C++ code, so each rank should have same scale_inv, # Assume scaling factors are identical across ranks
# so we can just pass it from inp to out
out._scale_inv = inp._scale_inv out._scale_inv = inp._scale_inv
# Perform communication # Perform communication
...@@ -920,12 +925,13 @@ def _all_gather_fp8( ...@@ -920,12 +925,13 @@ def _all_gather_fp8(
) )
# Make sure FP8 transpose is populated if needed # Make sure FP8 transpose is populated if needed
if out._transpose is not None: needs_transpose = (
quantizer is not None and quantizer.columnwise_usage and not non_tn_fp8_gemm_supported()
)
if needs_transpose:
if handle is not None: if handle is not None:
handle.wait() handle.wait()
handle = None handle = None
if not isinstance(out, Float8Tensor):
raise RuntimeError("FP8TensorBase does not support FP8 transpose yet")
out._create_transpose() out._create_transpose()
return out, handle return out, handle
......
...@@ -55,6 +55,7 @@ from ..tensor.quantized_tensor import ( ...@@ -55,6 +55,7 @@ from ..tensor.quantized_tensor import (
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
...@@ -557,12 +558,27 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -557,12 +558,27 @@ class _LayerNormLinear(torch.autograd.Function):
ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer) ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer)
dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer) dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer)
# Configure quantizer for grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None: if ctx.grad_output_quantizer is not None:
# Reduce duplicated transpose, which is performed in grad_output.update_usage rowwise_usage = True
if ctx.ub_overlap_ag and ctx.fp8_recipe.float8_per_tensor_scaling(): columnwise_usage = True
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=False) if ctx.ub_overlap_ag and isinstance(
else: ctx.grad_output_quantizer,
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True) (Float8Quantizer, Float8CurrentScalingQuantizer),
):
# If data is in FP8 and communication is handled
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage = False
ctx.grad_output_quantizer.set_usage(
rowwise=rowwise_usage,
columnwise=columnwise_usage,
)
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.grad_output_preprocess") nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
( (
grad_output, grad_output,
...@@ -575,15 +591,19 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -575,15 +591,19 @@ class _LayerNormLinear(torch.autograd.Function):
) )
nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess") nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
# Prepare GEMM input # Launch tensor-parallel communication for LayerNorm out tensor
# Note: Perform tensor-parallel communication if needed
ln_out_total = None ln_out_total = None
ln_out_total_work = None ln_out_total_work = None
if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad: if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad:
quantizer = None quantizer = None
if ctx.fp8: if ctx.fp8:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
quantizer.set_usage(rowwise=False, columnwise=True) if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
quantizer.set_usage(rowwise=True, columnwise=False)
else:
# wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out, ln_out,
...@@ -652,6 +672,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -652,6 +672,8 @@ class _LayerNormLinear(torch.autograd.Function):
# Compute grad weight tensor # Compute grad weight tensor
wgrad = None wgrad = None
if ctx.requires_wgrad: if ctx.requires_wgrad:
# Synchronize tensor-parallel communication for input tensor
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
ln_out_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer) ln_out_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer)
if ctx.fp8: if ctx.fp8:
...@@ -665,18 +687,25 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -665,18 +687,25 @@ class _LayerNormLinear(torch.autograd.Function):
# FP8 GEMM on Hopper only supports TN layout so the gathered input must # FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose. # have a valid transpose.
ln_out_total._create_transpose() ln_out_total._create_transpose()
if ln_out_total_work is not None:
ln_out_total_work.wait()
ln_out_total_work = None
else: # Make sure GEMM inputs have required data
if ln_out_total_work is not None: if isinstance(ln_out_total, QuantizedTensor):
# Synchronize tensor-parallel communication ln_out_total.update_usage(columnwise_usage=True)
ln_out_total_work.wait()
ln_out_total_work = None
if isinstance(grad_output, QuantizedTensor): if isinstance(grad_output, QuantizedTensor):
# This is a no-op if platform supports non-TN FP8 GEMM or the transpose grad_output.update_usage(columnwise_usage=True)
# already exists.
grad_output.update_usage(rowwise_usage=True, columnwise_usage=True) # Figure out whether to use split accumulator
use_split_accumulator = _2X_ACC_WGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_wgrad"):
use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator
# Output buffer for overlapping grad input
# reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf(): if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf():
rs_out = torch.empty( rs_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=inputmat.device dgrad_shape, dtype=ctx.activation_dtype, device=inputmat.device
...@@ -685,14 +714,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -685,14 +714,6 @@ class _LayerNormLinear(torch.autograd.Function):
# wgrad GEMM # wgrad GEMM
# Note: Fuse with bgrad computation if needed # Note: Fuse with bgrad computation if needed
nvtx_range_push(f"{nvtx_label}.wgrad_gemm") nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_wgrad"):
wgrad_gemm_use_split_accumulator = (
recipe.fp8_gemm_wgrad.use_split_accumulator
)
wgrad, grad_bias_, *_, rs_out = general_gemm( wgrad, grad_bias_, *_, rs_out = general_gemm(
ln_out_total, ln_out_total,
grad_output, grad_output,
...@@ -704,7 +725,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -704,7 +725,7 @@ class _LayerNormLinear(torch.autograd.Function):
), ),
bias=(bias if (grad_bias is None and not ctx.fp8) else None), bias=(bias if (grad_bias is None and not ctx.fp8) else None),
out=main_grad if ctx.fuse_wgrad_accumulation else None, out=main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=wgrad_gemm_use_split_accumulator, use_split_accumulator=use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
ub=ub_obj_wgrad, ub=ub_obj_wgrad,
ub_type=ub_type_wgrad, ub_type=ub_type_wgrad,
...@@ -728,7 +749,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -728,7 +749,7 @@ class _LayerNormLinear(torch.autograd.Function):
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme # TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
clear_tensor_data(ln_out_total) clear_tensor_data(ln_out_total)
# Synchronize tensor parallel communication # Make sure all tensor-parallel communication is finished
if ln_out_total_work is not None: if ln_out_total_work is not None:
ln_out_total_work.wait() ln_out_total_work.wait()
ln_out_total_work = None ln_out_total_work = None
......
...@@ -56,7 +56,11 @@ from ..distributed import ( ...@@ -56,7 +56,11 @@ from ..distributed import (
from ..constants import dist_group_type from ..constants import dist_group_type
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ..tensor.float8_tensor import Float8Tensor from ..tensor.float8_tensor import (
Float8CurrentScalingQuantizer,
Float8Quantizer,
Float8Tensor,
)
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ._common import apply_normalization, _fix_gathered_fp8_transpose from ._common import apply_normalization, _fix_gathered_fp8_transpose
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
...@@ -642,15 +646,27 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -642,15 +646,27 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.ub_bulk_dgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_dgrad ctx.ub_bulk_dgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_dgrad
ctx.ub_bulk_wgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_wgrad ctx.ub_bulk_wgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_wgrad
# Prepare grad output tensor # Configure quantizer for FC2 grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication # Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_fc2_output_quantizer is not None: if ctx.grad_fc2_output_quantizer is not None:
# Reduce duplicated transpose, which is performed in grad_output.update_usage rowwise_usage = True
if ctx.ub_overlap_ag and ctx.fp8_recipe.float8_per_tensor_scaling(): columnwise_usage = True
ctx.grad_fc2_output_quantizer.set_usage(rowwise=True, columnwise=False) if ctx.ub_overlap_ag and isinstance(
else: ctx.grad_fc2_output_quantizer,
ctx.grad_fc2_output_quantizer.set_usage(rowwise=True, columnwise=True) (Float8Quantizer, Float8CurrentScalingQuantizer),
):
# If data is in FP8 and communication is handled
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage = False
ctx.grad_fc2_output_quantizer.set_usage(
rowwise=rowwise_usage,
columnwise=columnwise_usage,
)
# Prepare FC2 grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
ub_obj_fc2_dgrad = None ub_obj_fc2_dgrad = None
if ctx.ub_overlap_ag: if ctx.ub_overlap_ag:
ub_obj_fc2_dgrad = get_ub("fc2_dgrad") ub_obj_fc2_dgrad = get_ub("fc2_dgrad")
...@@ -662,8 +678,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -662,8 +678,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx, grad_outputs[0], True, ctx.grad_fc2_output_quantizer ctx, grad_outputs[0], True, ctx.grad_fc2_output_quantizer
) )
# Prepare FC1 GEMM input # Launch tensor-parallel communication for FC1 GEMM input
# Note: Perform tensor-parallel communication if needed
ln_out_total = None ln_out_total = None
ln_out_total_work = None ln_out_total_work = None
if ( if (
...@@ -675,7 +690,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -675,7 +690,12 @@ class _LayerNormMLP(torch.autograd.Function):
quantizer = None quantizer = None
if ctx.fp8: if ctx.fp8:
quantizer = ctx.fc1_input_quantizer quantizer = ctx.fc1_input_quantizer
quantizer.set_usage(rowwise=False, columnwise=True) if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
quantizer.set_usage(rowwise=True, columnwise=False)
else:
# wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out, ln_out,
ctx.tp_group, ctx.tp_group,
...@@ -868,6 +888,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -868,6 +888,8 @@ class _LayerNormMLP(torch.autograd.Function):
# FC1 WGRAD # FC1 WGRAD
fc1_wgrad = None fc1_wgrad = None
if ctx.fc1_weight_requires_grad: if ctx.fc1_weight_requires_grad:
# Synchronize tensor-parallel communication for FC1 GEMM input tensor
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
ln_out_total = ub_obj_fc1_dgrad.get_buffer(ctx.fc1_input_quantizer) ln_out_total = ub_obj_fc1_dgrad.get_buffer(ctx.fc1_input_quantizer)
if ctx.fp8: if ctx.fp8:
...@@ -879,24 +901,24 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -879,24 +901,24 @@ class _LayerNormMLP(torch.autograd.Function):
# FP8 GEMM on Hopper only supports TN layout so the gathered input must # FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose. # have a valid transpose.
ln_out_total._create_transpose() ln_out_total._create_transpose()
if ln_out_total_work is not None:
ln_out_total_work.wait()
ln_out_total_work = None
else: # Make sure GEMM inputs have required data
if ln_out_total_work is not None:
# Synchronize tensor-parallel communication
ln_out_total_work.wait()
ln_out_total_work = None
# Make sure GEMM inputs have expected data
if isinstance(ln_out_total, QuantizedTensor): if isinstance(ln_out_total, QuantizedTensor):
ln_out_total.update_usage(rowwise_usage=True, columnwise_usage=True) ln_out_total.update_usage(columnwise_usage=True)
if isinstance(dact, QuantizedTensor): if isinstance(dact, QuantizedTensor):
dact.update_usage(rowwise_usage=True, columnwise_usage=True) dact.update_usage(columnwise_usage=True)
# Output buffer for overlapping grad input
# reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad and ub_obj_fc1_wgrad.is_fp8_ubuf(): if ctx.ub_bulk_wgrad and ub_obj_fc1_wgrad.is_fp8_ubuf():
fc1_dgrad_rs_out = torch.empty( fc1_dgrad_rs_out = torch.empty(
fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda" fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda"
) )
# wgrad GEMM
fc1_wgrad_outputs = general_gemm( fc1_wgrad_outputs = general_gemm(
ln_out_total, ln_out_total,
dact, dact,
...@@ -930,7 +952,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -930,7 +952,7 @@ class _LayerNormMLP(torch.autograd.Function):
else: else:
fc1_dgrad = ub_obj_fc1_wgrad.get_buffer(None, local_chunk=True) fc1_dgrad = ub_obj_fc1_wgrad.get_buffer(None, local_chunk=True)
# Synchronize tensor parallel communication # Make sure all tensor-parallel communication is finished
if ln_out_total_work is not None: if ln_out_total_work is not None:
ln_out_total_work.wait() ln_out_total_work.wait()
ln_out_total_work = None ln_out_total_work = None
......
...@@ -56,6 +56,7 @@ from ..tensor.quantized_tensor import ( ...@@ -56,6 +56,7 @@ from ..tensor.quantized_tensor import (
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
...@@ -471,14 +472,27 @@ class _Linear(torch.autograd.Function): ...@@ -471,14 +472,27 @@ class _Linear(torch.autograd.Function):
ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer) ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer)
dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer) dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer)
# Configure quantizer for grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
rowwise_usage = True
columnwise_usage = True
if ctx.ub_overlap_ag and isinstance(
ctx.grad_output_quantizer,
(Float8Quantizer, Float8CurrentScalingQuantizer),
):
# If data is in FP8 and communication is handled
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage = False
ctx.grad_output_quantizer.set_usage(
rowwise=rowwise_usage,
columnwise=columnwise_usage,
)
# Prepare grad output tensor # Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication # Note: Cast to expected dtype and perform tensor-parallel communication
if ctx.grad_output_quantizer is not None:
# Reduce duplicated transpose, which is performed in grad_output.update_usage
if ctx.ub_overlap_ag and ctx.fp8_recipe.float8_per_tensor_scaling():
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=False)
else:
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True)
nvtx_range_push(f"{nvtx_label}.grad_output_preprocess") nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
( (
grad_output, grad_output,
...@@ -491,15 +505,19 @@ class _Linear(torch.autograd.Function): ...@@ -491,15 +505,19 @@ class _Linear(torch.autograd.Function):
) )
nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess") nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
# Prepare input tensor # Launch tensor-parallel communication for input tensor
# Note: Perform tensor-parallel communication if needed
inputmat_total = None inputmat_total = None
inputmat_total_work = None inputmat_total_work = None
if ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad: if ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad:
quantizer = None quantizer = None
if ctx.fp8: if ctx.fp8:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
quantizer.set_usage(rowwise=False, columnwise=True) if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
quantizer.set_usage(rowwise=True, columnwise=False)
else:
# wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
inputmat_total, inputmat_total_work = gather_along_first_dim( inputmat_total, inputmat_total_work = gather_along_first_dim(
inputmat, inputmat,
...@@ -573,6 +591,8 @@ class _Linear(torch.autograd.Function): ...@@ -573,6 +591,8 @@ class _Linear(torch.autograd.Function):
# Compute grad weight tensor # Compute grad weight tensor
wgrad = None wgrad = None
if ctx.requires_wgrad: if ctx.requires_wgrad:
# Synchronize tensor-parallel communication for input tensor
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
inputmat_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer) inputmat_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer)
if ctx.fp8: if ctx.fp8:
...@@ -586,18 +606,25 @@ class _Linear(torch.autograd.Function): ...@@ -586,18 +606,25 @@ class _Linear(torch.autograd.Function):
# FP8 GEMM on Hopper only supports TN layout so the gathered input must # FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose. # have a valid transpose.
inputmat_total._create_transpose() inputmat_total._create_transpose()
if inputmat_total_work is not None:
inputmat_total_work.wait()
inputmat_total_work = None
else: # Make sure GEMM inputs have required data
if inputmat_total_work is not None: if isinstance(inputmat_total, QuantizedTensor):
# Synchronize tensor-parallel communication inputmat_total.update_usage(columnwise_usage=True)
inputmat_total_work.wait()
inputmat_total_work = None
if isinstance(grad_output, QuantizedTensor): if isinstance(grad_output, QuantizedTensor):
# This is a no-op if platform supports non-TN FP8 GEMM or the transpose grad_output.update_usage(columnwise_usage=True)
# already exists.
grad_output.update_usage(rowwise_usage=True, columnwise_usage=True)
# Figure out whether to use split accumulator
use_split_accumulator = _2X_ACC_WGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_wgrad"):
use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator
# Output buffer for overlapping grad input
# reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf(): if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf():
rs_out = torch.empty( rs_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
...@@ -606,14 +633,6 @@ class _Linear(torch.autograd.Function): ...@@ -606,14 +633,6 @@ class _Linear(torch.autograd.Function):
# wgrad GEMM # wgrad GEMM
# Note: Fuse with bgrad computation if needed # Note: Fuse with bgrad computation if needed
nvtx_range_push(f"{nvtx_label}.wgrad_gemm") nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_wgrad"):
wgrad_gemm_use_split_accumulator = (
recipe.fp8_gemm_wgrad.use_split_accumulator
)
wgrad, grad_bias_, _, rs_out = general_gemm( wgrad, grad_bias_, _, rs_out = general_gemm(
inputmat_total, inputmat_total,
grad_output, grad_output,
...@@ -625,7 +644,7 @@ class _Linear(torch.autograd.Function): ...@@ -625,7 +644,7 @@ class _Linear(torch.autograd.Function):
), ),
bias=(bias if (grad_bias is None and not ctx.fp8) else None), bias=(bias if (grad_bias is None and not ctx.fp8) else None),
out=main_grad if ctx.fuse_wgrad_accumulation else None, out=main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=wgrad_gemm_use_split_accumulator, use_split_accumulator=use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
ub=ub_obj_wgrad, ub=ub_obj_wgrad,
ub_type=ub_type_wgrad, ub_type=ub_type_wgrad,
...@@ -652,7 +671,7 @@ class _Linear(torch.autograd.Function): ...@@ -652,7 +671,7 @@ class _Linear(torch.autograd.Function):
if not ctx.use_bias: if not ctx.use_bias:
grad_bias = None grad_bias = None
# Synchronize tensor parallel communication # Make sure all tensor-parallel communication is finished
if inputmat_total_work is not None: if inputmat_total_work is not None:
inputmat_total_work.wait() inputmat_total_work.wait()
inputmat_total_work = None inputmat_total_work = None
......
...@@ -134,3 +134,11 @@ class Float8TensorBase: ...@@ -134,3 +134,11 @@ class Float8TensorBase:
f"data={self.dequantize()}" f"data={self.dequantize()}"
")" ")"
) )
def _create_transpose(self):
"""Update FP8 transpose cache"""
data = self._data
if not data.is_contiguous():
data = data.contiguous()
self._transpose = tex.fp8_transpose(data, self._fp8_dtype, out=self._transpose)
self._transpose_invalid = False
...@@ -422,13 +422,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor): ...@@ -422,13 +422,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
return Float8Tensor.make_like(self) return Float8Tensor.make_like(self)
def _create_transpose(self):
data = self._data
if not data.is_contiguous():
data = data.contiguous()
self._transpose = tex.fp8_transpose(data, self._fp8_dtype, out=self._transpose)
self._transpose_invalid = False
def update_usage( def update_usage(
self, self,
rowwise_usage: Optional[bool] = None, rowwise_usage: Optional[bool] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment