Unverified Commit 0356010c authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Optimize MXFP8 all-gathers (#1581)



* Coalesce NCCL all-gathers for MXFP8 all-gather
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add missing import
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Cache quantized input tensor after linear module forward pass
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid unnecessarily allocating layernorm output in LayerNormLinear/LayerNormMLP
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 65c2798a
......@@ -5,6 +5,7 @@
"""Methods needed for distributed training (DP/TP)."""
from __future__ import annotations
from collections.abc import Iterable
from contextlib import contextmanager, AbstractContextManager, ContextDecorator
from functools import lru_cache
import math
......@@ -876,10 +877,14 @@ def _all_gather_fp8(
# we cannot directly gather the transposed fp8 tensor
# so we need to disable columnwise usage for the quantizer
# and then set it back to the original value after quantizing
init_rowwise_usage = quantizer.rowwise_usage
init_columnwise_usage = quantizer.columnwise_usage
quantizer.set_usage(columnwise=False)
quantizer.set_usage(rowwise=True, columnwise=False)
inp = quantizer(inp)
quantizer.set_usage(columnwise=init_columnwise_usage)
quantizer.set_usage(
rowwise=init_rowwise_usage,
columnwise=init_columnwise_usage,
)
# Construct output tensor
out: Float8TensorBase
......@@ -936,9 +941,34 @@ def _all_gather_mxfp8(
) -> tuple[MXFP8TensorBase, Optional[torch.distributed.Work]]:
"""All-gather MXFP8 tensor along first dimension."""
# Tensor dims
# Input tensor attributes
in_shape: Iterable[int]
device: torch.device
dtype: torch.dtype
if isinstance(inp, torch.Tensor):
in_shape = inp.size()
device = inp.device
dtype = inp.dtype
elif isinstance(inp, MXFP8TensorBase):
if inp._rowwise_data is not None:
in_shape = inp._rowwise_data.device.size()
device = inp._rowwise_data.device
dtype = inp._rowwise_data.dtype
elif inp._columnwise_data is not None:
in_shape = inp._columnwise_data.device.size()
device = inp._columnwise_data.device
dtype = inp._columnwise_data.dtype
else:
raise ValueError("Got MXFP8 input tensor without any data")
dtype = torch.bfloat16
else:
raise ValueError(
"Invalid type for input tensor (expected torch.Tensor or MXFP8TensorBase, "
f"found {inp.__class__.__name__})"
)
# Output tensor shape
world_size = get_distributed_world_size(process_group)
in_shape = list(inp.size())
if out_shape is None:
out_shape = [in_shape[0] * world_size] + in_shape[1:]
......@@ -951,25 +981,19 @@ def _all_gather_mxfp8(
):
out = torch.empty(
out_shape,
dtype=inp.dtype,
device=inp.device,
dtype=dtype,
device=device,
memory_format=torch.contiguous_format,
)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
out = quantizer(out)
return out, None
inp_dtype = inp.dtype
inp_device = inp.device
# Cast input tensor to MXFP8 with required data
if not isinstance(inp, MXFP8TensorBase):
inp = quantizer(inp)
elif (
inp.rowwise_data is None
and quantizer.rowwise_usage
or inp.columnwise_data is None
and quantizer.columnwise_usage
elif (quantizer.rowwise_usage and inp._rowwise_data is None) or (
quantizer.columnwise_usage and inp._columnwise_data is None
):
warnings.warn(
"Input and quantizer do not have matching usages. "
......@@ -978,65 +1002,64 @@ def _all_gather_mxfp8(
inp = quantizer(inp.dequantize())
# Construct MXFP8 output tensor
out = quantizer.make_empty(out_shape, dtype=inp_dtype, device=inp_device)
# Async op handle
handle = None
# Gather MXFP8 data for row-wise usage
if quantizer.rowwise_usage:
# Remove padding from MXFP8 scale-inverses
in_scale_inv = inp._rowwise_scale_inv
out_scale_inv = out._rowwise_scale_inv
flattened_in_shape0 = math.prod(in_shape[:-1])
if in_scale_inv.size(0) != flattened_in_shape0:
in_scale_inv = in_scale_inv[:flattened_in_shape0]
out_scale_inv[flattened_in_shape0 * world_size :].zero_()
out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]
# Launch all-gathers
if handle is not None:
handle.wait()
torch.distributed.all_gather_into_tensor(
out_scale_inv,
in_scale_inv,
group=process_group,
)
handle = torch.distributed.all_gather_into_tensor(
out._rowwise_data,
inp._rowwise_data,
group=process_group,
async_op=async_op,
)
# Gather MXFP8 data for column-wise usage
if quantizer.columnwise_usage:
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
# Remove padding from MXFP8 scale-inverses
in_scale_inv = inp._columnwise_scale_inv
out_scale_inv = out._columnwise_scale_inv
flattened_in_shape0 = math.prod(in_shape[:-1]) // 32
if in_scale_inv.size(0) != flattened_in_shape0:
in_scale_inv = in_scale_inv[:flattened_in_shape0]
out_scale_inv[flattened_in_shape0 * world_size :].zero_()
out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]
# Coalesce NCCL collectives
with torch.distributed._coalescing_manager(
group=process_group,
device=device,
async_ops=async_op,
) as coalescing_manager:
# Gather MXFP8 data for row-wise usage
if quantizer.rowwise_usage:
# Remove padding from MXFP8 scale-inverses
in_scale_inv = inp._rowwise_scale_inv
out_scale_inv = out._rowwise_scale_inv
flattened_in_shape0 = math.prod(in_shape[:-1])
if in_scale_inv.size(0) != flattened_in_shape0:
in_scale_inv = in_scale_inv[:flattened_in_shape0]
out_scale_inv[flattened_in_shape0 * world_size :].zero_()
out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]
# Launch all-gathers
torch.distributed.all_gather_into_tensor(
out_scale_inv,
in_scale_inv,
group=process_group,
)
torch.distributed.all_gather_into_tensor(
out._rowwise_data,
inp._rowwise_data,
group=process_group,
)
# Launch all-gathers
if handle is not None:
handle.wait()
torch.distributed.all_gather_into_tensor(
out_scale_inv,
in_scale_inv,
group=process_group,
)
handle = torch.distributed.all_gather_into_tensor(
out._columnwise_data,
inp._columnwise_data,
group=process_group,
async_op=async_op,
)
# Gather MXFP8 data for column-wise usage
if quantizer.columnwise_usage:
# Remove padding from MXFP8 scale-inverses
in_scale_inv = inp._columnwise_scale_inv
out_scale_inv = out._columnwise_scale_inv
flattened_in_shape0 = math.prod(in_shape[:-1]) // 32
if in_scale_inv.size(0) != flattened_in_shape0:
in_scale_inv = in_scale_inv[:flattened_in_shape0]
out_scale_inv[flattened_in_shape0 * world_size :].zero_()
out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]
# Launch all-gathers
torch.distributed.all_gather_into_tensor(
out_scale_inv,
in_scale_inv,
group=process_group,
)
torch.distributed.all_gather_into_tensor(
out._columnwise_data,
inp._columnwise_data,
group=process_group,
)
handle = coalescing_manager if async_op else None
return out, handle
......
......@@ -55,7 +55,6 @@ from ..tensor.quantized_tensor import (
prepare_for_saving,
restore_from_saved,
)
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
......@@ -137,6 +136,11 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias = cast_if_needed(ln_bias, activation_dtype)
nvtx_range_pop(f"{nvtx_label}.norm_input_cast")
# Avoid quantized norm kernel if norm output will be returned
with_quantized_norm = (
fp8 and not return_layernorm_output and not return_layernorm_output_gathered
)
tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag_fprop = (
ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output
......@@ -146,6 +150,7 @@ class _LayerNormLinear(torch.autograd.Function):
backward_needs_input = is_grad_enabled and weight_requires_grad
with_input_all_gather = parallel_mode == "column" and sequence_parallel
# Check if Userbuffers is supported
if fp8:
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not (
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
......@@ -155,104 +160,74 @@ class _LayerNormLinear(torch.autograd.Function):
" current scaling"
)
# Configure quantizer for norm output
if fp8:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
# Configure quantizer for normalization output
with_quantized_norm = fp8 and not return_layernorm_output
if with_quantized_norm:
if with_input_all_gather:
input_quantizer.set_usage(rowwise=True, columnwise=False)
if isinstance(input_quantizer, MXFP8Quantizer):
with_quantized_norm = False
else:
input_quantizer.set_usage(
rowwise=True,
columnwise=backward_needs_input,
)
# Reduce duplicated transpose in `_fix_gathered_fp8_transpose`
if (
fp8
and FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
and ub_bulk_dgrad
):
input_quantizer.set_usage(rowwise=True, columnwise=False)
ub_obj_fprop = None
ln_out = None
# For DelayScaling, output of normalization will be in fp8.
# For Float8CurrentScaling, we want the output of normalization in high precision, then quantize to fp8.
if ub_overlap_ag_fprop and not isinstance(input_quantizer, Float8CurrentScalingQuantizer):
ub_obj_fprop = get_ub(ub_name + "_fprop")
ln_out = ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True)
elif with_quantized_norm:
if with_input_all_gather:
input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out = input_quantizer.make_empty(inputmat.shape, dtype=inputmat.dtype, device="cuda")
else:
ln_out = torch.empty_like(
inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format, device="cuda"
)
columnwise_usage = backward_needs_input
if (
columnwise_usage
and with_input_all_gather
and not isinstance(input_quantizer, MXFP8Quantizer)
):
columnwise_usage = False
input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
# Apply normalization
nvtx_range_push(f"{nvtx_label}.norm")
ln_out, mu, rsigma = apply_normalization(
inputmat,
ln_out,
None, # ln_out
ln_weight,
ln_bias,
eps,
input_quantizer if with_quantized_norm else None,
inp.dtype,
inputmat.dtype,
normalization,
fwd_ln_sm_margin,
zero_centered_gamma,
)
ln_out_return = ln_out if return_layernorm_output else None
ln_out_return = None
if return_layernorm_output or return_layernorm_output_gathered:
ln_out_return = ln_out
nvtx_range_pop(f"{nvtx_label}.norm")
# For Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer.
# So the output of normalization is in high precision, and we need to quantize it to FP8 and put in the buffer.
if ub_overlap_ag_fprop and isinstance(input_quantizer, Float8CurrentScalingQuantizer):
ub_obj_fprop = get_ub(ub_name + "_fprop")
ln_out_local = ln_out
ln_out = ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True)
input_quantizer.quantize(ln_out_local, out=ln_out)
# Prepare GEMM input
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.gemm_input_cast_comm")
if with_input_all_gather and not ub_overlap_ag_fprop:
with_quantized_all_gather = fp8
if return_layernorm_output and return_layernorm_output_gathered:
with_quantized_all_gather = False
if fp8:
input_quantizer.set_usage(rowwise=True, columnwise=False)
# ln_out in this has two possibilities:
# 1. in FP8 low precision, the cast was done by fusing quantization into layernorm kernel
# 2. in high precision, then we need to cast it and then gather in FP8
# the output ln_out_total will be in FP8, and it's a full tensor
ln_out_total, _ = gather_along_first_dim(
ln_out,
tp_group,
quantizer=(input_quantizer if with_quantized_all_gather else None),
)
if return_layernorm_output and return_layernorm_output_gathered:
ln_out_total = None
ub_obj_fprop = None
if with_input_all_gather:
if return_layernorm_output_gathered:
# Perform all-gather in high precision if gathered
# norm output will be returned
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total
if fp8 and not with_quantized_all_gather:
ln_out_total = input_quantizer(ln_out_total)
else:
if ub_overlap_ag_fprop:
ln_out_total = ub_obj_fprop.get_buffer(input_quantizer)
if fp8:
ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out_total = input_quantizer(ln_out_total)
else:
if fp8:
if not isinstance(ln_out, QuantizedTensor):
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if not with_quantized_norm:
ln_out = input_quantizer(ln_out)
elif backward_needs_input:
ln_out.update_usage(rowwise_usage=True, columnwise_usage=True)
ln_out_total = ln_out
input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag_fprop:
# Copy into Userbuffers buffer
ub_obj_fprop = get_ub(ub_name + "_fprop")
ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True).copy_(ln_out)
ln_out_total = ub_obj_fprop.get_buffer(input_quantizer)
else:
# All-gather with NCCL
ln_out_total, _ = gather_along_first_dim(
ln_out,
tp_group,
quantizer=(input_quantizer if fp8 else None),
)
else:
if fp8 and not with_quantized_norm:
ln_out = input_quantizer(ln_out)
ln_out_total = ln_out
nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm")
# Cast weight to expected dtype
......@@ -397,7 +372,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight,
bias,
ln_weight,
ln_out.clone() if ub_overlap_ag_fprop else ln_out, # avoid saving a UB buffer
ln_out,
mu,
rsigma,
)
......@@ -608,7 +583,7 @@ class _LayerNormLinear(torch.autograd.Function):
quantizer = None
if ctx.fp8:
quantizer = ctx.input_quantizer
quantizer.set_usage(rowwise=True, columnwise=True)
quantizer.set_usage(rowwise=False, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out,
......
......@@ -60,7 +60,6 @@ from ..tensor.float8_tensor import Float8Tensor
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ._common import apply_normalization, _fix_gathered_fp8_transpose
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
from ..tensor.quantized_tensor import (
QuantizedTensor,
Quantizer,
......@@ -207,112 +206,81 @@ class _LayerNormMLP(torch.autograd.Function):
if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype)
# for fp8 DelayedScaling: layernorm output = FP8
# only output of the linear is returned
# for return_layernorm_output: layernorm output = High precision, then cast to FP8
# high precision layernorm output and output of the linear are returned
with_quantized_norm = fp8 and not return_layernorm_output
# Avoid quantized norm kernel if norm output will be returned
with_quantized_norm = (
fp8 and not return_layernorm_output and not return_layernorm_output_gathered
)
tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output
ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered
ub_overlap_rs = ub_overlap_rs and is_grad_enabled
with_input_all_gather_nccl = sequence_parallel and not ub_overlap_ag
backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad
# Configure quantizer for normalization output
if fp8 and fc1_input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if with_quantized_norm:
if with_input_all_gather_nccl:
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if isinstance(fc1_input_quantizer, MXFP8Quantizer):
with_quantized_norm = False
else:
fc1_input_quantizer.set_usage(
rowwise=True,
columnwise=backwards_needs_fc1_input,
)
# Reduce duplicated transpose in `_fix_gathered_fp8_transpose`
if (
fp8
and FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
and ub_bulk_dgrad
):
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
ub_obj_lnout = None
ln_out = None
# For DelayScaling, output of normalization will be in fp8.
# For Float8CurrentScaling, we want the output of normalization in high precision, then quantize to fp8.
if ub_overlap_ag and not isinstance(fc1_input_quantizer, Float8CurrentScalingQuantizer):
ub_obj_lnout = get_ub("fc1_fprop")
ln_out = ub_obj_lnout.get_buffer(fc1_input_quantizer, local_chunk=True)
elif not with_quantized_norm:
ln_out = torch.empty_like(
inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format, device="cuda"
)
# Configure quantizer for norm output
if fp8:
if fc1_input_quantizer is None:
raise ValueError("Missing quantizer for FC1 input tensor")
columnwise_usage = backwards_needs_fc1_input
if (
columnwise_usage
and sequence_parallel
and not isinstance(fc1_input_quantizer, MXFP8Quantizer)
):
columnwise_usage = False
fc1_input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
# Apply normalization
ln_out, mu, rsigma = apply_normalization(
inputmat,
ln_out,
None, # ln_out
ln_weight,
ln_bias,
eps,
fc1_input_quantizer if with_quantized_norm else None,
inp.dtype,
inputmat.dtype,
normalization,
fwd_ln_sm_margin,
zero_centered_gamma,
)
ln_out_return = ln_out if return_layernorm_output else None
# For Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer.
# So the output of normalization is in high precision, and we need to quantize it to FP8 and put in the buffer.
if ub_overlap_ag and isinstance(fc1_input_quantizer, Float8CurrentScalingQuantizer):
ub_obj_lnout = get_ub("fc1_fprop")
ln_out_local = ln_out
ln_out = ub_obj_lnout.get_buffer(fc1_input_quantizer, local_chunk=True)
fc1_input_quantizer.quantize(ln_out_local, out=ln_out)
ln_out_return = None
if return_layernorm_output or return_layernorm_output_gathered:
ln_out_return = ln_out
# Prepare GEMM input
# Note: Cast to expected dtype and perform tensor-parallel communication
ln_out_gathered = False
with_quantized_all_gather = fp8
if with_input_all_gather_nccl:
if return_layernorm_output and return_layernorm_output_gathered:
with_quantized_all_gather = False
if fp8:
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
# ln_out in this has two possibilities:
# 1. in FP8 low precision, the cast was done by fusing quantization into layernorm kernel
# 2. in high precision, then we need to cast it and then gather in FP8
# the output ln_out_total will be in FP8, and it's a full tensor
ln_out_total, _ = gather_along_first_dim(
ln_out,
tp_group,
quantizer=(fc1_input_quantizer if with_quantized_all_gather else None),
)
ln_out_gathered = True
else:
with_quantized_all_gather = False
if ub_overlap_ag:
ln_out_total = ub_obj_lnout.get_buffer(fc1_input_quantizer, False)
ln_out_total = None
ub_obj_lnout = None
if sequence_parallel:
if return_layernorm_output_gathered:
# Perform all-gather in high precision if gathered
# norm output will be returned
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total
if fp8:
ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out_total = fc1_input_quantizer(ln_out_total)
else:
if fp8:
if not isinstance(ln_out, QuantizedTensor):
fc1_input_quantizer.set_usage(
rowwise=True, columnwise=backwards_needs_fc1_input
)
if not with_quantized_norm:
ln_out = fc1_input_quantizer(ln_out)
elif backwards_needs_fc1_input:
ln_out.update_usage(rowwise_usage=True, columnwise_usage=True)
# here ln_out is in FP8 low precision, the cast was either done by fc1_input_quantizer
# or fused into the layernorm kernel
# ln_out_total represents the full fp8 tensor, in this case, it's the same as ln_out
ln_out_total = ln_out
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag:
# Copy into Userbuffers buffer
ub_obj_lnout = get_ub("fc1_fprop")
ub_obj_lnout.get_buffer(fc1_input_quantizer, local_chunk=True).copy_(ln_out)
ln_out_total = ub_obj_lnout.get_buffer(fc1_input_quantizer)
else:
# All-gather with NCCL
ln_out_total, _ = gather_along_first_dim(
ln_out,
tp_group,
quantizer=(fc1_input_quantizer if fp8 else None),
)
else:
if fp8 and not with_quantized_norm:
ln_out = fc1_input_quantizer(ln_out)
ln_out_total = ln_out
# Cast weights to expected dtype
if not fp8:
......@@ -497,7 +465,7 @@ class _LayerNormMLP(torch.autograd.Function):
tensors_to_save, tensor_objects = prepare_for_saving(
inputmat,
ln_weight,
ln_out.clone() if ub_overlap_ag else ln_out, # avoid saving a UB buffer
ln_out,
fc1_weight_final,
fc1_bias,
fc1_out,
......@@ -544,7 +512,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.bias_gelu_fusion = bias_gelu_fusion
ctx.return_layernorm_output = return_layernorm_output
ctx.return_layernorm_output_gathered = (
return_layernorm_output_gathered and ln_out_gathered
return_layernorm_output_gathered and sequence_parallel
)
ctx.set_parallel_mode = set_parallel_mode
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
......@@ -704,7 +672,7 @@ class _LayerNormMLP(torch.autograd.Function):
quantizer = None
if ctx.fp8:
quantizer = ctx.fc1_input_quantizer
quantizer.set_usage(rowwise=True, columnwise=True)
quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out,
ctx.tp_group,
......
......@@ -56,6 +56,7 @@ from ..tensor.quantized_tensor import (
prepare_for_saving,
restore_from_saved,
)
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
......@@ -140,9 +141,13 @@ class _Linear(torch.autograd.Function):
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if with_input_all_gather_nccl:
assert not isinstance(
inputmat, QuantizedTensor
), "All gather of fp8 input is not supported"
if not isinstance(inputmat, QuantizedTensor):
columnwise_usage = backward_needs_input and isinstance(
input_quantizer, MXFP8Quantizer
)
input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
inputmat = input_quantizer(inputmat)
own_quantized_input = True
input_quantizer.set_usage(rowwise=True, columnwise=False)
inputmat_total, _ = gather_along_first_dim(
inputmat,
......@@ -269,7 +274,7 @@ class _Linear(torch.autograd.Function):
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather:
inputmat.update_usage(rowwise_usage=False)
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
saved_inputmat = inputmat
# Weight with column-wise usage is needed for dgrad GEMM.
......@@ -494,7 +499,7 @@ class _Linear(torch.autograd.Function):
quantizer = None
if ctx.fp8:
quantizer = ctx.input_quantizer
quantizer.set_usage(rowwise=True, columnwise=True)
quantizer.set_usage(rowwise=False, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
inputmat_total, inputmat_total_work = gather_along_first_dim(
inputmat,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment