Unverified Commit 0356010c authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Optimize MXFP8 all-gathers (#1581)



* Coalesce NCCL all-gathers for MXFP8 all-gather
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add missing import
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Cache quantized input tensor after linear module forward pass
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid unnecessarily allocating layernorm output in LayerNormLinear/LayerNormMLP
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 65c2798a
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
"""Methods needed for distributed training (DP/TP).""" """Methods needed for distributed training (DP/TP)."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable
from contextlib import contextmanager, AbstractContextManager, ContextDecorator from contextlib import contextmanager, AbstractContextManager, ContextDecorator
from functools import lru_cache from functools import lru_cache
import math import math
...@@ -876,10 +877,14 @@ def _all_gather_fp8( ...@@ -876,10 +877,14 @@ def _all_gather_fp8(
# we cannot directly gather the transposed fp8 tensor # we cannot directly gather the transposed fp8 tensor
# so we need to disable columnwise usage for the quantizer # so we need to disable columnwise usage for the quantizer
# and then set it back to the original value after quantizing # and then set it back to the original value after quantizing
init_rowwise_usage = quantizer.rowwise_usage
init_columnwise_usage = quantizer.columnwise_usage init_columnwise_usage = quantizer.columnwise_usage
quantizer.set_usage(columnwise=False) quantizer.set_usage(rowwise=True, columnwise=False)
inp = quantizer(inp) inp = quantizer(inp)
quantizer.set_usage(columnwise=init_columnwise_usage) quantizer.set_usage(
rowwise=init_rowwise_usage,
columnwise=init_columnwise_usage,
)
# Construct output tensor # Construct output tensor
out: Float8TensorBase out: Float8TensorBase
...@@ -936,9 +941,34 @@ def _all_gather_mxfp8( ...@@ -936,9 +941,34 @@ def _all_gather_mxfp8(
) -> tuple[MXFP8TensorBase, Optional[torch.distributed.Work]]: ) -> tuple[MXFP8TensorBase, Optional[torch.distributed.Work]]:
"""All-gather MXFP8 tensor along first dimension.""" """All-gather MXFP8 tensor along first dimension."""
# Tensor dims # Input tensor attributes
in_shape: Iterable[int]
device: torch.device
dtype: torch.dtype
if isinstance(inp, torch.Tensor):
in_shape = inp.size()
device = inp.device
dtype = inp.dtype
elif isinstance(inp, MXFP8TensorBase):
if inp._rowwise_data is not None:
in_shape = inp._rowwise_data.device.size()
device = inp._rowwise_data.device
dtype = inp._rowwise_data.dtype
elif inp._columnwise_data is not None:
in_shape = inp._columnwise_data.device.size()
device = inp._columnwise_data.device
dtype = inp._columnwise_data.dtype
else:
raise ValueError("Got MXFP8 input tensor without any data")
dtype = torch.bfloat16
else:
raise ValueError(
"Invalid type for input tensor (expected torch.Tensor or MXFP8TensorBase, "
f"found {inp.__class__.__name__})"
)
# Output tensor shape
world_size = get_distributed_world_size(process_group) world_size = get_distributed_world_size(process_group)
in_shape = list(inp.size())
if out_shape is None: if out_shape is None:
out_shape = [in_shape[0] * world_size] + in_shape[1:] out_shape = [in_shape[0] * world_size] + in_shape[1:]
...@@ -951,25 +981,19 @@ def _all_gather_mxfp8( ...@@ -951,25 +981,19 @@ def _all_gather_mxfp8(
): ):
out = torch.empty( out = torch.empty(
out_shape, out_shape,
dtype=inp.dtype, dtype=dtype,
device=inp.device, device=device,
memory_format=torch.contiguous_format, memory_format=torch.contiguous_format,
) )
torch.distributed.all_gather_into_tensor(out, inp, group=process_group) torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
out = quantizer(out) out = quantizer(out)
return out, None return out, None
inp_dtype = inp.dtype
inp_device = inp.device
# Cast input tensor to MXFP8 with required data # Cast input tensor to MXFP8 with required data
if not isinstance(inp, MXFP8TensorBase): if not isinstance(inp, MXFP8TensorBase):
inp = quantizer(inp) inp = quantizer(inp)
elif ( elif (quantizer.rowwise_usage and inp._rowwise_data is None) or (
inp.rowwise_data is None quantizer.columnwise_usage and inp._columnwise_data is None
and quantizer.rowwise_usage
or inp.columnwise_data is None
and quantizer.columnwise_usage
): ):
warnings.warn( warnings.warn(
"Input and quantizer do not have matching usages. " "Input and quantizer do not have matching usages. "
...@@ -978,10 +1002,14 @@ def _all_gather_mxfp8( ...@@ -978,10 +1002,14 @@ def _all_gather_mxfp8(
inp = quantizer(inp.dequantize()) inp = quantizer(inp.dequantize())
# Construct MXFP8 output tensor # Construct MXFP8 output tensor
out = quantizer.make_empty(out_shape, dtype=inp_dtype, device=inp_device) out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
# Async op handle # Coalesce NCCL collectives
handle = None with torch.distributed._coalescing_manager(
group=process_group,
device=device,
async_ops=async_op,
) as coalescing_manager:
# Gather MXFP8 data for row-wise usage # Gather MXFP8 data for row-wise usage
if quantizer.rowwise_usage: if quantizer.rowwise_usage:
...@@ -996,18 +1024,15 @@ def _all_gather_mxfp8( ...@@ -996,18 +1024,15 @@ def _all_gather_mxfp8(
out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]
# Launch all-gathers # Launch all-gathers
if handle is not None:
handle.wait()
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
out_scale_inv, out_scale_inv,
in_scale_inv, in_scale_inv,
group=process_group, group=process_group,
) )
handle = torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
out._rowwise_data, out._rowwise_data,
inp._rowwise_data, inp._rowwise_data,
group=process_group, group=process_group,
async_op=async_op,
) )
# Gather MXFP8 data for column-wise usage # Gather MXFP8 data for column-wise usage
...@@ -1023,20 +1048,18 @@ def _all_gather_mxfp8( ...@@ -1023,20 +1048,18 @@ def _all_gather_mxfp8(
out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]
# Launch all-gathers # Launch all-gathers
if handle is not None:
handle.wait()
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
out_scale_inv, out_scale_inv,
in_scale_inv, in_scale_inv,
group=process_group, group=process_group,
) )
handle = torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
out._columnwise_data, out._columnwise_data,
inp._columnwise_data, inp._columnwise_data,
group=process_group, group=process_group,
async_op=async_op,
) )
handle = coalescing_manager if async_op else None
return out, handle return out, handle
......
...@@ -55,7 +55,6 @@ from ..tensor.quantized_tensor import ( ...@@ -55,7 +55,6 @@ from ..tensor.quantized_tensor import (
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
...@@ -137,6 +136,11 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -137,6 +136,11 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias = cast_if_needed(ln_bias, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype)
nvtx_range_pop(f"{nvtx_label}.norm_input_cast") nvtx_range_pop(f"{nvtx_label}.norm_input_cast")
# Avoid quantized norm kernel if norm output will be returned
with_quantized_norm = (
fp8 and not return_layernorm_output and not return_layernorm_output_gathered
)
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag_fprop = ( ub_overlap_ag_fprop = (
ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output
...@@ -146,6 +150,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -146,6 +150,7 @@ class _LayerNormLinear(torch.autograd.Function):
backward_needs_input = is_grad_enabled and weight_requires_grad backward_needs_input = is_grad_enabled and weight_requires_grad
with_input_all_gather = parallel_mode == "column" and sequence_parallel with_input_all_gather = parallel_mode == "column" and sequence_parallel
# Check if Userbuffers is supported
if fp8: if fp8:
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not ( if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not (
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
...@@ -155,103 +160,73 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -155,103 +160,73 @@ class _LayerNormLinear(torch.autograd.Function):
" current scaling" " current scaling"
) )
# Configure quantizer for norm output
if fp8:
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
columnwise_usage = backward_needs_input
# Configure quantizer for normalization output
with_quantized_norm = fp8 and not return_layernorm_output
if with_quantized_norm:
if with_input_all_gather:
input_quantizer.set_usage(rowwise=True, columnwise=False)
if isinstance(input_quantizer, MXFP8Quantizer):
with_quantized_norm = False
else:
input_quantizer.set_usage(
rowwise=True,
columnwise=backward_needs_input,
)
# Reduce duplicated transpose in `_fix_gathered_fp8_transpose`
if ( if (
fp8 columnwise_usage
and FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() and with_input_all_gather
and ub_bulk_dgrad and not isinstance(input_quantizer, MXFP8Quantizer)
): ):
input_quantizer.set_usage(rowwise=True, columnwise=False) columnwise_usage = False
input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
ub_obj_fprop = None
ln_out = None
# For DelayScaling, output of normalization will be in fp8.
# For Float8CurrentScaling, we want the output of normalization in high precision, then quantize to fp8.
if ub_overlap_ag_fprop and not isinstance(input_quantizer, Float8CurrentScalingQuantizer):
ub_obj_fprop = get_ub(ub_name + "_fprop")
ln_out = ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True)
elif with_quantized_norm:
if with_input_all_gather:
input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out = input_quantizer.make_empty(inputmat.shape, dtype=inputmat.dtype, device="cuda")
else:
ln_out = torch.empty_like(
inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format, device="cuda"
)
# Apply normalization # Apply normalization
nvtx_range_push(f"{nvtx_label}.norm") nvtx_range_push(f"{nvtx_label}.norm")
ln_out, mu, rsigma = apply_normalization( ln_out, mu, rsigma = apply_normalization(
inputmat, inputmat,
ln_out, None, # ln_out
ln_weight, ln_weight,
ln_bias, ln_bias,
eps, eps,
input_quantizer if with_quantized_norm else None, input_quantizer if with_quantized_norm else None,
inp.dtype, inputmat.dtype,
normalization, normalization,
fwd_ln_sm_margin, fwd_ln_sm_margin,
zero_centered_gamma, zero_centered_gamma,
) )
ln_out_return = ln_out if return_layernorm_output else None ln_out_return = None
if return_layernorm_output or return_layernorm_output_gathered:
ln_out_return = ln_out
nvtx_range_pop(f"{nvtx_label}.norm") nvtx_range_pop(f"{nvtx_label}.norm")
# For Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer.
# So the output of normalization is in high precision, and we need to quantize it to FP8 and put in the buffer.
if ub_overlap_ag_fprop and isinstance(input_quantizer, Float8CurrentScalingQuantizer):
ub_obj_fprop = get_ub(ub_name + "_fprop")
ln_out_local = ln_out
ln_out = ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True)
input_quantizer.quantize(ln_out_local, out=ln_out)
# Prepare GEMM input # Prepare GEMM input
# Note: Cast to expected dtype and perform tensor-parallel communication # Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.gemm_input_cast_comm") nvtx_range_push(f"{nvtx_label}.gemm_input_cast_comm")
if with_input_all_gather and not ub_overlap_ag_fprop: ln_out_total = None
with_quantized_all_gather = fp8 ub_obj_fprop = None
if return_layernorm_output and return_layernorm_output_gathered: if with_input_all_gather:
with_quantized_all_gather = False if return_layernorm_output_gathered:
# Perform all-gather in high precision if gathered
# norm output will be returned
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total
if fp8: if fp8:
ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False) input_quantizer.set_usage(rowwise=True, columnwise=False)
# ln_out in this has two possibilities:
# 1. in FP8 low precision, the cast was done by fusing quantization into layernorm kernel
# 2. in high precision, then we need to cast it and then gather in FP8
# the output ln_out_total will be in FP8, and it's a full tensor
ln_out_total, _ = gather_along_first_dim(
ln_out,
tp_group,
quantizer=(input_quantizer if with_quantized_all_gather else None),
)
if return_layernorm_output and return_layernorm_output_gathered:
ln_out_return = ln_out_total
if fp8 and not with_quantized_all_gather:
ln_out_total = input_quantizer(ln_out_total) ln_out_total = input_quantizer(ln_out_total)
else: else:
if fp8:
if not with_quantized_norm:
ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag_fprop: if ub_overlap_ag_fprop:
# Copy into Userbuffers buffer
ub_obj_fprop = get_ub(ub_name + "_fprop")
ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True).copy_(ln_out)
ln_out_total = ub_obj_fprop.get_buffer(input_quantizer) ln_out_total = ub_obj_fprop.get_buffer(input_quantizer)
else: else:
if fp8: # All-gather with NCCL
if not isinstance(ln_out, QuantizedTensor): ln_out_total, _ = gather_along_first_dim(
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) ln_out,
tp_group,
quantizer=(input_quantizer if fp8 else None),
)
else:
if fp8 and not with_quantized_norm:
ln_out = input_quantizer(ln_out) ln_out = input_quantizer(ln_out)
elif backward_needs_input:
ln_out.update_usage(rowwise_usage=True, columnwise_usage=True)
ln_out_total = ln_out ln_out_total = ln_out
nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm") nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm")
...@@ -397,7 +372,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -397,7 +372,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight, weight,
bias, bias,
ln_weight, ln_weight,
ln_out.clone() if ub_overlap_ag_fprop else ln_out, # avoid saving a UB buffer ln_out,
mu, mu,
rsigma, rsigma,
) )
...@@ -608,7 +583,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -608,7 +583,7 @@ class _LayerNormLinear(torch.autograd.Function):
quantizer = None quantizer = None
if ctx.fp8: if ctx.fp8:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
quantizer.set_usage(rowwise=True, columnwise=True) quantizer.set_usage(rowwise=False, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out, ln_out,
......
...@@ -60,7 +60,6 @@ from ..tensor.float8_tensor import Float8Tensor ...@@ -60,7 +60,6 @@ from ..tensor.float8_tensor import Float8Tensor
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ._common import apply_normalization, _fix_gathered_fp8_transpose from ._common import apply_normalization, _fix_gathered_fp8_transpose
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
QuantizedTensor, QuantizedTensor,
Quantizer, Quantizer,
...@@ -207,111 +206,80 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -207,111 +206,80 @@ class _LayerNormMLP(torch.autograd.Function):
if ln_bias is not None: if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype)
# for fp8 DelayedScaling: layernorm output = FP8 # Avoid quantized norm kernel if norm output will be returned
# only output of the linear is returned with_quantized_norm = (
# for return_layernorm_output: layernorm output = High precision, then cast to FP8 fp8 and not return_layernorm_output and not return_layernorm_output_gathered
# high precision layernorm output and output of the linear are returned )
with_quantized_norm = fp8 and not return_layernorm_output
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered
ub_overlap_rs = ub_overlap_rs and is_grad_enabled ub_overlap_rs = ub_overlap_rs and is_grad_enabled
with_input_all_gather_nccl = sequence_parallel and not ub_overlap_ag
backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad
# Configure quantizer for normalization output # Configure quantizer for norm output
if fp8 and fc1_input_quantizer is None: if fp8:
raise ValueError("Missing quantizer for input tensor") if fc1_input_quantizer is None:
if with_quantized_norm: raise ValueError("Missing quantizer for FC1 input tensor")
if with_input_all_gather_nccl: columnwise_usage = backwards_needs_fc1_input
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if isinstance(fc1_input_quantizer, MXFP8Quantizer):
with_quantized_norm = False
else:
fc1_input_quantizer.set_usage(
rowwise=True,
columnwise=backwards_needs_fc1_input,
)
# Reduce duplicated transpose in `_fix_gathered_fp8_transpose`
if ( if (
fp8 columnwise_usage
and FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() and sequence_parallel
and ub_bulk_dgrad and not isinstance(fc1_input_quantizer, MXFP8Quantizer)
): ):
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) columnwise_usage = False
fc1_input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
ub_obj_lnout = None
ln_out = None
# For DelayScaling, output of normalization will be in fp8.
# For Float8CurrentScaling, we want the output of normalization in high precision, then quantize to fp8.
if ub_overlap_ag and not isinstance(fc1_input_quantizer, Float8CurrentScalingQuantizer):
ub_obj_lnout = get_ub("fc1_fprop")
ln_out = ub_obj_lnout.get_buffer(fc1_input_quantizer, local_chunk=True)
elif not with_quantized_norm:
ln_out = torch.empty_like(
inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format, device="cuda"
)
# Apply normalization # Apply normalization
ln_out, mu, rsigma = apply_normalization( ln_out, mu, rsigma = apply_normalization(
inputmat, inputmat,
ln_out, None, # ln_out
ln_weight, ln_weight,
ln_bias, ln_bias,
eps, eps,
fc1_input_quantizer if with_quantized_norm else None, fc1_input_quantizer if with_quantized_norm else None,
inp.dtype, inputmat.dtype,
normalization, normalization,
fwd_ln_sm_margin, fwd_ln_sm_margin,
zero_centered_gamma, zero_centered_gamma,
) )
ln_out_return = None
ln_out_return = ln_out if return_layernorm_output else None if return_layernorm_output or return_layernorm_output_gathered:
ln_out_return = ln_out
# For Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer.
# So the output of normalization is in high precision, and we need to quantize it to FP8 and put in the buffer.
if ub_overlap_ag and isinstance(fc1_input_quantizer, Float8CurrentScalingQuantizer):
ub_obj_lnout = get_ub("fc1_fprop")
ln_out_local = ln_out
ln_out = ub_obj_lnout.get_buffer(fc1_input_quantizer, local_chunk=True)
fc1_input_quantizer.quantize(ln_out_local, out=ln_out)
# Prepare GEMM input # Prepare GEMM input
# Note: Cast to expected dtype and perform tensor-parallel communication # Note: Cast to expected dtype and perform tensor-parallel communication
ln_out_gathered = False ln_out_total = None
with_quantized_all_gather = fp8 ub_obj_lnout = None
if with_input_all_gather_nccl: if sequence_parallel:
if return_layernorm_output and return_layernorm_output_gathered: if return_layernorm_output_gathered:
with_quantized_all_gather = False # Perform all-gather in high precision if gathered
# norm output will be returned
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total
if fp8:
ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out_total = fc1_input_quantizer(ln_out_total)
else:
if fp8: if fp8:
if not with_quantized_norm:
ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
# ln_out in this has two possibilities: if ub_overlap_ag:
# 1. in FP8 low precision, the cast was done by fusing quantization into layernorm kernel # Copy into Userbuffers buffer
# 2. in high precision, then we need to cast it and then gather in FP8 ub_obj_lnout = get_ub("fc1_fprop")
# the output ln_out_total will be in FP8, and it's a full tensor ub_obj_lnout.get_buffer(fc1_input_quantizer, local_chunk=True).copy_(ln_out)
ln_out_total = ub_obj_lnout.get_buffer(fc1_input_quantizer)
else:
# All-gather with NCCL
ln_out_total, _ = gather_along_first_dim( ln_out_total, _ = gather_along_first_dim(
ln_out, ln_out,
tp_group, tp_group,
quantizer=(fc1_input_quantizer if with_quantized_all_gather else None), quantizer=(fc1_input_quantizer if fp8 else None),
) )
ln_out_gathered = True
else:
with_quantized_all_gather = False
if ub_overlap_ag:
ln_out_total = ub_obj_lnout.get_buffer(fc1_input_quantizer, False)
else: else:
if fp8: if fp8 and not with_quantized_norm:
if not isinstance(ln_out, QuantizedTensor):
fc1_input_quantizer.set_usage(
rowwise=True, columnwise=backwards_needs_fc1_input
)
ln_out = fc1_input_quantizer(ln_out) ln_out = fc1_input_quantizer(ln_out)
elif backwards_needs_fc1_input:
ln_out.update_usage(rowwise_usage=True, columnwise_usage=True)
# here ln_out is in FP8 low precision, the cast was either done by fc1_input_quantizer
# or fused into the layernorm kernel
# ln_out_total represents the full fp8 tensor, in this case, it's the same as ln_out
ln_out_total = ln_out ln_out_total = ln_out
# Cast weights to expected dtype # Cast weights to expected dtype
...@@ -497,7 +465,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -497,7 +465,7 @@ class _LayerNormMLP(torch.autograd.Function):
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
inputmat, inputmat,
ln_weight, ln_weight,
ln_out.clone() if ub_overlap_ag else ln_out, # avoid saving a UB buffer ln_out,
fc1_weight_final, fc1_weight_final,
fc1_bias, fc1_bias,
fc1_out, fc1_out,
...@@ -544,7 +512,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -544,7 +512,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.bias_gelu_fusion = bias_gelu_fusion ctx.bias_gelu_fusion = bias_gelu_fusion
ctx.return_layernorm_output = return_layernorm_output ctx.return_layernorm_output = return_layernorm_output
ctx.return_layernorm_output_gathered = ( ctx.return_layernorm_output_gathered = (
return_layernorm_output_gathered and ln_out_gathered return_layernorm_output_gathered and sequence_parallel
) )
ctx.set_parallel_mode = set_parallel_mode ctx.set_parallel_mode = set_parallel_mode
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
...@@ -704,7 +672,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -704,7 +672,7 @@ class _LayerNormMLP(torch.autograd.Function):
quantizer = None quantizer = None
if ctx.fp8: if ctx.fp8:
quantizer = ctx.fc1_input_quantizer quantizer = ctx.fc1_input_quantizer
quantizer.set_usage(rowwise=True, columnwise=True) quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out, ln_out,
ctx.tp_group, ctx.tp_group,
......
...@@ -56,6 +56,7 @@ from ..tensor.quantized_tensor import ( ...@@ -56,6 +56,7 @@ from ..tensor.quantized_tensor import (
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
...@@ -140,9 +141,13 @@ class _Linear(torch.autograd.Function): ...@@ -140,9 +141,13 @@ class _Linear(torch.autograd.Function):
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
if with_input_all_gather_nccl: if with_input_all_gather_nccl:
assert not isinstance( if not isinstance(inputmat, QuantizedTensor):
inputmat, QuantizedTensor columnwise_usage = backward_needs_input and isinstance(
), "All gather of fp8 input is not supported" input_quantizer, MXFP8Quantizer
)
input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
inputmat = input_quantizer(inputmat)
own_quantized_input = True
input_quantizer.set_usage(rowwise=True, columnwise=False) input_quantizer.set_usage(rowwise=True, columnwise=False)
inputmat_total, _ = gather_along_first_dim( inputmat_total, _ = gather_along_first_dim(
inputmat, inputmat,
...@@ -269,7 +274,7 @@ class _Linear(torch.autograd.Function): ...@@ -269,7 +274,7 @@ class _Linear(torch.autograd.Function):
# to gather the input. For MXFP8, columnwise only data # to gather the input. For MXFP8, columnwise only data
# can be allgathered. # can be allgathered.
if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather: if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather:
inputmat.update_usage(rowwise_usage=False) inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
saved_inputmat = inputmat saved_inputmat = inputmat
# Weight with column-wise usage is needed for dgrad GEMM. # Weight with column-wise usage is needed for dgrad GEMM.
...@@ -494,7 +499,7 @@ class _Linear(torch.autograd.Function): ...@@ -494,7 +499,7 @@ class _Linear(torch.autograd.Function):
quantizer = None quantizer = None
if ctx.fp8: if ctx.fp8:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
quantizer.set_usage(rowwise=True, columnwise=True) quantizer.set_usage(rowwise=False, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
inputmat_total, inputmat_total_work = gather_along_first_dim( inputmat_total, inputmat_total_work = gather_along_first_dim(
inputmat, inputmat,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment