Unverified Commit ce0b46c4 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

MXFP8 support in Userbuffers (#1711)



* Initial work toward restoring UB support in te.Sequential
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Forward UB linear runs, but has numerical error
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug UB forward tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Minor tweaks
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove Python checks for MXFP8 UB linear forward
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add dim check for MXFP8 full tiles
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move QuantizedTensor logic out of UB comm and into Python helper function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support MXFP8 AGs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Coalesce NCCL all-gathers for MXFP8 all-gather
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Initial impl of backward UB linear in te.Sequential
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug UB linear backward with no quantization

dgrad GEMM + dx RS is still broken.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix chunk dims for dgrad GEMM + dx RS
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debugging MXFP8 UB cases

Still failing with dy AG + wgrad GEMM
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use NCCL to overlap dy AG with dgrad GEMM
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug UB GEMM tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Initial refactoring of linear module forward
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactor linear module backward
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug linear module UB tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tweak test tensor dims
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Do not store autograd context within wgrad GEMM closure
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update LayerNormLinear
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update LayerNormMLP
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug UB tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Minor style tweaks
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix incorrect usage for GEMM input with block-scaled FP8
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix RS out dims
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable dgrad GEMM + UB AG + NCCL AG overlapping
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Disable dgrad GEMM + UB AG + NCCL AG overlap in te.Sequential
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Restore support for internal quantized tensors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add tests for MXFP8 GEMM with UB
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 3d2152c2
......@@ -26,7 +26,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PAT
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
# python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
......
......@@ -20,7 +20,11 @@ from torch.distributed.elastic.multiprocessing.errors import record
import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_cublas_workspace_size_bytes
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.module.base import (
fill_userbuffers_buffer_for_all_gather,
get_cublas_workspace_size_bytes,
)
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
......@@ -56,7 +60,11 @@ def _parse_args(argv=None, namespace=None):
)
parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
"--quantization",
type=str.lower,
default="none",
choices=["none", "fp8", "mxfp8"],
help="Quantization recipe",
)
parser.add_argument(
"--fp8-output", action="store_true", default=False, help="Get FP8 output from GEMM."
......@@ -154,9 +162,9 @@ def _parse_args(argv=None, namespace=None):
if opts.atomic:
warnings.warn("Atomic GEMM is not supported with bulk overlap.")
opts.atomic = False
if opts.fp8:
if opts.quantization != "none":
warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.")
opts.fp8 = False
opts.quantization = "none"
elif opts.comm_type == tex.CommOverlapType.AG:
if opts.atomic:
setattr(opts, "atomic_rs_p2p", opts.p2p)
......@@ -164,8 +172,11 @@ def _parse_args(argv=None, namespace=None):
if opts.atomic:
if not te.fp8.check_fp8_support():
assert not opts.fp8, "Atomic GEMM is only supported in FP8."
opts.fp8 = True
assert opts.quantization == "none", "Atomic GEMM is only supported in FP8."
opts.quantization = "fp8"
if opts.fp8_output:
assert ops.quantization == "fp8", "FP8 output is only supported with FP8 compute."
return opts
......@@ -302,7 +313,11 @@ def _main(opts):
inp_shape = (opts.seq_length, opts.batch_size, hidden_size)
outer_size = reduce(operator.mul, inp_shape[:-1], 1)
buffer_dtype = torch.bfloat16
if opts.fp8 and not opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.AG:
if (
opts.quantization != "none"
and not opts.bulk_overlap
and opts.comm_type == tex.CommOverlapType.AG
):
buffer_dtype = torch.uint8
ub_obj = (
tex.CommOverlapP2P(
......@@ -447,6 +462,8 @@ def _main(opts):
inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable
ref2_g = torch.matmul(inp2_g, ker2_g)
# Initialize quantizers
with_quantized_compute = opts.quantization != "none"
inp_quantizer = None
ker_quantizer = None
out_quantizer = None
......@@ -454,7 +471,7 @@ def _main(opts):
inp2_quantizer = None
ker2_quantizer = None
out2_quantizer = None
if opts.fp8:
if opts.quantization == "fp8":
# Structure to maintain amax and scale/scale_inv information for the kernel and input
num_gemms = 6 if ub_obj2 is not None else 3
fp8_dtype = tex.DType.kFloat8E4M3
......@@ -499,11 +516,23 @@ def _main(opts):
out2_quantizer = Float8Quantizer(
fp8_scales[5].clone(), fp8_amaxes[5].clone(), fp8_dtype
)
elif opts.quantization == "mxfp8":
fp8_dtype = tex.DType.kFloat8E4M3
inp_quantizer = MXFP8Quantizer(fp8_dtype, columnwise=False)
ker_quantizer = MXFP8Quantizer(fp8_dtype)
if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
bulk_inp_quantizer = MXFP8Quantizer(fp8_dtype, columnwise=False)
elif ub_obj2 is not None:
inp2_quantizer = MXFP8Quantizer(fp8_dtype, columnwise=False)
ker2_quantizer = MXFP8Quantizer(fp8_dtype)
# Quantize tensors
if with_quantized_compute:
# Cast input to Float8Tensor
# Quantize input tensor
inp_fp8 = inp_quantizer(inp)
# Cast kernel to Float8Tensor
# Quantize kernel tensor
kernel_t_fp8 = ker_quantizer(kernel_t)
if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
bulk_inp_fp8 = bulk_inp_quantizer(bulk_inp)
......@@ -540,31 +569,40 @@ def _main(opts):
)
# Set up comm/compute buffers
ag_out = None
rs_out = None
rs_out2 = None
if opts.comm_type == tex.CommOverlapType.AG:
if opts.bulk_overlap:
ub_obj.copy_into_buffer(bulk_inp, bulk_inp_quantizer, True)
ag_out, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj,
bulk_inp,
bulk_inp_quantizer,
tp_group,
)
gemm_inp = inp
else:
ub_obj.copy_into_buffer(inp_fp8 if opts.fp8 else inp, inp_quantizer, True)
gemm_inp = ub_obj.get_buffer(inp_quantizer, False, inp_g.size())
ag_out, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj,
inp_fp8 if with_quantized_compute else inp,
inp_quantizer,
tp_group,
)
gemm_inp = ag_out
if ub_obj2 is not None:
if opts.fp8 and opts.fp8_output:
ub_obj2.set_buffer_params(out_quantizer)
rs_out2 = torch.empty(
(outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda"
)
else:
if opts.bulk_overlap:
ub_obj.copy_into_buffer(
bulk_inp_fp8 if opts.fp8 else bulk_inp, bulk_inp_quantizer, False
)
if opts.fp8:
ub_obj.set_buffer_params(bulk_inp_quantizer)
elif opts.fp8 and opts.fp8_output:
ub_obj.set_buffer_params(out_quantizer)
gemm_inp = inp_fp8 if opts.fp8 else inp
if opts.quantization == "none":
ub_obj.copy_into_buffer(bulk_inp, local_chunk=False)
if opts.quantization == "fp8":
ub_obj.copy_into_buffer(bulk_inp_fp8._data, local_chunk=False)
elif opts.quantization == "mxfp8":
ub_obj.copy_into_buffer(bulk_inp_fp8._rowwise_data, local_chunk=False)
gemm_inp = inp_fp8 if with_quantized_compute else inp
rs_out = torch.empty(
(outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda"
)
......@@ -623,7 +661,7 @@ def _main(opts):
if opts.use_cuda_graphs:
# Trace the CUDA graph first
g = torch.cuda.CUDAGraph()
if opts.fp8:
if with_quantized_compute:
if ub_obj is None:
with torch.cuda.graph(g):
all_outputs = _fp8_gemm()
......@@ -643,7 +681,7 @@ def _main(opts):
else:
for i in range(total_iters):
if opts.fp8:
if with_quantized_compute:
start_events[i].record()
all_outputs = _fp8_gemm()
end_events[i].record()
......@@ -688,10 +726,22 @@ def _main(opts):
output_info = ""
if opts.comm_type == tex.CommOverlapType.AG:
# Bulk overlap AG output is already gathered
test_out = ub_obj.get_buffer(bulk_inp_quantizer, False)
test_out = ag_out
if bulk_inp_quantizer is None:
test_out = ub_obj.get_buffer(False)
else:
test_out = Float8Tensor(
shape=test_out.shape,
dtype=torch.bfloat16,
data=ub_obj.get_buffer(False),
fp8_scale=bulk_inp_quantizer.scale,
fp8_dtype=bulk_inp_quantizer.dtype,
quantizer=bulk_inp_quantizer,
)
else:
# Bulk overlap RS output needs to be gathered
out_local = ub_obj.get_buffer(bulk_inp_quantizer, True)
out_local = ub_obj.get_buffer(True)
output_info += f"rs_output: {list(out_local.shape)} | "
test_out = te.distributed.gather_along_first_dim(out_local, tp_group)[0]
......@@ -762,8 +812,8 @@ def _main(opts):
m = torch.argmax(diff)
abs_err = diff[m].item()
rel_err = abs_err / max(abs(ref_out.flatten()[m].item()), 1e-5)
rtol = 0.125 if opts.fp8 else 0.02
atol = 0.0625 if opts.fp8 else 0.001
rtol = 0.02 if opts.quantization == "none" else 0.125
atol = 0.001 if opts.quantization == "none" else 0.0625
if rel_err > rtol and abs_err > atol:
numerics_failed = True
numerics_info = (
......
......@@ -17,7 +17,12 @@ import torch
import torch.distributed as dist
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling, Float8CurrentScaling
from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
Format,
MXFP8BlockScaling,
)
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
......@@ -163,7 +168,7 @@ def _parse_args(argv=None, namespace=None):
"--quantization",
type=str.lower,
default="none",
choices=["none", "fp8_delayed_scaling", "fp8_current_scaling"],
choices=["none", "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"],
help="Quantization recipe",
)
parser.add_argument(
......@@ -414,6 +419,8 @@ def _train(opts):
)
elif opts.quantization == "fp8_current_scaling":
fp8_recipe = Float8CurrentScaling(fp8_format=fp8_format)
elif opts.quantization == "mxfp8":
fp8_recipe = MXFP8BlockScaling()
# Prepare random input tensors
test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True)
......
......@@ -15,11 +15,12 @@ if torch.cuda.device_count() < 2:
pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
RNG_SEED: int = 42
SEQ_LENGTH: int = 1024
BATCH_SIZE: int = 2
NUM_HEADS: int = 16
NUM_HEADS: int = 32
HEAD_DIM: int = 48
TE_LAYERS = [
te.Linear,
......@@ -50,7 +51,7 @@ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
torch._dynamo.reset()
def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8):
def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, quantization):
test_path = TEST_ROOT / "run_gemm_with_overlap.py"
test_cmd = LAUNCH_CMD + [
str(test_path),
......@@ -66,10 +67,11 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8):
if bulk:
test_cmd.append("--bulk-overlap")
else:
if fp8:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
test_cmd.append("--fp8")
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
test_cmd.append(f"--quantization={quantization}")
if p2p:
test_cmd.append("--p2p")
if atomic:
......@@ -107,8 +109,10 @@ def _run_layer_with_overlap(
test_cmd.append("--overlap-rs-dgrad")
if fp8:
if not fp8_available:
if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
test_cmd.append("--fp8")
test_cmd.append(f"--quantization={quantization}")
......@@ -130,51 +134,34 @@ def _run_layer_with_overlap(
raise AssertionError(result.stderr.decode())
@pytest.mark.parametrize(
"fp8",
(False, True),
ids=[" BF16 - RING-EXCHANGE ", " FP8 - RING-EXCHANGE "],
)
def test_split_all_gather_overlaps(fp8):
@pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8"))
def test_split_all_gather_overlaps(quantization):
"""
Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap("AG", False, True, False, fp8)
_run_gemm_with_overlap("AG", False, True, False, quantization)
@pytest.mark.parametrize(
"fp8,p2p",
[
(False, False),
(False, True),
(True, False),
(True, True),
],
ids=[
" BF16 - PIPELINE ",
" BF16 - RING-EXCHANGE ",
" FP8 - PIPELINE ",
" FP8 - RING-EXCHANGE ",
],
)
def test_split_reduce_scatter_overlaps(fp8, p2p):
@pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8"))
@pytest.mark.parametrize("p2p", (False, True))
def test_split_reduce_scatter_overlaps(quantization, p2p):
"""
Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap("RS", False, p2p, False, fp8)
_run_gemm_with_overlap("RS", False, p2p, False, quantization)
@pytest.mark.parametrize(
"comm_type, fp8, connections",
"comm_type, quantization, connections",
[
("AG", False, 1),
("RS", False, 1),
("RS", True, 1),
("AG", False, 8),
("RS", False, 8),
("RS", True, 8),
("AG", "none", 1),
("RS", "none", 1),
("RS", "fp8", 1),
("AG", "none", 8),
("RS", "none", 8),
("RS", "fp8", 8),
],
ids=[
"ALL-GATHER - BF16 - 1 connections",
......@@ -185,7 +172,7 @@ def test_split_reduce_scatter_overlaps(fp8, p2p):
"REDUCE-SCATTER - FP8 - 8 connections",
],
)
def test_bulk_overlaps(comm_type, fp8, connections):
def test_bulk_overlaps(comm_type, quantization, connections):
"""
Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm.
"""
......@@ -196,10 +183,10 @@ def test_bulk_overlaps(comm_type, fp8, connections):
" 9.0 (HOPPER ARCH)."
)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
_run_gemm_with_overlap(comm_type, True, False, False, fp8)
_run_gemm_with_overlap(comm_type, True, False, False, quantization)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
else:
_run_gemm_with_overlap(comm_type, True, False, False, fp8)
_run_gemm_with_overlap(comm_type, True, False, False, quantization)
@pytest.mark.parametrize(
......@@ -251,15 +238,7 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
@pytest.mark.parametrize(
"quantization",
["fp8_delayed_scaling", "fp8_current_scaling"],
ids=[" DELAYED SCALING ", " CURRENT SCALING "],
)
@pytest.mark.parametrize(
"fp8",
(True,),
ids=[
" FP8 ",
],
["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"],
)
@pytest.mark.parametrize(
"layer_type,linear_parallel_mode,overlap_rs_dgrad",
......@@ -279,15 +258,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
)
),
ids=[
f" {te.Linear.__name__} - ROW-PARALLEL ",
f" {te.Linear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ",
f" {te.Linear.__name__} - COL-PARLALEL - DGRAD+RS ",
f" {te.LayerNormLinear.__name__} - ROW-PARALLEL ",
f" {te.LayerNormLinear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ",
f" {te.LayerNormLinear.__name__} - COL-PARALLEL - DGRAD+RS ",
f"{te.Linear.__name__}-row_tensor_parallel",
f"{te.Linear.__name__}-col_tensor_parallel-BULK DGRAD/WGRAD",
f"{te.Linear.__name__}-col_tensor_parallel-DGRAD+RS",
f"{te.LayerNormLinear.__name__}-row_tensor_parallel",
f"{te.LayerNormLinear.__name__}-col_tensor_parallel-BULK DGRAD/WGRAD",
f"{te.LayerNormLinear.__name__}-col_tensor_parallel-DGRAD+RS",
]
+ [
" " + " - ".join(test_name_parts) + " "
"-".join(test_name_parts)
for test_name_parts in zip(
[layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)],
["BULK DGRAD/WGRAD", "DGRAD+RS"] * len(TE_LAYERS[2:]),
......@@ -295,12 +274,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
],
)
def test_layers_with_overlap_fp8(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization
layer_type,
linear_parallel_mode,
overlap_rs_dgrad,
quantization,
):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization)
_run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization)
@pytest.mark.parametrize(
......@@ -347,22 +329,11 @@ def test_multi_layer_with_overlap_bf16(
@pytest.mark.parametrize(
"quantization",
["fp8_delayed_scaling", "fp8_current_scaling"],
ids=[" DELAYED SCALING ", " CURRENT SCALING "],
)
@pytest.mark.parametrize(
"fp8",
(True,),
ids=[
" FP8 ",
],
["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"],
)
@pytest.mark.parametrize(
"num_layers",
(2,),
ids=[
" 2 layers ",
],
)
@pytest.mark.parametrize(
"layer_type,linear_parallel_mode,overlap_rs_dgrad",
......@@ -374,7 +345,7 @@ def test_multi_layer_with_overlap_bf16(
)
),
ids=[
" " + " - ".join(test_name_parts) + " "
"-".join(test_name_parts)
for test_name_parts in zip(
[te.TransformerLayer.__name__ for _ in range(2)],
["BULK DGRAD/WGRAD", "DGRAD+RS"],
......@@ -382,11 +353,11 @@ def test_multi_layer_with_overlap_bf16(
],
)
def test_multi_layer_with_overlap_fp8(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers
layer_type, linear_parallel_mode, overlap_rs_dgrad, quantization, num_layers
):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers
layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization, num_layers
)
......@@ -19,7 +19,6 @@ import torch
import transformer_engine
import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
......@@ -27,6 +26,8 @@ from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear,
UserbuffersForwardLinear,
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.utils import is_bf16_compatible
# Import utility functions
......@@ -36,6 +37,13 @@ from utils import dtype_tols, str_to_dtype
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
quantization_list: list[Optional[str]] = [None]
if fp8_available:
quantization_list.append("fp8")
if mxfp8_available:
quantization_list.append("mxfp8")
# Check if there are multiple GPUs
if torch.cuda.device_count() < 2:
......@@ -51,7 +59,7 @@ class ModelConfig:
num_heads: int
head_dim: int
dtype: torch.dtype
fp8: bool
quantization: Optional[str]
@property
def hidden_size(self):
......@@ -129,12 +137,16 @@ def make_reference_and_test_tensors(
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Make copy of tensor
test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8:
test = Float8Tensor.to_float8(ref)
else:
test = ref.to(device=test_device, dtype=test_dtype)
if test.data_ptr() == ref.data_ptr():
test = test.clone()
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device),
amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
test = quantizer(test)
elif test.data_ptr() == ref.data_ptr():
test = test.clone()
# Make sure reference and test tensors represent exact same values
ref.copy_(test)
......@@ -145,6 +157,21 @@ def make_reference_and_test_tensors(
return ref, test
def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name == "fp8":
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
def _test_linear(
*,
model_config: ModelConfig,
......@@ -155,7 +182,8 @@ def _test_linear(
weight_requires_grad: bool = True,
) -> None:
dtype = model_config.dtype
fp8_compute = model_config.fp8
quantization = model_config.quantization
quantized_compute = quantization is not None
# Distributed process group
process_group = world_group()
......@@ -175,14 +203,19 @@ def _test_linear(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_compute,
test_is_fp8=quantized_compute,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_compute,
test_is_fp8=quantized_compute,
)
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
b_ref, b_test = None, None
if bias:
if tensor_parallel_mode == "row":
......@@ -198,9 +231,11 @@ def _test_linear(
out_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_compute,
test_is_fp8=quantized_compute,
requires_grad=False,
)
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
......@@ -265,21 +300,15 @@ def _test_linear(
x_test.requires_grad_()
# Implementation with fusible operation
with te.fp8_model_init(enabled=fp8_compute):
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_compute, recipe=recipe):
ops = []
linear_op = None
bias_op = None
if tensor_parallel_mode == "column":
userbuffers_options = {}
if not weight_requires_grad:
if fp8_compute:
userbuffers_options["comm_name"] = "fc1"
else:
# There is a correctness bug with overlapping
# dgrad reduce-scatter with dgrad GEMM. Fall back
# to overlapping dgrad reduce-scatter with wgrad
# GEMM, even though wgrad isn't needed.
userbuffers_options["comm_name"] = "qkv"
userbuffers_options["comm_name"] = "fc1"
else:
userbuffers_options["comm_name"] = "qkv"
linear_op = te_ops.BasicLinear(
......@@ -322,7 +351,7 @@ def _test_linear(
bias_op.bias.copy_(b_test)
del w_test
del b_test
with te.fp8_autocast(enabled=fp8_compute):
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = model(x_test)
y_test.backward(dy_test)
......@@ -338,7 +367,7 @@ def _test_linear(
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if fp8_compute:
if quantized_compute:
tols = dtype_tols(
model[0].weight._fp8_dtype
if is_float8_tensor(model[0].weight)
......@@ -370,7 +399,7 @@ def run_parallel_tests(model_config: ModelConfig) -> None:
for test_config in itertools.product(
(False, True), # bias
("column", "row"), # tensor_parallel_mode
(False, True), # weight_requires_grad
(True, False), # weight_requires_grad
):
if rank == 0:
print(f"Running _test_linear with {test_config=}")
......@@ -390,19 +419,15 @@ if torch.cuda.device_count() > 1:
@pytest.mark.parametrize("world_size", _world_sizes)
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("quantization", quantization_list)
def test_fuser_ops_with_userbuffers(
*,
world_size: int,
dtype: torch.dtype = torch.bfloat16,
fp8: bool,
quantization: Optional[str],
) -> None:
"""Launch parallel job and run tests"""
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
# Parallel job launcher
command = []
if tex.ubuf_built_with_mpi():
......@@ -424,8 +449,8 @@ def test_fuser_ops_with_userbuffers(
str(dtype),
)
)
if fp8:
command.append("--fp8")
if quantization is not None:
command.extend(("--quantization", quantization))
# Environment
env = dict(os.environ)
......@@ -445,12 +470,12 @@ def main() -> None:
# Parse command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--parallel", action="store_true", help="Run parallel tests")
parser.add_argument("--sequence-length", type=int, default=32)
parser.add_argument("--sequence-length", type=int, default=256)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--num-heads", type=int, default=16)
parser.add_argument("--head-dim", type=int, default=32)
parser.add_argument("--head-dim", type=int, default=256)
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--fp8", action="store_true")
parser.add_argument("--quantization", type=str, default=None)
args = parser.parse_args()
# Run parallel tests if needed
......@@ -463,14 +488,17 @@ def main() -> None:
num_heads=args.num_heads,
head_dim=args.head_dim,
dtype=str_to_dtype(args.dtype),
fp8=args.fp8,
quantization=args.quantization,
)
# Initialize Userbuffers
group = world_group() # Initialize NCCL
bootstrap_backend = "mpi" if launcher() == "ompi" else "nccl"
userbuffer_configs = {
"fc1_dgrad": {"method": "pipeline"}, # Overlap dgrad RS with dgrad GEMM
"fc1_dgrad": {
"method": "ring_exchange",
"fp8_buf": False,
}, # Overlap dgrad RS with dgrad GEMM
}
te.module.base.initialize_ub(
[
......@@ -478,7 +506,7 @@ def main() -> None:
model_config.num_heads * model_config.head_dim,
],
torch.distributed.get_world_size(group),
use_fp8=model_config.fp8,
use_fp8=model_config.quantization is not None,
dtype=model_config.dtype,
bootstrap_backend=bootstrap_backend,
ub_cfgs=userbuffer_configs,
......
......@@ -147,7 +147,44 @@ CommOverlapCore::~CommOverlapCore() {
TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, size_t chunk_offset,
const std::vector<size_t> &chunk_shape) {
TensorWrapper chunk;
const auto scaling_mode = source.scaling_mode();
// Tensor dimensions
std::vector<size_t> shape = shape_to_vector(source.shape());
auto flatten_shape_to_2d = [](const std::vector<size_t> &shape) -> std::pair<size_t, size_t> {
if (shape.empty()) {
return {1, 1};
}
size_t height = 1;
for (size_t i = 0; i < shape.size() - 1; ++i) {
height *= shape[i];
}
return {height, shape.back()};
};
size_t height, width, chunk_height, chunk_width;
std::tie(height, width) = flatten_shape_to_2d(shape);
std::tie(chunk_height, chunk_width) = flatten_shape_to_2d(chunk_shape);
// Check tensor dimensions
#define NVTE_DIM_CHECK(cond, message) \
NVTE_CHECK(cond, message, " (tensor shape=", shape, ", chunk shape=", chunk_shape, \
", chunk offset=", chunk_offset, ")")
NVTE_DIM_CHECK(height > 0 && width > 0, "Attempted to get chunk from empty tensor");
NVTE_DIM_CHECK(chunk_height > 0 && chunk_width > 0, "Attempted to get empty tensor chunk");
NVTE_DIM_CHECK(chunk_height <= height && chunk_width <= width,
"Attempted to get out-of-bounds tensor chunk");
if (scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING) {
// MXFP8 scale-inverses are padded to a 2D matrix with dims that
// are divisible by 128. UB doesn't handle this padding yet.
NVTE_DIM_CHECK(height % 128 == 0 && width % 128 == 0,
"Userbuffers requires MXFP8 tensor dims that are divisible by 128");
NVTE_DIM_CHECK(chunk_height % 128 == 0 && chunk_width % 128 == 0,
"Userbuffers requires MXFP8 tensor chunk dims that are divisible by 128");
}
#undef NVTE_DIM_CHECK
// Construct tensor chunk
TensorWrapper chunk(scaling_mode);
for (int param_id = 0; param_id < NVTETensorParam::kNVTENumTensorParams; param_id++) {
auto param_type = static_cast<NVTETensorParam>(param_id);
auto param = source.get_parameter(param_type);
......@@ -163,8 +200,8 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
param_shape = chunk_shape;
if (param_type == NVTETensorParam::kNVTEColumnwiseData &&
source.scaling_mode() != NVTEScalingMode::NVTE_MXFP8_1D_SCALING) {
// Columnwise shape for non-block scaled tensors shifts the last dimension to the front
source.scaling_mode() == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) {
// Columnwise shape for FP8 tensor-scaled tensors shifts the last dimension to the front
auto last_dim = param_shape.back();
param_shape.pop_back();
param_shape.insert(param_shape.begin(), last_dim);
......@@ -172,18 +209,16 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
} else if (source.scaling_mode() == NVTEScalingMode::NVTE_MXFP8_1D_SCALING &&
(param_type == NVTETensorParam::kNVTERowwiseScaleInv ||
param_type == NVTETensorParam::kNVTEColumnwiseScaleInv)) {
// Calculate block scaling offset and size
auto scaled_tensor_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv)
? source.shape().data[0]
: source.columnwise_shape().data[0];
auto scaled_chunk_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv)
? chunk_shape.front()
: chunk_shape.back();
auto chunk_scale_start = chunk_offset / 32;
auto chunk_scale_end = (chunk_offset + scaled_chunk_dim_size) / 32;
auto chunk_scale_size = chunk_scale_end - chunk_scale_start;
param_dptr += chunk_scale_start * typeToSize(param_dtype);
param_shape = std::vector<size_t>{chunk_scale_size};
// Calculate offset and size for MXFP8 scale-invs
size_t chunk_scale_height = chunk_height;
size_t chunk_scale_width = chunk_width;
if (param_type == NVTETensorParam::kNVTERowwiseScaleInv) {
chunk_scale_width /= 32;
} else {
chunk_scale_height /= 32;
}
param_dptr += (chunk_offset / 32) * typeToSize(param_dtype);
param_shape = {chunk_scale_height, chunk_scale_width};
}
// Set chunked source parameters into the chunked tensor output
......@@ -422,10 +457,21 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
size_t k = transa ? A.size(1) : A.size(0);
size_t n = _ubuf.size(0);
size_t m_chunk = m / _num_splits;
const std::vector<size_t> input_a_chunk_shape =
(transa ? std::vector<size_t>{m_chunk, k} : std::vector<size_t>{k, m_chunk});
const std::vector<size_t> output_chunk_shape = {n, m_chunk};
size_t input_a_chunk_size = m_chunk * k;
size_t output_chunk_size = n * m_chunk;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Helper function to get bias chunk if needed
auto maybe_get_bias_chunk = [this, &bias, m_chunk](size_t chunk_id) -> TensorWrapper {
if (bias.dptr() == nullptr) {
return TensorWrapper();
}
return get_tensor_chunk(bias, chunk_id * m_chunk, {m_chunk});
};
// Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
for (size_t i = 0; i < _stream_compute.size(); i++) {
......@@ -437,21 +483,23 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
if (_rs_overlap_first_gemm) {
auto input_a_chunk = get_tensor_chunk(A, 0, {m_chunk, k});
auto output_chunk = get_buffer_chunk_like(D, 0, {m, m_chunk});
auto input_a_chunk = get_tensor_chunk(A, 0, input_a_chunk_shape);
auto output_chunk = get_buffer_chunk_like(D, 0, output_chunk_shape);
auto bias_chunk = maybe_get_bias_chunk(0);
auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0]);
for (int i = 1; i < _num_splits; i++) {
input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k});
output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk});
input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, input_a_chunk_shape);
output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, output_chunk_shape);
bias_chunk = maybe_get_bias_chunk(i);
workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
......@@ -494,12 +542,13 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
}
} else {
for (int i = 0; i < _num_splits; i++) {
auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k});
auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk});
auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, input_a_chunk_shape);
auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, output_chunk_shape);
auto bias_chunk = maybe_get_bias_chunk(i);
auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
......@@ -759,8 +808,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const bool do_gelu = pre_gelu_out.numel() > 0;
size_t input_chunk_size = n_chunk * k;
size_t output_chunk_size = n_chunk * m;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
......@@ -771,8 +818,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
}
if (_aggregate) {
const int num_steps = _tp_size / 2;
input_chunk_size *= 2;
output_chunk_size *= 2;
// Chunk dims
std::vector<size_t> input_b_chunk_shape =
(transb ? std::vector<size_t>{k, 2 * n_chunk} : std::vector<size_t>{2 * n_chunk, k});
std::vector<size_t> output_chunk_shape = {2 * n_chunk, k};
size_t input_b_chunk_size = 2 * n_chunk * k;
size_t output_chunk_size = 2 * n_chunk * m;
// Initial 1X input chunk exchange between neighboring peers
int send_chunk_id = _tp_id;
......@@ -801,8 +853,9 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// GEMM
auto input_b_chunk =
get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk * 2, k});
auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk * 2, m});
get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);
auto output_chunk =
get_tensor_chunk(D, output_chunk_size * send_chunk_id, output_chunk_shape);
auto aux_chunk =
(do_gelu)
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk * 2, k})
......@@ -834,6 +887,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
}
}
} else {
// Chunk dims
std::vector<size_t> input_b_chunk_shape =
(transb ? std::vector<size_t>{k, n_chunk} : std::vector<size_t>{n_chunk, k});
std::vector<size_t> output_chunk_shape = {n_chunk, m};
size_t input_b_chunk_size = n_chunk * k;
size_t output_chunk_size = n_chunk * m;
for (int i = 0; i < _tp_size; i++) {
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
......@@ -845,8 +905,10 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
int recv_offset = comm_bytes * recv_chunk_id;
// GEMM
auto input_b_chunk = get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk, k});
auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk, m});
auto input_b_chunk =
get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);
auto output_chunk =
get_tensor_chunk(D, output_chunk_size * send_chunk_id, output_chunk_shape);
auto aux_chunk =
(do_gelu)
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk, k})
......@@ -996,6 +1058,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
auto input_b_chunk = get_tensor_chunk(B, input_b_chunk_id * input_chunk_size, {n_chunk, k});
auto output_chunk = get_buffer_chunk_by_id(D, i);
auto workspace_chunk =
get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk});
......
......@@ -12,10 +12,18 @@
#include <cudnn.h>
#include <nvrtc.h>
#include <iostream>
#include <stdexcept>
#include "../util/string.h"
#define NVTE_WARN(...) \
do { \
std::cerr << ::transformer_engine::concat_strings( \
__FILE__ ":", __LINE__, " in function ", __func__, ": ", \
::transformer_engine::concat_strings(__VA_ARGS__), "\n"); \
} while (false)
#define NVTE_ERROR(...) \
do { \
throw ::std::runtime_error(::transformer_engine::concat_strings( \
......
......@@ -452,12 +452,10 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve
~CommOverlap() {}
void set_buffer_params(py::handle quantizer);
void copy_into_buffer(const at::Tensor &input, bool local_chunk = false);
void copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk = false);
py::object get_buffer(py::handle quantizer, bool local_chunk = false,
std::optional<const std::vector<int64_t>> shape = std::nullopt);
at::Tensor get_buffer(bool local_chunk = false,
std::optional<std::vector<int64_t>> shape = std::nullopt);
}; // CommOverlap
......@@ -473,12 +471,10 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
~CommOverlapP2P() {}
void set_buffer_params(py::handle quantizer);
void copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk = false);
void copy_into_buffer(const at::Tensor &input, bool local_chunk = false);
py::object get_buffer(py::handle quantizer, bool local_chunk = false,
std::optional<const std::vector<int64_t>> shape = std::nullopt);
at::Tensor get_buffer(bool local_chunk = false,
std::optional<std::vector<int64_t>> shape = std::nullopt);
}; // CommOverlapP2P
......
......@@ -141,81 +141,79 @@ CommOverlap::CommOverlap(const std::vector<size_t> &buffer_shape, at::ScalarType
num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm,
set_sm_margin, atomic_gemm, rs_overlap_first_gemm) {}
void CommOverlap::set_buffer_params(py::handle quantizer) {
std::unique_ptr<te::pytorch::Quantizer> my_quantizer = te::pytorch::convert_quantizer(quantizer);
my_quantizer->set_quantization_params(&_ubuf);
_ubuf_scale_inv_initialized = true;
}
/*
** Helper function to copy input to _ubuf
*/
void CommOverlap::copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk) {
auto input_tensor = te::pytorch::makeTransformerEngineTensor(input, quantizer);
auto input_ptr = input_tensor.dptr() ? input_tensor.dptr() : input_tensor.columnwise_dptr();
NVTE_CHECK(input_ptr, "Input tensor does not have rowwise or columnwise data!");
char *ubuf_ptr = reinterpret_cast<char *>(_ubuf.dptr());
void CommOverlap::copy_into_buffer(const at::Tensor &input, bool local_chunk) {
const auto &input_ = input.contiguous();
// Check element size
const size_t element_size = input.element_size();
NVTE_CHECK(_ubuf.element_size() == element_size,
"Tried to copy data into a Userbuffers buffer but dtypes are not compatible ",
"(input dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(),
" bytes)");
// Input data
const size_t input_size = input_.numel();
const void *src_ptr = input_.data_ptr();
// Userbuffers data
const size_t ubuf_size = _ubuf.numel();
void *dst_ptr = _ubuf.dptr();
if (local_chunk) {
if (input_tensor.numel() * _tp_size > _ubuf.numel())
NVTE_ERROR("input is larger than the local communication buffer!");
if (input_tensor.element_size() != _ubuf.element_size())
NVTE_ERROR("input data type does not match communication buffer!");
ubuf_ptr += (_ubuf.numel() / _tp_size) * _tp_id * _ubuf.element_size();
NVTE_CHECK(input_size * _tp_size == ubuf_size,
"Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ",
"(input_size=", input_size, ", tensor_parallel_size=", _tp_size,
", ubuf_size=", ubuf_size, ")");
dst_ptr = (reinterpret_cast<char *>(dst_ptr) + (ubuf_size / _tp_size) * _tp_id * element_size);
} else {
if (input_tensor.numel() > _ubuf.numel())
NVTE_ERROR("input is larger than the global communication buffer!");
if (input_tensor.element_size() != _ubuf.element_size())
NVTE_ERROR("input data type does not match communication buffer!");
NVTE_CHECK(input_size == ubuf_size,
"Tried to copy an invalid tensor into a Userbuffers buffer ",
"(input_size=", input_size, ", ubuf_size=", ubuf_size, ")");
}
// Copy either row or columnwise data into the communication buffer's columnwise data
// NOTE: _ubuf.columnwise_dptr() is not a valid copy target because it is not registered with
// the Userbuffers communicator.
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
// Copy data
auto stream_main = at::cuda::getCurrentCUDAStream();
NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0));
NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input_tensor.dptr(),
input_tensor.numel() * input_tensor.element_size(),
NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, input_size * element_size,
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm));
}
py::object CommOverlap::get_buffer(py::handle quantizer, bool local_chunk,
std::optional<const std::vector<int64_t>> shape) {
using namespace te::pytorch;
char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.dptr());
if (local_chunk) ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size();
std::vector<int64_t> torch_shape;
if (shape.has_value()) {
torch_shape = shape.value();
size_t requested = product(torch_shape);
auto expected = local_chunk ? _ubuf.numel() / _tp_size : _ubuf.numel();
NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested,
") does not match allocated buffer size (", expected, ")!");
at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional<std::vector<int64_t>> shape) {
// Check buffer shape
const size_t ubuf_size = _ubuf.numel();
if (shape) {
const size_t requested_size = transformer_engine::pytorch::product(*shape);
if (local_chunk) {
NVTE_CHECK(requested_size * _tp_size == ubuf_size,
"Invalid shape for local chunk of a Userbuffers buffer (requested shape=", *shape,
", tensor_parallel_size=", _tp_size, ", ubuf_size=", ubuf_size, ")");
} else {
NVTE_CHECK(requested_size == ubuf_size,
"Invalid shape for a Userbuffers buffer (requested shape=", *shape,
", ubuf_size=", ubuf_size, ")");
}
} else {
int64_t output_c_dim0 = (local_chunk) ? _ubuf.size(0) / _tp_size : _ubuf.size(0);
int64_t output_c_dim1 = _ubuf.size(1);
torch_shape = {output_c_dim0, output_c_dim1};
int64_t dim0 = _ubuf.size(0);
int64_t dim1 = _ubuf.size(1);
if (local_chunk) {
dim0 /= _tp_size;
}
shape = {dim0, dim1};
}
auto ubuf_tensor = torch::from_blob(reinterpret_cast<void *>(ubuf_wt_ptr), torch_shape,
at::dtype(GetATenDType(_ubuf.dtype())).device(torch::kCUDA));
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
std::vector<size_t> te_shape;
for (auto s : torch_shape) te_shape.emplace_back(static_cast<size_t>(s));
// Always output a rowwise-only QuantizedTensor
// TODO (Alp): This needs to produce an un-interleaved transpose when required.
auto is_internal = my_quantizer->internal;
auto uses_columnwise = my_quantizer->columnwise_usage;
my_quantizer->internal = false;
my_quantizer->columnwise_usage = false;
auto [te_tensor, py_tensor] = my_quantizer->create_tensor(te_shape, _ubuf.dtype(), ubuf_tensor);
my_quantizer->internal = is_internal;
my_quantizer->columnwise_usage = uses_columnwise;
return py_tensor;
// Data pointer
void *ubuf_ptr = _ubuf.dptr();
if (local_chunk) {
ubuf_ptr = (reinterpret_cast<char *>(ubuf_ptr) +
(ubuf_size / _tp_size) * _tp_id * _ubuf.element_size());
}
// Construct PyTorch tensor
const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype());
return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
}
/***************************************************************************************************
......@@ -236,74 +234,69 @@ CommOverlapP2P::CommOverlapP2P(const std::vector<size_t> &buffer_shape, at::Scal
comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce,
atomic_gemm, aggregate) {}
void CommOverlapP2P::set_buffer_params(py::handle quantizer) {
std::unique_ptr<te::pytorch::Quantizer> my_quantizer = te::pytorch::convert_quantizer(quantizer);
my_quantizer->set_quantization_params(&_ubuf);
for (size_t i = 0; i < _ubufs.size(); i++) my_quantizer->set_quantization_params(&_ubufs[i]);
}
/*
** Copy input to _ubufs[0]
*/
void CommOverlapP2P::copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk) {
auto input_tensor = te::pytorch::makeTransformerEngineTensor(input, quantizer);
auto input_ptr = input_tensor.dptr() ? input_tensor.dptr() : input_tensor.columnwise_dptr();
NVTE_CHECK(input_ptr, "Input tensor does not have rowwise or columnwise data!");
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
void CommOverlapP2P::copy_into_buffer(const at::Tensor &input, bool local_chunk) {
const auto &input_ = input.contiguous();
// Check element size
const size_t element_size = input.element_size();
NVTE_CHECK(_ubuf.element_size() == element_size,
"Tried to copy data into a Userbuffers buffer but dtypes are not compatible ",
"(input dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(),
" bytes)");
// Input data
const size_t input_size = input_.numel();
const void *src_ptr = input_.data_ptr();
// Userbuffers data
void *dst_ptr;
if (local_chunk) {
// Copy input to the target ubuf chunk by rank offset
if (input_tensor.numel() * _tp_size > _ubuf.numel())
NVTE_ERROR("input is larger than the local communication buffer!");
if (input_tensor.element_size() != _ubuf.element_size())
NVTE_ERROR("input data type does not match communication buffer!");
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input_ptr,
input_tensor.numel() * input_tensor.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main));
NVTE_CHECK(_ubufs[_tp_id].numel() == input_size,
"Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ",
"(input_size=", input_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")");
dst_ptr = _ubufs[_tp_id].dptr();
} else {
if (input_tensor.numel() > _ubuf.numel())
NVTE_ERROR("input is larger than the global communication buffer!");
if (input_tensor.element_size() != _ubuf.element_size())
NVTE_ERROR("input data type does not match communication buffer!");
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input_ptr,
input_tensor.numel() * input_tensor.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main));
NVTE_CHECK(_ubuf.numel() == input_size,
"Tried to copy an invalid tensor into a Userbuffers buffer ",
"(input_size=", input_size, ", ubuf_size=", _ubuf.numel(), ")");
dst_ptr = _ubuf.dptr();
}
// Copy data
NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, input_size * element_size,
cudaMemcpyDeviceToDevice,
(cudaStream_t)at::cuda::getCurrentCUDAStream()));
}
py::object CommOverlapP2P::get_buffer(py::handle quantizer, bool local_chunk,
std::optional<const std::vector<int64_t>> shape) {
using namespace te::pytorch;
char *ubuf_wt_ptr = reinterpret_cast<char *>(local_chunk ? _ubufs[_tp_id].dptr() : _ubuf.dptr());
std::vector<int64_t> torch_shape;
if (shape.has_value()) {
torch_shape = shape.value();
size_t requested = product(torch_shape);
auto expected = local_chunk ? _ubufs[_tp_id].numel() : _ubuf.numel();
NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested,
") does not match allocated buffer size (", expected, ")!");
at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional<std::vector<int64_t>> shape) {
// Check buffer shape
if (shape) {
const size_t requested_size = transformer_engine::pytorch::product(*shape);
if (local_chunk) {
NVTE_CHECK(requested_size == _ubufs[_tp_id].numel(),
"Invalid shape for local chunk of a Userbuffers buffer (requested shape=", *shape,
", local_ubuf_size=", _ubufs[_tp_id].numel(), ")");
} else {
NVTE_CHECK(requested_size == _ubuf.numel(),
"Invalid shape for a Userbuffers buffer (requested shape=", *shape,
", ubuf_size=", _ubuf.numel(), ")");
}
} else {
int64_t output_c_dim0 = (local_chunk) ? _ubuf.size(0) / _tp_size : _ubuf.size(0);
int64_t output_c_dim1 = _ubuf.size(1);
torch_shape = {output_c_dim0, output_c_dim1};
int64_t dim0 = _ubuf.size(0);
int64_t dim1 = _ubuf.size(1);
if (local_chunk) {
dim0 /= _tp_size;
}
shape = {dim0, dim1};
}
auto ubuf_tensor = torch::from_blob(reinterpret_cast<void *>(ubuf_wt_ptr), torch_shape,
at::dtype(GetATenDType(_ubuf.dtype())).device(torch::kCUDA));
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
std::vector<size_t> te_shape;
for (auto s : torch_shape) te_shape.emplace_back(static_cast<size_t>(s));
// Always output a rowwise-only QuantizedTensor
// TODO (Alp): This needs to produce an un-interleaved transpose when required.
auto is_internal = my_quantizer->internal;
auto uses_columnwise = my_quantizer->columnwise_usage;
my_quantizer->internal = false;
my_quantizer->columnwise_usage = false;
auto [te_tensor, py_tensor] = my_quantizer->create_tensor(te_shape, _ubuf.dtype(), ubuf_tensor);
my_quantizer->internal = is_internal;
my_quantizer->columnwise_usage = uses_columnwise;
return py_tensor;
// Data pointer
void *ubuf_ptr = local_chunk ? _ubufs[_tp_id].dptr() : _ubuf.dptr();
// Construct PyTorch tensor
const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype());
return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
}
......@@ -360,10 +360,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true,
py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false)
.def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"),
py::arg("quantizer"), py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlap::get_buffer, py::arg("quantizer"),
py::arg("local_chunk") = false, py::arg("shape") = std::nullopt)
.def("set_buffer_params", &CommOverlap::set_buffer_params);
py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false,
py::arg("shape") = std::nullopt);
py::class_<CommOverlapP2P, std::shared_ptr<CommOverlapP2P>,
transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>(
......@@ -378,8 +377,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false,
py::arg("use_ce") = true, py::arg("aggregate") = false)
.def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"),
py::arg("quantizer"), py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("quantizer"),
py::arg("local_chunk") = false, py::arg("shape") = std::nullopt)
.def("set_buffer_params", &CommOverlapP2P::set_buffer_params);
py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false,
py::arg("shape") = std::nullopt);
}
......@@ -919,8 +919,10 @@ def _all_gather_fp8(
# Note: We cannot directly all-gather the transposed FP8 tensor,
# so temporarily modify quantizer to avoid creating FP8 transpose.
if not isinstance(inp, Float8TensorBase):
if quantizer is None:
raise ValueError("Input tensor is not FP8 and no quantizer was provided")
assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer))
# we cannot directly gather the transposed fp8 tensor
# so we need to disable columnwise usage for the quantizer
# and then set it back to the original value after quantizing
init_rowwise_usage = quantizer.rowwise_usage
init_columnwise_usage = quantizer.columnwise_usage
quantizer.set_usage(rowwise=True, columnwise=False)
......@@ -1063,11 +1065,11 @@ def _all_gather_mxfp8(
dtype = inp.dtype
elif isinstance(inp, MXFP8TensorBase):
if inp._rowwise_data is not None:
in_shape = inp._rowwise_data.device.size()
in_shape = inp._rowwise_data.size()
device = inp._rowwise_data.device
dtype = inp._rowwise_data.dtype
elif inp._columnwise_data is not None:
in_shape = inp._columnwise_data.device.size()
in_shape = inp._columnwise_data.size()
device = inp._columnwise_data.device
dtype = inp._columnwise_data.dtype
else:
......
......@@ -6,8 +6,6 @@
from typing import Any, List, Optional, Tuple, Union, Callable
from dataclasses import dataclass
from functools import reduce
from operator import mul as multiply_op
import queue
import torch
......@@ -15,7 +13,6 @@ import torch
from .. import cpp_extensions as tex
from ..constants import TE_DType
from ..utils import get_default_init_method
from ..tensor.float8_tensor import Float8Tensor
def _get_normalization_func(normalization: str, forward: bool):
......@@ -33,39 +30,6 @@ def _get_normalization_func(normalization: str, forward: bool):
return bwd_normalization_funcs[normalization]
def _fix_gathered_fp8_transpose(fp8_tensor: Float8Tensor, tp_size: int) -> Float8Tensor:
"""Reorder FP8 transposes after Userbuffers gather.
The all-gather is performed in-place in the Float8Tensor's
row-wise data, and afterwards we need to do a transpose to get the
correct ordering. This misuses data fields in Float8Tensor and
should be considered an evil hack. It would be best to move
transpose logic into CommOverlap::get_buffer.
Responsibility for fixing: adener, tmoon
"""
assert isinstance(fp8_tensor, Float8Tensor), "Tensor is not a Float8Tensor"
assert tp_size > 1, "The tensor transpose cannot be interleaved when TP size is 1"
assert fp8_tensor._data is not None, "The tensor does not hold any rowwise data"
assert (
fp8_tensor._data.shape[0] % tp_size == 0
), "Leading dimension of data is not divisble by TP size"
data = fp8_tensor._data
batched_size = reduce(multiply_op, data.shape[1:])
interleaved_shape = [tp_size, data.shape[0] // tp_size, batched_size]
transposed_shape = [data.shape[0] // tp_size, batched_size * tp_size]
fp8_tensor._transpose = (
data.view(interleaved_shape).transpose(0, 1).contiguous().view(transposed_shape)
)
fp8_tensor._transpose_invalid = False
fp8_tensor._data = None
return fp8_tensor
def apply_normalization(
inputmat: torch.Tensor,
ln_out: torch.Tensor,
......
......@@ -4,6 +4,7 @@
"""Base modules and utilities for TransformerEngine PyTorch API"""
import io
import math
import os
import pickle
import warnings
......@@ -35,7 +36,9 @@ from ..distributed import (
_fsdp_gather_tensors,
)
from ..constants import dist_group_type
from ..tensor import QuantizedTensor, Quantizer
from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorBase, Quantizer
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
......@@ -418,6 +421,142 @@ def destroy_ub():
layers_atomic_ring_exchange = []
def fill_userbuffers_buffer_for_all_gather(
comm,
local_tensor: torch.Tensor,
quantizer: Optional[Quantizer],
process_group,
) -> tuple[torch.Tensor | QuantizedTensorBase, torch.Tensor | QuantizedTensorBase]:
"""Fill local shard of Userbuffers buffer with data for all-gather
Returns the full tensor and the local shard, both using the
Userbuffers buffer as their underlying data. These tensors should
be used carefully (e.g. only immediately before and after a
Userbuffers operation) since the underlying data may be
overwritten by other Userbuffers operations.
May perform blocking communication if needed for the gathered
tensor's metadata, e.g. scaling factors.
"""
# Tensor dimensions
local_shape = local_tensor.size()
if not local_shape:
raise ValueError(f"Invalid local tensor (shape={tuple(local_shape)})")
process_group_size = torch.distributed.get_world_size(process_group)
global_shape = list(local_shape)
global_shape[0] *= process_group_size
# Unquantized data
if quantizer is None:
if isinstance(local_tensor, QuantizedTensorBase):
local_tensor = local_tensor.dequantize()
if comm.is_fp8_ubuf():
raise RuntimeError(
"Attempting to all-gather unquantized tensor, "
"but Userbuffers is initialized with FP8 buffers"
)
comm.copy_into_buffer(local_tensor, local_chunk=True)
global_tensor = comm.get_buffer(shape=global_shape)
return global_tensor, local_tensor
# FP8 data
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
if not isinstance(local_tensor, Float8TensorBase):
if isinstance(local_tensor, QuantizedTensorBase):
local_tensor.dequantize()
quantizer.set_usage(rowwise=True, columnwise=False)
local_tensor = quantizer(local_tensor)
if not comm.is_fp8_ubuf():
raise RuntimeError(
"Attempting to all-gather FP8 tensor, "
"but Userbuffers is not initialized with FP8 buffers"
)
comm.copy_into_buffer(local_tensor._data, local_chunk=True)
global_tensor_data = comm.get_buffer(shape=global_shape)
global_tensor = Float8TensorBase(
data=global_tensor_data,
fp8_scale_inv=local_tensor._scale_inv,
fp8_dtype=local_tensor._fp8_dtype,
quantizer=quantizer,
)
return global_tensor, local_tensor
# MXFP8 data
if isinstance(quantizer, MXFP8Quantizer):
# Cast to MXFP8 if needed
if not isinstance(local_tensor, MXFP8TensorBase):
if isinstance(local_tensor, QuantizedTensorBase):
local_tensor.dequantize()
local_tensor = quantizer(local_tensor)
if not comm.is_fp8_ubuf():
raise RuntimeError(
"Attempting to all-gather MXFP8 tensor, "
"but Userbuffers is not initialized with FP8 buffers"
)
# Check which MXFP8 buffer to communicate
if quantizer.rowwise_usage == quantizer.columnwise_usage:
raise ValueError(
"Userbuffers can only communicate one MXFP8 buffer at a time, "
f"but quantizer has rowwise_usage={quantizer.rowwise_usage}, "
f"columnwise_usage={quantizer.columnwise_usage}"
)
with_rowwise_data = quantizer.rowwise_usage
# Copy MXFP8 data to local chunk of Userbuffers buffer
local_data = (
local_tensor._rowwise_data if with_rowwise_data else local_tensor._columnwise_data
)
comm.copy_into_buffer(local_data, local_chunk=True)
# Gather scaling-inverses
if math.prod(local_shape[:-1]) % 128 != 0:
raise ValueError(
"Userbuffers requires MXFP8 tensor dims that are divisible by 128, "
f"but got MXFP8 tensor with shape={tuple(local_shape)}"
)
local_scale_inv = (
local_tensor._rowwise_scale_inv
if with_rowwise_data
else local_tensor._columnwise_scale_inv
)
local_scale_inv_size = list(local_scale_inv.size())
global_scale_inv = torch.empty(
[process_group_size * local_scale_inv_size[0]] + local_scale_inv_size[1:],
dtype=local_scale_inv.dtype,
device=local_scale_inv.device,
)
torch.distributed.all_gather_into_tensor(
global_scale_inv,
local_scale_inv,
group=process_group,
)
# Construct MXFP8 tensor with Userbuffers buffer
rowwise_data, rowwise_scale_inv = None, None
columnwise_data, columnwise_scale_inv = None, None
global_data = comm.get_buffer(shape=global_shape)
if with_rowwise_data:
rowwise_data, rowwise_scale_inv = global_data, global_scale_inv
else:
columnwise_data, columnwise_scale_inv = global_data, global_scale_inv
global_tensor = MXFP8TensorBase(
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
fp8_dtype=local_tensor._fp8_dtype,
quantizer=quantizer,
)
return global_tensor, local_tensor
# Unsupported data format
raise ValueError(f"Unsupported quantizer for Userbuffers ({quantizer})")
class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module."""
......@@ -866,11 +1005,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Non-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8 and not ctx.debug:
if gather_grad_output:
if not ctx.ub_overlap_ag:
if not ctx.ub_overlap_ag: # Perform NCCL all-gather
grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group)
else:
ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True)
grad_output = ctx.ub_obj_gradout.get_buffer(quantizer)
else: # Initialize Userbuffers all-gather
grad_output, _ = fill_userbuffers_buffer_for_all_gather(
ctx.ub_obj_gradout,
grad_output,
None,
ctx.tp_group,
)
return grad_output, None
# FP8 with all-gather: unfused bgrad, fused cast + transpose
......@@ -893,8 +1036,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_output = quantizer(grad_output)
# Copy into communication buffer, and replace original gradient with it
ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True)
grad_output = ctx.ub_obj_gradout.get_buffer(quantizer)
grad_output, _ = fill_userbuffers_buffer_for_all_gather(
ctx.ub_obj_gradout,
grad_output,
quantizer,
ctx.tp_group,
)
else:
grad_output, _ = gather_along_first_dim(
grad_output,
......@@ -1165,7 +1312,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
return
with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"):
(wgrad, grad_bias_, _, _), _ = self.wgrad_store.pop()
(wgrad, bgrad), _ = self.wgrad_store.pop()
if not self.fuse_wgrad_accumulation:
unfused_weights = [getattr(self, name) for name in self.weight_names]
weight_tensor = noop_cat(unfused_weights)
......@@ -1174,9 +1321,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
if bias_tensor.grad is None:
bias_tensor.grad = grad_bias_.to(bias_tensor.dtype)
del grad_bias_
del wgrad
bias_tensor.grad = bgrad.to(bias_tensor.dtype)
def _validate_name(self):
"""
......
......@@ -9,7 +9,6 @@ from typing import Callable, Dict, Optional, Tuple, Union
from functools import reduce
from operator import mul as multiply_op
import functools
import torch
from torch.nn import init
......@@ -18,6 +17,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version
from .base import (
fill_userbuffers_buffer_for_all_gather,
get_workspace,
get_ub,
TransformerEngineBaseModule,
......@@ -53,9 +53,10 @@ from ..distributed import (
from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ._common import apply_normalization, noop_cat, _fix_gathered_fp8_transpose, WeightGradStore
from ._common import apply_normalization, noop_cat, WeightGradStore
from ..tensor.quantized_tensor import (
QuantizedTensor,
QuantizedTensorBase,
Quantizer,
prepare_for_saving,
restore_from_saved,
......@@ -153,42 +154,43 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.norm_input_cast")
tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag_fprop = (
ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output
)
weight_requires_grad = weight.requires_grad
backward_needs_input = is_grad_enabled and weight_requires_grad
with_input_all_gather = parallel_mode == "column" and sequence_parallel
# Check if Userbuffers is supported
if fp8:
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not (
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
):
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
# Configure Userbuffers communication (comm+GEMM overlap)
ub_obj = None
ub_type = None
ub_overlap_ag_fprop = (
ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output
)
if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.RS
elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.AG
# Configure quantizer for norm output
if fp8:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
columnwise_usage = backward_needs_input
if (
columnwise_usage
and with_input_all_gather
and not isinstance(input_quantizer, MXFP8Quantizer)
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if with_input_all_gather and isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
columnwise_usage = False
input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
# All-gather is not supported with FP8 column-wise data
input_quantizer.set_usage(columnwise=False)
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_blockwise_ln_out_gather = (
fp8 and with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer)
) # Perform TP communication in high precision.
)
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
with_quantized_norm = (
fp8
and not return_layernorm_output
......@@ -210,16 +212,19 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_ln_sm_margin,
zero_centered_gamma,
)
nvtx_range_pop(f"{nvtx_label}.norm")
# Store unquantized layer norm output if we need to return it
ln_out_return = None
if return_layernorm_output or return_layernorm_output_gathered:
ln_out_return = ln_out
nvtx_range_pop(f"{nvtx_label}.norm")
# Prepare GEMM input
# ------------------------------------------------------
# Prepare GEMM input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# ------------------------------------------------------
nvtx_range_push(f"{nvtx_label}.gemm_input_cast_comm")
ln_out_total = None
ub_obj_fprop = None
if with_input_all_gather:
if return_layernorm_output_gathered:
# Perform all-gather in high precision if gathered
......@@ -227,47 +232,53 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total
if fp8 or debug:
ln_out = input_quantizer(ln_out)
if not force_hp_blockwise_ln_out_gather:
ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out_total = input_quantizer(ln_out_total)
else:
quantizer = None
if fp8 or debug:
quantizer = input_quantizer
if not with_quantized_norm and not force_hp_blockwise_ln_out_gather:
ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag_fprop:
# Copy into Userbuffers buffer
ub_obj_fprop = get_ub(ub_name + "_fprop")
ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True).copy_(ln_out)
ln_out_total = ub_obj_fprop.get_buffer(input_quantizer)
else:
# All-gather with NCCL
ln_out = quantizer(ln_out)
quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather
ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj,
ln_out,
quantizer,
tp_group,
)
else: # Perform NCCL all-gather
ln_out_total, _ = gather_along_first_dim(
ln_out,
tp_group,
quantizer=(input_quantizer if fp8 or debug else None),
quantizer=quantizer,
)
else:
if (fp8 or debug) and not with_quantized_norm:
ln_out = input_quantizer(ln_out)
ln_out_total = ln_out
nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm")
# ------------------------------------------------------
# GEMM input tensor is ready...
# ------------------------------------------------------
# Cast weight to expected dtype
# ------------------------------------------------------
# Prepare weight tensor
# ------------------------------------------------------
weightmat = weight
quantized_weight = False
if not fp8 and not debug:
weightmat = cast_if_needed(weightmat, activation_dtype)
else:
quantized_weight = not isinstance(weight, QuantizedTensor)
if fp8 or debug:
quantized_weight = not isinstance(weight, QuantizedTensorBase)
# Configure quantizer
if weight_quantizer is not None:
weight_quantizer.set_usage(rowwise=True, columnwise=True)
# FP8 cast to workspace buffer
# Get quantized weight
update_workspace = is_first_microbatch is None or is_first_microbatch
weightmat = module.get_weight_workspace(
tensor=weight,
quantizer=weight_quantizer,
......@@ -277,17 +288,21 @@ class _LayerNormLinear(torch.autograd.Function):
fsdp_group=fsdp_group,
workspace_dtype=activation_dtype,
)
weightmat.update_usage(rowwise_usage=True)
else:
weightmat = cast_if_needed(weightmat, activation_dtype) # Cast for AMP
# ------------------------------------------------------
# Weight tensor is ready for GEMM...
# ------------------------------------------------------
# Cast bias to expected dtype
bias_dtype = activation_dtype
if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32:
# cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16
bias_dtype = torch.bfloat16
bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias
# Configure output quantizer
if output_quantizer is not None:
output_quantizer.set_usage(rowwise=True, columnwise=False)
# Calibrate quantizers if needed
if not fp8 and fp8_calibration:
if input_quantizer is not None:
......@@ -295,47 +310,80 @@ class _LayerNormLinear(torch.autograd.Function):
if weight_quantizer is not None:
weight_quantizer.calibrate(weight)
ub_obj = None
ub_type = None
rs_out = None
if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.RS
out_shape = [reduce(multiply_op, inp_shape[:-1]) // tp_world_size, out_features]
rs_out = torch.empty(out_shape, dtype=activation_dtype, device=ln_out_total.device)
elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.AG
if fp8:
assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer."
ln_out_total = ub_obj.get_buffer(input_quantizer)
nvtx_range_push(f"{nvtx_label}.gemm")
fprop_gemm_use_split_accumulator = _2X_ACC_FPROP
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator = _2X_ACC_FPROP
if fp8:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if hasattr(recipe, "fp8_gemm_fprop"):
fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
out, *_, rs_out = general_gemm(
# Configure output quantizer
if output_quantizer is not None:
output_quantizer.set_usage(rowwise=True, columnwise=False)
# Output buffer for Userbuffers reduce-scatter
reduce_scatter_out = None
if ub_overlap_rs_fprop:
out_shape = list(inp_shape)
out_shape[0] //= tp_world_size
out_shape[-1] = out_features
reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device)
# ------------------------------------------------------
# Forward GEMM
# Note: y = x * w^T
# ------------------------------------------------------
nvtx_range_push(f"{nvtx_label}.gemm")
gemm_out, *_, reduce_scatter_out = general_gemm(
weightmat,
ln_out_total,
get_workspace(),
quantization_params=output_quantizer,
out_dtype=activation_dtype,
bias=bias,
use_split_accumulator=fprop_gemm_use_split_accumulator,
use_split_accumulator=use_split_accumulator,
ub=ub_obj,
ub_type=ub_type,
extra_output=rs_out,
extra_output=reduce_scatter_out,
)
nvtx_range_pop(f"{nvtx_label}.gemm")
# ------------------------------------------------------
# Finished forward GEMM...
# ------------------------------------------------------
# Deallocate GEMM input tensor if no longer needed
if not weight.requires_grad and not return_layernorm_output:
ln_out = ln_out_total = None
clear_tensor_data(ln_out, ln_out_total)
# ------------------------------------------------------
# Prepare output tensor
# Note: Perform tensor-parallel communication
# ------------------------------------------------------
out = None
if ub_overlap_rs_fprop:
out = reduce_scatter_out
elif parallel_mode == "row" and tp_size > 1:
nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
out = gemm_out
if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel:
if symmetric_ar_type is not None:
out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
else:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
else:
out = gemm_out
out = out.view(-1, *inp_shape[1:-1], out_features)
# ------------------------------------------------------
# Output tensor is ready to return...
# ------------------------------------------------------
if not weight.requires_grad:
if not return_layernorm_output:
ln_out = ln_out_total = None
clear_tensor_data(ln_out, ln_out_total)
# ------------------------------------------------------
# Cache state for backward pass
# ------------------------------------------------------
if is_grad_enabled:
ctx.weight_quantizer = weight_quantizer
......@@ -346,19 +394,15 @@ class _LayerNormLinear(torch.autograd.Function):
# Input with column-wise usage is needed for wgrad GEMM.
if backward_needs_input:
if isinstance(ln_out, QuantizedTensor):
if isinstance(ln_out, QuantizedTensorBase):
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather:
ln_out.update_usage(rowwise_usage=False)
# For force_hp_blockwise_ln_out_gather, we should
# be saving the unquantized ln_out to ctx.
assert not force_hp_blockwise_ln_out_gather
# Weight with column-wise usage is needed for dgrad GEMM.
if isinstance(weightmat, QuantizedTensor):
if isinstance(weightmat, QuantizedTensorBase):
weightmat.update_usage(columnwise_usage=True)
if cpu_offloading:
......@@ -445,22 +489,9 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.wgrad_store = wgrad_store
ctx.debug = debug
# Row Parallel Linear
if ub_overlap_rs_fprop:
out = rs_out
elif parallel_mode == "row":
nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel:
if symmetric_ar_type is not None:
out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
else:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
# [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.view(-1, *inp_shape[1:-1], out_features)
# ------------------------------------------------------
# Cached state for backward pass is ready...
# ------------------------------------------------------
if return_layernorm_output:
if return_layernorm_output_gathered:
......@@ -482,24 +513,6 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_label = f"{nvtx_label}.{ctx.ub_name}"
with torch.cuda.nvtx.range("_LayerNormLinear_backward"):
if (
ctx.fp8
and any(
[
ctx.ub_overlap_ag,
ctx.ub_overlap_rs_dgrad,
ctx.ub_bulk_dgrad,
ctx.ub_bulk_wgrad,
]
)
and (ctx.fp8_recipe is not None)
):
if not ctx.fp8_recipe.float8_per_tensor_scaling():
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
saved_tensors = ctx.saved_tensors
( # pylint: disable=unbalanced-tuple-unpacking
inputmat,
......@@ -544,66 +557,50 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
origin_weight.main_grad = main_grad
# Configure Userbuffers communication (comm+GEMM overlap)
ctx.ub_obj_gradout = None
ub_obj_dgrad = None
ub_obj_wgrad = None
ub_type_dgrad = None
ub_type_wgrad = None
dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
rs_out = None
dgrad_bulk = None
if ctx.ub_overlap_ag:
# Overlap grad_output all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG
elif ctx.ub_overlap_rs_dgrad:
# Overlap dgrad reduce-scatter with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.RS
rs_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=inputmat.device
)
else:
if ctx.ub_bulk_dgrad:
# Overlap inputmat all-gather with dgrad compute
# NOTE: Copying into communication buffer will always prefer rowwise data,
# and will copy columnwise data if rowwise does not exist. In that case,
# the all-gather will apply to the leading dimension of the transpose,
# which then needs to be interleaved correctly before WGRAD.
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG
ub_obj_dgrad.copy_into_buffer(ln_out, ctx.input_quantizer, local_chunk=True)
if ctx.ub_bulk_wgrad:
# Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad")
ub_type_wgrad = tex.CommOverlapType.RS
ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer)
dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer)
# --------------------------------------------------
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# --------------------------------------------------
# Configure quantizer for grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
rowwise_usage = True
columnwise_usage = True
if ctx.ub_overlap_ag and isinstance(
ctx.grad_output_quantizer,
(Float8Quantizer, Float8CurrentScalingQuantizer),
):
# If data is in FP8 and communication is handled
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage = False
ctx.grad_output_quantizer.set_usage(
rowwise=rowwise_usage,
columnwise=columnwise_usage,
)
quantizer = ctx.grad_output_quantizer
quantizer.set_usage(rowwise=True, columnwise=True)
if ctx.ub_overlap_ag:
# Userbuffers only supports communication for one
# tensor usage at a time. Configure quantizer with
# usage for only dgrad GEMM.
quantizer.set_usage(columnwise=False)
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
......@@ -619,12 +616,21 @@ class _LayerNormLinear(torch.autograd.Function):
)
nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
# Launch tensor-parallel communication for LayerNorm out tensor
# --------------------------------------------------
# Grad output tensor is ready for computing grad input...
# --------------------------------------------------
# --------------------------------------------------
# Prepare GEMM input tensor
# Note: Input tensor is needed for wgrad GEMM.
# Tensor-parallel communication is overlapped with dgrad
# GEMM.
# --------------------------------------------------
ln_out_total = None
ln_out_total_work = None
if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad:
if ctx.ln_out_needs_gather:
quantizer = None
if ctx.input_quantizer is not None:
if ctx.input_quantizer is not None and not ctx.force_hp_blockwise_ln_out_gather:
quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
......@@ -632,70 +638,92 @@ class _LayerNormLinear(torch.autograd.Function):
else:
# wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
# async_op is not compatible with high precision gather since
# gather_along_first_dim does not offer callback chaining.
gather_quantizer = None if ctx.force_hp_blockwise_ln_out_gather else quantizer
ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out,
ctx.tp_group,
async_op=True,
quantizer=gather_quantizer,
)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
if ctx.ub_bulk_dgrad:
ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_dgrad,
ln_out,
quantizer,
ctx.tp_group,
)
else:
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out,
ctx.tp_group,
async_op=True,
quantizer=quantizer,
)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
else:
ln_out_total = ln_out
# Check whether to output wgrad GEMM directly into main grad
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# dgrad GEMM
if ctx.grad_input_quantizer is not None:
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD
# --------------------------------------------------
# Input tensor is ready for computing grad weight...
# --------------------------------------------------
# --------------------------------------------------
# Compute grad input tensor
# Note: Gradient w.r.t. GEMM input (i.e. norm output).
# --------------------------------------------------
# Make sure required data is available
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(rowwise_usage=True)
if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorBase):
weight.update_usage(columnwise_usage=True)
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator = _2X_ACC_DGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_gemm_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensor):
weight.update_usage(
rowwise_usage=ctx.weight_quantizer.rowwise_usage,
columnwise_usage=ctx.weight_quantizer.columnwise_usage,
# Update grad input quantizer
if ctx.grad_input_quantizer is not None:
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
# Output buffers for Userbuffers reduce-scatter
gemm_out = None
reduce_scatter_out = None
if ctx.ub_overlap_rs_dgrad:
reduce_scatter_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_outputs[0].device
)
dgrad, *_ = general_gemm(
elif ctx.ub_bulk_wgrad:
gemm_out = ub_obj_wgrad.get_buffer(local_chunk=False)
# dgrad GEMM
# Note: dx = dy * w
nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
gemm_out, *_, reduce_scatter_out = general_gemm(
weight,
grad_output,
get_workspace(),
layout="NN",
grad=True,
quantization_params=ctx.grad_input_quantizer,
out=dgrad_bulk,
out=gemm_out,
out_dtype=ctx.activation_dtype,
use_split_accumulator=dgrad_gemm_use_split_accumulator,
use_split_accumulator=use_split_accumulator,
ub=ub_obj_dgrad,
ub_type=ub_type_dgrad,
extra_output=rs_out,
extra_output=reduce_scatter_out,
bulk_overlap=ctx.ub_bulk_dgrad,
)
nvtx_range_pop(f"{nvtx_label}.dgrad_gemm")
# Launch tensor-parallel communication
# Prepare grad input tensor
# Note: Perform tensor-parallel communication
dgrad = None
dgrad_work = None
if ctx.ub_overlap_rs_dgrad:
dgrad = rs_out
elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad:
dgrad = reduce_scatter_out
elif ctx.ub_bulk_wgrad:
dgrad = ub_obj_wgrad.get_buffer(local_chunk=True)
elif ctx.parallel_mode == "column" and ctx.tp_size > 1:
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad")
dgrad = gemm_out
if ctx.sequence_parallel:
if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered:
dgrad = dgrad + grad_outputs[1].view_as(dgrad)
dgrad, dgrad_work = reduce_scatter_along_first_dim(
dgrad,
ctx.tp_group,
......@@ -704,41 +732,55 @@ class _LayerNormLinear(torch.autograd.Function):
else:
dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad")
else:
dgrad = gemm_out
# --------------------------------------------------
# Grad input tensor has been computed...
# --------------------------------------------------
# --------------------------------------------------
# Compute grad weight
# --------------------------------------------------
# Compute grad weight tensor
wgrad = None
if ctx.requires_wgrad:
# Synchronize tensor-parallel communication for input tensor
if ctx.ub_bulk_dgrad:
ln_out_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer)
if ctx.fp8:
# FP8 GEMM on Hopper only supports TN layout so the gathered input must have
# a valid transpose.
if ln_out._data is None:
# All-gather executed on columnwise data and result is in rowwise data,
# so we need to fix the interleaving before WGRAD.
ln_out_total = _fix_gathered_fp8_transpose(ln_out_total, ctx.tp_size)
else:
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
ln_out_total._create_transpose()
# Prepare input tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ln_out_total_work is not None:
ln_out_total_work.wait()
ln_out_total_work = None
if ctx.input_quantizer is not None and not isinstance(
ln_out_total, QuantizedTensor
):
# Async gather may have been done in BF16
# call quantizer after gather.
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total = ctx.input_quantizer(ln_out_total)
# Make sure GEMM inputs have required data
if isinstance(ln_out_total, QuantizedTensor):
ln_out_total.update_usage(columnwise_usage=True)
if isinstance(grad_output, QuantizedTensor):
grad_output.update_usage(columnwise_usage=True)
if ctx.fp8 or ctx.debug:
if isinstance(ln_out_total, QuantizedTensorBase):
ln_out_total.update_usage(columnwise_usage=True)
else:
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total = ctx.input_quantizer(ln_out_total)
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output, _ = gather_along_first_dim(
grad_outputs[0],
ctx.tp_group,
quantizer=ctx.grad_output_quantizer,
)
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(columnwise_usage=True)
else:
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output = ctx.grad_output_quantizer(grad_output)
# Figure out whether to use split accumulator
use_split_accumulator = _2X_ACC_WGRAD
......@@ -747,55 +789,95 @@ class _LayerNormLinear(torch.autograd.Function):
if hasattr(recipe, "fp8_gemm_wgrad"):
use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator
# Output buffer for overlapping grad input
# Figure out whether to output wgrad GEMM directly into main grad
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# Output buffer for overlapping FP8 grad input
# reduce-scatter with wgrad GEMM
reduce_scatter_out = None
if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf():
rs_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=inputmat.device
reduce_scatter_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_outputs[0].device
)
# wgrad GEMM
# Note: Fuse with bgrad computation if needed
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
general_gemm_wgrad = functools.partial(
general_gemm,
out_dtype=(
# Arguments to include in wgrad GEMM closure
wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": (
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
),
workspace=get_workspace(),
layout="NT",
grad=True,
bias=(bias if (grad_bias is None and not ctx.fp8) else None),
out=main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad,
quantization_params=ctx.grad_weight_quantizer,
ub=ub_obj_wgrad,
ub_type=ub_type_wgrad,
extra_output=rs_out,
bulk_overlap=ctx.ub_bulk_wgrad,
)
"quantization_params": ctx.grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad,
"layout": "NT",
"out": main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": (bias if (grad_bias is None and not ctx.fp8) else None),
"use_split_accumulator": use_split_accumulator,
"grad": True,
"ub": ub_obj_wgrad,
"ub_type": ub_type_wgrad,
"extra_output": reduce_scatter_out,
"bulk_overlap": ctx.ub_bulk_wgrad,
}
def wgrad_gemm(
x: torch.Tensor,
dy: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform wgrad GEMM: dw = dy^T * x
May be fused with bgrad computation.
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
return dw, db
# Choose whether to call wgrad GEMM now or delay
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([ln_out_total, grad_output], general_gemm_wgrad)
if (
wgrad_gemm_kwargs["ub"] is not None
or wgrad_gemm_kwargs["ub_type"] is not None
or wgrad_gemm_kwargs["extra_output"] is not None
or wgrad_gemm_kwargs["bulk_overlap"]
):
raise NotImplementedError(
"Delayed weight grad computation is not supported "
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx.wgrad_store.put([ln_out_total, grad_output], wgrad_gemm)
else:
wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(ln_out_total, grad_output)
# Call wgrad GEMM now
wgrad, grad_bias_ = wgrad_gemm(ln_out_total, grad_output)
# Update grad bias if needed
if grad_bias is None:
grad_bias = grad_bias_
del grad_bias_
# Deallocate input tensor
# Deallocate input tensor if permitted
if not ctx.return_layernorm_output:
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
clear_tensor_data(ln_out_total)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad:
if ub_obj_wgrad.is_fp8_ubuf():
dgrad = rs_out
dgrad = reduce_scatter_out
else:
dgrad = ub_obj_wgrad.get_buffer(None, local_chunk=True)
dgrad = ub_obj_wgrad.get_buffer(local_chunk=True).clone()
# --------------------------------------------------
# Grad weight has been computed...
# --------------------------------------------------
# Don't return grad bias if not needed
if not ctx.use_bias:
......@@ -870,7 +952,7 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
# Scatter fp8 weight buffers
# if ctx.fp8 and not isinstance(weight, QuantizedTensor):
# if ctx.fp8 and not isinstance(weight, QuantizedTensorBase):
# _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
return (
......@@ -1506,7 +1588,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
grad_output_quantizer = None
output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
input_quantizer.internal = False
input_quantizer.internal = True
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
if fp8_output:
......@@ -1574,3 +1656,12 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
# parallel related
if self.sequence_parallel and self.parallel_mode == "row":
# customize grad_output_quantizer with amax reduction TP group
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
......@@ -8,7 +8,6 @@ import warnings
from typing import Callable, Optional, Tuple, Union
from functools import reduce
from operator import mul as multiply_op
import functools
import torch
from torch.nn.parameter import Parameter
......@@ -19,6 +18,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version
from .base import (
fill_userbuffers_buffer_for_all_gather,
get_workspace,
_ub_communicators,
get_ub,
......@@ -42,7 +42,6 @@ from ..utils import (
assert_dim_for_fp8_exec,
clear_tensor_data,
requires_grad,
is_non_tn_fp8_gemm_supported,
needs_quantized_gemm,
)
from ..distributed import (
......@@ -66,10 +65,11 @@ from ..tensor.float8_tensor import (
)
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, _fix_gathered_fp8_transpose, WeightGradStore
from ._common import apply_normalization, WeightGradStore
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..tensor.quantized_tensor import (
QuantizedTensor,
QuantizedTensorBase,
Quantizer,
prepare_for_saving,
restore_from_saved,
......@@ -200,24 +200,16 @@ class _LayerNormMLP(torch.autograd.Function):
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring
in_features, inp_shape = ln_weight.numel(), inp.shape
# Make sure input dimensions are compatible
in_features, inp_shape = ln_weight.numel(), inp.shape
assert inp_shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
if fp8:
assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight)
if any([ub_overlap_ag, ub_overlap_rs]) and not (
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
):
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
activation_func = _act_func(
activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
)[0]
device = inp.device
# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
......@@ -225,6 +217,38 @@ class _LayerNormMLP(torch.autograd.Function):
if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype)
tp_world_size = get_distributed_world_size(tp_group)
backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad
device = inp.device
# Configure Userbuffers communication (comm+GEMM overlap)
ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered
ub_overlap_rs = ub_overlap_rs and is_grad_enabled
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator = _2X_ACC_FPROP
if fp8:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if hasattr(recipe, "fp8_gemm_fprop"):
use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
# Configure quantizer for norm output
if fp8:
if fc1_input_quantizer is None:
raise ValueError("Missing quantizer for FC1 input tensor")
fc1_input_quantizer.set_usage(rowwise=True, columnwise=backwards_needs_fc1_input)
if sequence_parallel and isinstance(
fc1_input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
# All-gather is not supported with FP8 column-wise data
fc1_input_quantizer.set_usage(columnwise=False)
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_fc1_input_gather = (
fp8 and sequence_parallel and isinstance(fc1_input_quantizer, Float8BlockQuantizer)
)
# for fp8 DelayedScaling: layernorm output = FP8
# only output of the linear is returned
# for return_layernorm_output: layernorm output = High precision, then cast to FP8
......@@ -240,29 +264,6 @@ class _LayerNormMLP(torch.autograd.Function):
# Kernels not available for norm fusion.
with_quantized_norm = False
tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered
ub_overlap_rs = ub_overlap_rs and is_grad_enabled
backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad
# TODO(kwyss): Support FP8 allgather of Float8BlockQuantizer recipe
force_hp_fc1_input_gather = (
fp8 and sequence_parallel and isinstance(fc1_input_quantizer, Float8BlockQuantizer)
) # Perform TP communication in high precision.
# Configure quantizer for norm output
if fp8:
if fc1_input_quantizer is None:
raise ValueError("Missing quantizer for FC1 input tensor")
columnwise_usage = backwards_needs_fc1_input
if (
columnwise_usage
and sequence_parallel
and not isinstance(fc1_input_quantizer, MXFP8Quantizer)
):
columnwise_usage = False
fc1_input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
# Apply normalization
ln_out, mu, rsigma = apply_normalization(
inputmat,
......@@ -296,39 +297,43 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out_total = fc1_input_quantizer(ln_out_total)
else:
quantizer = None
if fp8 or debug:
quantizer = fc1_input_quantizer
if not with_quantized_norm and not force_hp_fc1_input_gather:
ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag:
# Copy into Userbuffers buffer
ub_obj_lnout = get_ub("fc1_fprop")
ub_obj_lnout.get_buffer(fc1_input_quantizer, local_chunk=True).copy_(ln_out)
ln_out_total = ub_obj_lnout.get_buffer(fc1_input_quantizer)
ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_lnout,
ln_out,
quantizer,
tp_group,
)
else:
# All-gather with NCCL
ln_out_total, _ = gather_along_first_dim(
ln_out,
tp_group,
quantizer=(fc1_input_quantizer if fp8 or debug else None),
quantizer=quantizer,
)
else:
# NOTE: force_hp_fc1_input_gather is redundant with else, but
# here for clarity. We should not quantize ln_out if bwd needs
# to gather in hp.
if (fp8 or debug) and not with_quantized_norm and not force_hp_fc1_input_gather:
if (fp8 or debug) and not with_quantized_norm:
ln_out = fc1_input_quantizer(ln_out)
ln_out_total = ln_out
# Cast weights to expected dtype
fc1_weight_final = fc1_weight
fc2_weight_final = fc2_weight
if fp8 or debug:
# If weights are not quantized, we call get_weight_workspace,
# which handles weight caching etc.
# FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch
fc1_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc1_weight_final = module.get_weight_workspace(
tensor=fc1_weight,
quantizer=fc1_weight_quantizer,
......@@ -338,7 +343,6 @@ class _LayerNormMLP(torch.autograd.Function):
fsdp_group=fsdp_group,
workspace_dtype=activation_dtype,
)
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc2_weight_final = module.get_weight_workspace(
tensor=fc2_weight,
quantizer=fc2_weight_quantizer,
......@@ -348,6 +352,8 @@ class _LayerNormMLP(torch.autograd.Function):
fsdp_group=fsdp_group,
workspace_dtype=activation_dtype,
)
fc1_weight_final.update_usage(rowwise_usage=True)
fc2_weight_final.update_usage(rowwise_usage=True)
else:
fc1_weight_final = cast_if_needed(fc1_weight_final, activation_dtype)
fc2_weight_final = cast_if_needed(fc2_weight_final, activation_dtype)
......@@ -355,6 +361,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Cast biases to expected dtype
bias_dtype = activation_dtype
if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32:
# cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16
bias_dtype = torch.bfloat16
if fc1_bias is not None:
fc1_bias = cast_if_needed(fc1_bias, bias_dtype)
......@@ -368,7 +375,9 @@ class _LayerNormMLP(torch.autograd.Function):
if fc1_weight_quantizer is not None:
fc1_weight_quantizer.calibrate(fc1_weight)
# ------------------------------------------------------
# FC1 GEMM
# ------------------------------------------------------
# There are 2 fusions possible:
# - gemm_gelu_fusion - default for full precision, optional for fp8 - need to turn on gemm_gelu_fusion,
......@@ -400,11 +409,16 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_bias if not bias_gelu_fusion else None
), # otherwise bias is added later (fused with gelu)
gelu=gemm_gelu_fusion,
accumulate=_2X_ACC_FPROP,
use_split_accumulator=use_split_accumulator,
ub=ub_obj_lnout,
ub_type=tex.CommOverlapType.AG if ub_overlap_ag else None,
)
# ------------------------------------------------------
# Finished FC1 GEMM...
# ------------------------------------------------------
# Deallocate FC1 GEMM input tensor if no longer needed
if not is_grad_enabled and (ln_out_total is not ln_out_return):
clear_tensor_data(ln_out_total)
......@@ -438,45 +452,66 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_input_quantizer.calibrate(act_out)
fc2_weight_quantizer.calibrate(fc2_weight)
# Configure Userbuffers reduce-scatter if needed
ub_obj_fc2out = None
rs_out = None
fc2_out = None
reduce_scatter_out = None
if ub_overlap_rs:
ub_obj_fc2out = get_ub("fc2_fprop")
dim_size = list(act_out.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = fc2_weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=device)
else:
dim_size = list(act_out.size())
dim_size[1] = fc2_weight.size(0)
fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=device)
dim_size[0] //= tp_world_size
dim_size[-1] = fc2_weight.size(0)
reduce_scatter_out = torch.empty(dim_size, dtype=activation_dtype, device=device)
# ------------------------------------------------------
# FC2 GEMM
_ = general_gemm(
# ------------------------------------------------------
gemm_out, *_, reduce_scatter_out = general_gemm(
fc2_weight_final,
act_out,
get_workspace(),
out_dtype=activation_dtype,
bias=fc2_bias,
quantization_params=fc2_output_quantizer,
out=fc2_out,
use_split_accumulator=_2X_ACC_FPROP,
use_split_accumulator=use_split_accumulator,
ub=ub_obj_fc2out,
ub_type=tex.CommOverlapType.RS if ub_overlap_rs else None,
extra_output=rs_out,
extra_output=reduce_scatter_out,
)
# ------------------------------------------------------
# Finished FC2 GEMM...
# ------------------------------------------------------
# Deallocate tensors if no longer needed
if not is_grad_enabled:
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)
# Prepare output tensor
# Note: Perform tensor-parallel communication if needed
fc2_out = None
if ub_overlap_rs:
fc2_out = reduce_scatter_out
elif set_parallel_mode and sequence_parallel:
fc2_out, _ = reduce_scatter_along_first_dim(gemm_out, tp_group)
elif set_parallel_mode and tensor_parallel:
if symmetric_ar_type is not None:
fc2_out, _ = symmetric_all_reduce(
gemm_out, tp_group, all_reduce_type=symmetric_ar_type
)
else:
fc2_out, _ = allreduce(gemm_out, tp_group)
else:
fc2_out = gemm_out
fc2_out = fc2_out.view(-1, *inp_shape[1:-1], fc2_out.shape[-1])
# Weight with column-wise usage is needed for dgrad GEMM.
# Cache state for backward pass
if is_grad_enabled:
if isinstance(fc1_weight_final, QuantizedTensor):
# Weight with column-wise usage is needed for dgrad GEMM.
if isinstance(fc1_weight_final, QuantizedTensorBase):
fc1_weight_final.update_usage(columnwise_usage=True)
if isinstance(fc2_weight_final, QuantizedTensor):
if isinstance(fc2_weight_final, QuantizedTensorBase):
fc2_weight_final.update_usage(columnwise_usage=True)
if not is_grad_enabled:
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)
else:
if cpu_offloading:
mark_activation_offload(
inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out
......@@ -503,8 +538,6 @@ class _LayerNormMLP(torch.autograd.Function):
if not return_layernorm_output:
clear_tensor_data(ln_out)
ln_out = None
elif force_hp_fc1_input_gather:
assert not isinstance(ln_out, QuantizedTensor)
if not fc2_weight.requires_grad:
clear_tensor_data(act_out)
act_out = None
......@@ -591,22 +624,6 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.wgrad_store = wgrad_store
# Row Parallel Linear
if ub_overlap_rs:
fc2_out = rs_out
elif set_parallel_mode and sequence_parallel:
fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group)
elif set_parallel_mode and tensor_parallel:
if symmetric_ar_type is not None:
fc2_out, _ = symmetric_all_reduce(
fc2_out, tp_group, all_reduce_type=symmetric_ar_type
)
else:
fc2_out, _ = allreduce(fc2_out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
fc2_out = fc2_out.view(-1, *inp_shape[1:-1], fc2_out.shape[-1])
if return_layernorm_output:
if return_layernorm_output_gathered:
shape = list(inp_shape)
......@@ -621,24 +638,6 @@ class _LayerNormMLP(torch.autograd.Function):
) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
with torch.cuda.nvtx.range("_LayerNormMLP_backward"):
if (
ctx.fp8
and any(
[
ctx.ub_overlap_ag,
ctx.ub_overlap_rs_dgrad,
ctx.ub_bulk_dgrad,
ctx.ub_bulk_wgrad,
]
)
and (ctx.fp8_recipe is not None)
):
if not ctx.fp8_recipe.float8_per_tensor_scaling():
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
saved_tensors = ctx.saved_tensors
( # pylint: disable=unbalanced-tuple-unpacking
inputmat,
......@@ -698,6 +697,16 @@ class _LayerNormMLP(torch.autograd.Function):
# fc2_weight_fp8 if ctx.fp8 and not isinstance(fc2_weight, Float8Tensor) else None,
# )
# Choose whether to use GEMM kernel with split accumulator
dgrad_use_split_accumulator = _2X_ACC_DGRAD
wgrad_use_split_accumulator = _2X_ACC_WGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
if hasattr(recipe, "fp8_gemm_wgrad"):
wgrad_use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator
# No need to do bulk DGRAD/WGRAD overlap if WGRAD is not required
ctx.ub_bulk_dgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_dgrad
ctx.ub_bulk_wgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_wgrad
......@@ -706,20 +715,13 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.fc2_grad_output_quantizer is not None:
rowwise_usage = True
columnwise_usage = True
if ctx.ub_overlap_ag and isinstance(
ctx.fc2_grad_output_quantizer,
(Float8Quantizer, Float8CurrentScalingQuantizer),
):
# If data is in FP8 and communication is handled
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage = False
ctx.fc2_grad_output_quantizer.set_usage(
rowwise=rowwise_usage,
columnwise=columnwise_usage,
)
quantizer = ctx.fc2_grad_output_quantizer
quantizer.set_usage(rowwise=True, columnwise=True)
if ctx.ub_overlap_ag:
# Userbuffers only supports communication for one
# tensor usage at a time. Configure quantizer with
# usage for only dgrad GEMM.
quantizer.set_usage(columnwise=False)
# Prepare FC2 grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
......@@ -737,14 +739,10 @@ class _LayerNormMLP(torch.autograd.Function):
# Launch tensor-parallel communication for FC1 GEMM input
ln_out_total = None
ln_out_total_work = None
if (
ctx.fc1_weight_requires_grad
and ctx.tensor_parallel
and ctx.sequence_parallel
and not ctx.ub_bulk_dgrad
):
ub_obj_fc1_dgrad = None
if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel:
quantizer = None
if ctx.fp8 or ctx.debug:
if ctx.fp8 or ctx.debug and not ctx.force_hp_fc1_input_gather:
quantizer = ctx.fc1_input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
......@@ -752,13 +750,21 @@ class _LayerNormMLP(torch.autograd.Function):
else:
# wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True)
gather_quantizer = None if ctx.force_hp_fc1_input_gather else quantizer
ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out,
ctx.tp_group,
async_op=True,
quantizer=gather_quantizer,
)
if ctx.ub_bulk_dgrad:
ub_obj_fc1_dgrad = get_ub("fc1_dgrad")
ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_fc1_dgrad,
ln_out,
quantizer,
ctx.tp_group,
)
else:
ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out,
ctx.tp_group,
async_op=True,
quantizer=quantizer,
)
else:
ln_out_total = ln_out
......@@ -769,6 +775,11 @@ class _LayerNormMLP(torch.autograd.Function):
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# --------------------------------------------------
# FC2 DGRAD
# --------------------------------------------------
# There are 6 possible fusion paths
# 1 high-precision bias_gelu_fusion: gemm, FC1_bias + gelu,
# 2 high-precision fc2_dgrad_gemm_gelu_fusion: gemm + gelu, FC1_bias + quantize
......@@ -783,12 +794,15 @@ class _LayerNormMLP(torch.autograd.Function):
and (not ctx.debug)
)
# FC2 DGRAD; Unconditional
if ctx.fc2_weight_quantizer is not None and isinstance(ctx.fc2_weight, QuantizedTensor):
ctx.fc2_weight.update_usage(
rowwise_usage=ctx.fc2_weight_quantizer.rowwise_usage,
columnwise_usage=ctx.fc2_weight_quantizer.columnwise_usage,
)
# Make sure required data is available
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(rowwise_usage=True)
if ctx.fc2_weight_quantizer is not None and isinstance(
ctx.fc2_weight, QuantizedTensorBase
):
ctx.fc2_weight.update_usage(columnwise_usage=True)
# Perform GEMM
gemm_output, *_ = general_gemm(
fc2_weight,
grad_output,
......@@ -803,52 +817,107 @@ class _LayerNormMLP(torch.autograd.Function):
out_dtype=ctx.activation_dtype,
gelu=fc2_dgrad_gemm_gelu_fusion,
gelu_in=fc1_out if fc2_dgrad_gemm_gelu_fusion else None,
use_split_accumulator=_2X_ACC_DGRAD,
use_split_accumulator=dgrad_use_split_accumulator,
ub=ub_obj_fc2_dgrad,
ub_type=tex.CommOverlapType.AG if ctx.ub_overlap_ag else None,
)
# Prepare input grad tensor
dact = None
fc2_dgrad = None
if fc2_dgrad_gemm_gelu_fusion:
dact = gemm_output
fc2_dgrad = None
else:
fc2_dgrad = gemm_output
# --------------------------------------------------
# Finished FC2 DGRAD...
# --------------------------------------------------
# --------------------------------------------------
# FC2 WGRAD
# --------------------------------------------------
fc2_wgrad = None
if ctx.fc2_weight_requires_grad:
if isinstance(act_out, QuantizedTensor):
act_out.update_usage(rowwise_usage=True, columnwise_usage=True)
if isinstance(grad_output, QuantizedTensor):
grad_output.update_usage(rowwise_usage=True, columnwise_usage=True)
# Prepare input tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.fp8 or ctx.debug:
if isinstance(act_out, QuantizedTensorBase):
act_out.update_usage(columnwise_usage=True)
else:
ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True)
act_out = ctx.fc2_input_quantizer(act_out)
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output, _ = gather_along_first_dim(
grad_outputs[0],
ctx.tp_group,
quantizer=ctx.fc2_grad_output_quantizer,
)
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(columnwise_usage=True)
else:
ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output = ctx.fc2_grad_output_quantizer(grad_output)
# Whether to set grad arg in general_gemm
grad_arg = True
if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling():
grad_arg = False
general_gemm_fc2_wgrad = functools.partial(
general_gemm,
out_dtype=(
# Arguments to include in wgrad GEMM closure
fc2_wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": (
origin_fc2_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype
),
workspace=get_workspace(),
quantization_params=ctx.fc2_grad_weight_quantizer, # wgrad in high precision
layout="NT",
grad=grad_arg,
bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None,
accumulate=accumulate_wgrad_into_param_main_grad,
use_split_accumulator=_2X_ACC_WGRAD,
out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
"quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision
"accumulate": accumulate_wgrad_into_param_main_grad,
"layout": "NT",
"out": origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None,
"use_split_accumulator": wgrad_use_split_accumulator,
"grad": grad_arg,
}
def fc2_wgrad_gemm(
x: torch.Tensor,
dy: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform FC2 WGRAD GEMM
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
dw, db, *_ = general_gemm(x, dy, **fc2_wgrad_gemm_kwargs)
return dw, db
# Choose whether to call wgrad GEMM now or delay
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([act_out, grad_output], general_gemm_fc2_wgrad)
fc2_wgrad = None
ctx.wgrad_store.put([act_out, grad_output], fc2_wgrad_gemm)
else:
fc2_wgrad, fc2_bias_grad_, *_ = general_gemm_fc2_wgrad(
act_out,
grad_output,
)
# Call wgrad GEMM now
fc2_wgrad, fc2_bias_grad_ = fc2_wgrad_gemm(act_out, grad_output)
# Update grad bias if needed
if fc2_bias_grad is None:
if (
ctx.fp8
......@@ -857,12 +926,17 @@ class _LayerNormMLP(torch.autograd.Function):
):
# BGRAD not fused with GEMM for float8 blockwise gemm.
fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0)
fc2_bias_grad = fc2_bias_grad_
del fc2_bias_grad_
# Deallocate input tensor if permitted
if ctx.wgrad_store is not None and not ctx.wgrad_store.delay_wgrad_compute():
clear_tensor_data(act_out)
# --------------------------------------------------
# Finished FC2 WGRAD...
# --------------------------------------------------
# bias computation
fc1_bias_grad = None
fuse_gemm_and_bias_fc1_wgrad = False
......@@ -926,63 +1000,69 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_fc1_dgrad = None
ub_obj_fc1_wgrad = None
ub_type_fc1_dgrad = None
ub_type_fc1_wgrad = None
fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]]
fc1_dgrad_rs_out = None
fc1_dgrad_bulk = None
if ctx.ub_overlap_rs_dgrad:
# Overlap DGRAD+RS
ub_obj_fc1_dgrad = get_ub("fc1_dgrad")
ub_type_fc1_dgrad = tex.CommOverlapType.RS
fc1_dgrad_rs_out = torch.empty(
fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda"
)
else:
if ctx.ub_bulk_dgrad:
# Overlap ln_out all-gather with DGRAD compute
# NOTE: Copying into communication buffer will always prefer rowwise data,
# and will copy columnwise data if rowwise does not exist. In that case,
# the all-gather will apply to the leading dimension of the transpose,
# which then needs to be interleaved correctly before WGRAD.
ub_obj_fc1_dgrad = get_ub("fc1_dgrad")
ub_type_fc1_dgrad = tex.CommOverlapType.AG
ub_obj_fc1_dgrad.copy_into_buffer(
ln_out, ctx.fc1_input_quantizer, local_chunk=True
)
if ctx.ub_bulk_wgrad:
# Overlap FC1 DGRAD reduce-scatter with WGRAD compute
ub_obj_fc1_wgrad = get_ub("fc1_wgrad")
fc1_dgrad_bulk = ub_obj_fc1_wgrad.get_buffer(None)
ub_type_fc1_wgrad = tex.CommOverlapType.RS
# FC1 DGRAD: Unconditional
# --------------------------------------------------
# FC1 DGRAD
# --------------------------------------------------
# Make sure required data is available
if ctx.fc1_weight_quantizer is not None and isinstance(
ctx.fc1_weight_quantizer, QuantizedTensor
ctx.fc1_weight_quantizer, QuantizedTensorBase
):
ctx.fc1_weight.update_usage(
rowwise_usage=ctx.fc1_weight_quantizer.rowwise_usage,
columnwise_usage=ctx.fc1_weight_quantizer.columnwise_usage,
ctx.fc1_weight.update_usage(columnwise_usage=True)
# Output buffers for Userbuffers reduce-scatter
gemm_out = None
reduce_scatter_out = None
if ctx.ub_overlap_rs_dgrad:
reduce_scatter_out = torch.empty(
fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda"
)
fc1_dgrad, *_, fc1_dgrad_rs_out = general_gemm(
if ctx.ub_bulk_wgrad:
gemm_out = ub_obj_fc1_wgrad.get_buffer(local_chunk=False)
# dgrad GEMM
gemm_out, *_, reduce_scatter_out = general_gemm(
fc1_weight,
dact,
get_workspace(),
out=fc1_dgrad_bulk,
out=gemm_out,
out_dtype=ctx.activation_dtype,
quantization_params=ctx.fc1_grad_input_quantizer,
layout="NN",
grad=True,
use_split_accumulator=dgrad_use_split_accumulator,
ub=ub_obj_fc1_dgrad,
ub_type=ub_type_fc1_dgrad,
extra_output=fc1_dgrad_rs_out,
extra_output=reduce_scatter_out,
bulk_overlap=ctx.ub_bulk_dgrad,
)
# Overlap dgrad-RS/AR with wgrad
# Prepare grad input tensor
# Note: Perform tensor-parallel communication
fc1_dgrad = None
fc1_dgrad_work = None
if ctx.ub_overlap_rs_dgrad:
fc1_dgrad = fc1_dgrad_rs_out
fc1_dgrad = reduce_scatter_out
elif ctx.ub_bulk_wgrad:
fc1_dgrad = ub_obj_fc1_wgrad.get_buffer(local_chunk=True)
elif ctx.set_parallel_mode and not ctx.ub_bulk_wgrad:
fc1_dgrad = gemm_out
if ctx.sequence_parallel:
if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered:
fc1_dgrad = fc1_dgrad + grad_outputs[1].view_as(fc1_dgrad)
......@@ -993,90 +1073,125 @@ class _LayerNormMLP(torch.autograd.Function):
)
elif ctx.tensor_parallel:
fc1_dgrad, fc1_dgrad_work = allreduce(fc1_dgrad, ctx.tp_group, async_op=True)
else:
fc1_dgrad = gemm_out
# --------------------------------------------------
# Finished FC1 DGRAD...
# --------------------------------------------------
# --------------------------------------------------
# FC1 WGRAD
# --------------------------------------------------
fc1_wgrad = None
if ctx.fc1_weight_requires_grad:
# Synchronize tensor-parallel communication for FC1 GEMM input tensor
if ctx.ub_bulk_dgrad:
ln_out_total = ub_obj_fc1_dgrad.get_buffer(ctx.fc1_input_quantizer)
if ctx.fp8:
if ln_out._data is None:
# All-gather executed on columnwise data and result is in rowwise data,
# so we need to fix the interleaving before WGRAD.
ln_out_total = _fix_gathered_fp8_transpose(ln_out_total, ctx.tp_size)
elif not is_non_tn_fp8_gemm_supported():
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
ln_out_total._create_transpose()
# Prepare input tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ln_out_total_work is not None:
ln_out_total_work.wait()
ln_out_total_work = None
if ctx.fc1_input_quantizer is not None and not isinstance(
ln_out_total, QuantizedTensor
):
# Async gather in BF16 does not asynchronously
# call quantizer after gather.
ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total = ctx.fc1_input_quantizer(ln_out_total)
# Make sure GEMM inputs have required data
if isinstance(ln_out_total, QuantizedTensor):
ln_out_total.update_usage(columnwise_usage=True)
if isinstance(dact, QuantizedTensor):
dact.update_usage(columnwise_usage=True)
if ctx.fp8 or ctx.debug:
if isinstance(ln_out_total, QuantizedTensorBase):
ln_out_total.update_usage(columnwise_usage=True)
else:
ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total = ctx.fc1_input_quantizer(ln_out_total)
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.fp8 or ctx.debug:
if isinstance(dact, QuantizedTensorBase):
dact.update_usage(columnwise_usage=True)
else:
ctx.fc1_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
dact = ctx.fc1_grad_output_quantizer(dact)
# Output buffer for overlapping grad input
# reduce-scatter with wgrad GEMM
reduce_scatter_out = None
if ctx.ub_bulk_wgrad and ub_obj_fc1_wgrad.is_fp8_ubuf():
fc1_dgrad_rs_out = torch.empty(
reduce_scatter_out = torch.empty(
fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda"
)
# wgrad GEMM
general_gemm_fc1_wgrad = functools.partial(
general_gemm,
out_dtype=(
# Arguments to include in wgrad GEMM closure
fc1_wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": (
origin_fc1_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype
),
workspace=get_workspace(),
layout="NT",
quantization_params=ctx.fc1_grad_weight_quantizer,
grad=fuse_gemm_and_bias_fc1_wgrad,
bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None,
accumulate=accumulate_wgrad_into_param_main_grad,
out=origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub=ub_obj_fc1_wgrad,
ub_type=tex.CommOverlapType.RS if ctx.ub_bulk_wgrad else None,
extra_output=fc1_dgrad_rs_out,
bulk_overlap=ctx.ub_bulk_wgrad,
)
"quantization_params": ctx.fc1_grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad,
"layout": "NT",
"out": origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None,
"use_split_accumulator": wgrad_use_split_accumulator,
"grad": fuse_gemm_and_bias_fc1_wgrad,
"ub": ub_obj_fc1_wgrad,
"ub_type": ub_type_fc1_wgrad,
"extra_output": reduce_scatter_out,
"bulk_overlap": ctx.ub_bulk_wgrad,
}
def fc1_wgrad_gemm(
x: torch.Tensor,
dy: torch.Tensor,
_is_delayed: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform FC1 WGRAD GEMM
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
dw, db, *_ = general_gemm(x, dy, **fc1_wgrad_gemm_kwargs)
return dw, db
# Choose whether to call wgrad GEMM now or delay
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([ln_out_total, dact], general_gemm_fc1_wgrad)
if (
fc1_wgrad_gemm_kwargs["ub"] is not None
or fc1_wgrad_gemm_kwargs["ub_type"] is not None
or fc1_wgrad_gemm_kwargs["extra_output"] is not None
or fc1_wgrad_gemm_kwargs["bulk_overlap"]
):
raise NotImplementedError(
"Delayed weight grad computation is not supported "
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx.wgrad_store.put([ln_out_total, dact], fc1_wgrad_gemm)
fc1_wgrad = None
if fuse_gemm_and_bias_fc1_wgrad:
fc1_bias_grad = None
else:
fc1_wgrad_outputs = general_gemm_fc1_wgrad(
ln_out_total,
dact,
)
clear_tensor_data(ln_out_total, dact)
# Call wgrad GEMM now
fc1_wgrad_outputs = fc1_wgrad_gemm(ln_out_total, dact)
if fuse_gemm_and_bias_fc1_wgrad:
fc1_wgrad, fc1_bias_grad, *_ = fc1_wgrad_outputs
fc1_wgrad, fc1_bias_grad = fc1_wgrad_outputs
else:
fc1_wgrad, *_ = fc1_wgrad_outputs
fc1_wgrad, _ = fc1_wgrad_outputs
# Deallocate tensors if permitted
clear_tensor_data(dact)
if not ctx.return_layernorm_output_gathered:
clear_tensor_data(ln_out_total)
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad:
if ub_obj_fc1_wgrad.is_fp8_ubuf():
fc1_dgrad = fc1_dgrad_rs_out
fc1_dgrad = reduce_scatter_out
else:
fc1_dgrad = ub_obj_fc1_wgrad.get_buffer(None, local_chunk=True)
fc1_dgrad = ub_obj_fc1_wgrad.get_buffer(local_chunk=True).clone()
# --------------------------------------------------
# Finished FC1 WGRAD...
# --------------------------------------------------
# Make sure all tensor-parallel communication is finished
if ln_out_total_work is not None:
......@@ -1746,7 +1861,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
) = [None] * 12
if self.fp8:
fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
fc1_input_quantizer.internal = False # temporary
fc1_input_quantizer.internal = True
fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
fc1_weight_quantizer.internal = True
fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
......@@ -1754,6 +1869,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
rowwise=True,
columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)),
)
fc1_input_quantizer.internal = True
fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT]
fc2_weight_quantizer.internal = True
if fp8_output:
......@@ -1762,11 +1878,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
]
if torch.is_grad_enabled():
fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
tex.FP8BwdTensors.GRAD_OUTPUT2
]
fc2_grad_output_quantizer.internal = True
fc1_grad_output_quantizer = self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_INPUT1
tex.FP8BwdTensors.GRAD_OUTPUT1
]
fc1_grad_output_quantizer.internal = True
......@@ -1851,25 +1967,25 @@ class LayerNormMLP(TransformerEngineBaseModule):
else:
# fc2_grad_output_quantizer: set configs about amax epsilon and power_2_scale for fc2_grad_output_quantizer
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
tex.FP8BwdTensors.GRAD_OUTPUT2
].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
tex.FP8BwdTensors.GRAD_OUTPUT2
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
# fc1_grad_output_quantizer: also set numerical configs for fc1_grad_output_quantizer
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_INPUT1
tex.FP8BwdTensors.GRAD_OUTPUT1
].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_INPUT1
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
if self.sequence_parallel and self.set_parallel_mode:
# fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
tex.FP8BwdTensors.GRAD_OUTPUT2
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
tex.FP8BwdTensors.GRAD_OUTPUT2
].amax_reduction_group = self.tp_group
def backward_dw(self):
......
......@@ -8,7 +8,6 @@ from functools import reduce
from operator import mul as multiply_op
import warnings
import functools
import torch
import transformer_engine_torch as tex
......@@ -17,15 +16,16 @@ from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version
from .base import (
get_workspace,
fill_userbuffers_buffer_for_all_gather,
get_dummy_wgrad,
get_ub,
get_workspace,
TransformerEngineBaseModule,
get_dummy_wgrad,
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ._common import noop_cat, _fix_gathered_fp8_transpose, WeightGradStore
from ._common import noop_cat, WeightGradStore
from ..fp8 import FP8GlobalStateManager
from ..utils import (
cast_if_needed,
......@@ -34,7 +34,6 @@ from ..utils import (
init_method_constant,
requires_grad,
needs_quantized_gemm,
is_non_tn_fp8_gemm_supported,
assert_dim_for_fp8_exec,
nvtx_range_pop,
nvtx_range_push,
......@@ -130,86 +129,98 @@ class _Linear(torch.autograd.Function):
out_features, in_features = weight.shape
assert inp.shape[-1] == in_features, "GEMM not possible"
# Configure tensor-parallel communication
tp_world_size = get_distributed_world_size(tp_group)
backward_needs_input = is_grad_enabled and weight.requires_grad
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.input_cast_comm")
inputmat = inp
inputmat_total = None
with_input_all_gather_nccl = (
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
)
own_quantized_input = False
# TODO(kwyss): Support FP8 allgather for FP8 block quantization.
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_input_gather = (
fp8 and with_input_all_gather_nccl and isinstance(input_quantizer, Float8BlockQuantizer)
) # Perform TP communication in high precision.
)
# Configure Userbuffers communication (comm+GEMM overlap)
ub_obj = None
ub_type = None
if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.RS
elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.AG
# ------------------------------------------------------
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# ------------------------------------------------------
nvtx_range_push(f"{nvtx_label}.input_cast_comm")
inputmat = inp # Input tensor to save for backward (maybe sharded)
inputmat_total = None # Input tensor to pass to GEMM (gathered)
own_quantized_input = False
if fp8:
assert_dim_for_fp8_exec(inputmat, weight)
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not (
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
):
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
if fp8 or debug:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if with_input_all_gather_nccl:
if force_hp_input_gather:
input_quantizer.set_usage(rowwise=True, columnwise=False)
inputmat_total, _ = gather_along_first_dim(
inputmat, tp_group, quantizer=input_quantizer
)
else:
if not isinstance(inputmat, QuantizedTensorBase):
columnwise_usage = backward_needs_input and isinstance(
input_quantizer, MXFP8Quantizer
)
# force_hp_input_gather should enforce this
assert not isinstance(input_quantizer, Float8BlockQuantizer)
input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
inputmat = input_quantizer(inputmat)
own_quantized_input = True
input_quantizer.set_usage(rowwise=True, columnwise=False)
inputmat_total, _ = gather_along_first_dim(
inputmat,
tp_group,
quantizer=input_quantizer,
)
if with_input_all_gather_nccl or ub_overlap_ag_fprop: # All-gather input tensor
# Cast local input tensor if needed
if fp8 or debug:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if not force_hp_input_gather and not isinstance(inputmat, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
# All-gather is not supported with FP8 column-wise data
input_quantizer.set_usage(columnwise=False)
inputmat = input_quantizer(inputmat)
own_quantized_input = True
else:
if (
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
and ub_bulk_dgrad
):
# reduce duplicated transpose in `_fix_gathered_fp8_transpose`
input_quantizer.set_usage(rowwise=True, columnwise=False)
inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP
# Initialize gathered input tensor
quantizer = None
if fp8 or debug:
quantizer = input_quantizer
quantizer.set_usage(rowwise=True, columnwise=False)
if with_input_all_gather_nccl: # Perform NCCL all-gather
inputmat_total, _ = gather_along_first_dim(
inputmat,
tp_group,
quantizer=quantizer,
)
elif ub_overlap_ag_fprop: # Initialize Userbuffers all-gather
inputmat_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj,
inputmat,
quantizer,
tp_group,
)
else: # Do not all-gather input tensor
if fp8 or debug:
if isinstance(inputmat, QuantizedTensorBase):
inputmat.update_usage(rowwise_usage=True)
else:
input_quantizer.set_usage(
rowwise=True,
columnwise=backward_needs_input,
)
if not isinstance(inputmat, QuantizedTensorBase):
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
inputmat = input_quantizer(inputmat)
own_quantized_input = True
elif backward_needs_input:
inputmat.update_usage(rowwise_usage=True, columnwise_usage=True)
inputmat_total = inputmat
else:
inputmat = cast_if_needed(inp, activation_dtype)
if with_input_all_gather_nccl:
inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
else:
inputmat_total = inputmat
inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP
inputmat_total = inputmat
nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
# ------------------------------------------------------
# Input tensor is ready for GEMM...
# ------------------------------------------------------
# Cast weight to expected dtype
# ------------------------------------------------------
# Prepare weight tensor
# ------------------------------------------------------
weightmat = weight
if fp8 or debug:
# Configure quantizer
if weight_quantizer is not None:
......@@ -220,7 +231,8 @@ class _Linear(torch.autograd.Function):
and not in_fp8_activation_recompute_phase()
)
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
# FP8 cast to workspace buffer
# Get quantized weight
update_workspace = is_first_microbatch is None or is_first_microbatch
weightmat = module.get_weight_workspace(
tensor=weight,
......@@ -231,19 +243,21 @@ class _Linear(torch.autograd.Function):
fsdp_group=fsdp_group,
workspace_dtype=activation_dtype,
)
weightmat.update_usage(rowwise_usage=True)
else:
weightmat = cast_if_needed(weightmat, activation_dtype)
weightmat = cast_if_needed(weightmat, activation_dtype) # Cast for AMP
# ------------------------------------------------------
# Weight tensor is ready for GEMM...
# ------------------------------------------------------
# Cast bias to expected dtype
bias_dtype = activation_dtype
if needs_quantized_gemm(inputmat_total) and activation_dtype == torch.float32:
# cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16
bias_dtype = torch.bfloat16
bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias
# Configure output quantizer
if output_quantizer is not None:
output_quantizer.set_usage(rowwise=True, columnwise=False)
# Calibrate quantizers if needed
if not fp8 and fp8_calibration:
if input_quantizer is not None:
......@@ -251,44 +265,74 @@ class _Linear(torch.autograd.Function):
if weight_quantizer is not None:
weight_quantizer.calibrate(weight)
ub_obj = None
ub_type = None
rs_out = None
out_dtype = activation_dtype
if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.RS
out_shape = [reduce(multiply_op, inp.shape[:-1]) // tp_world_size, out_features]
rs_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device)
elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.AG
if fp8:
assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer."
ub_obj.copy_into_buffer(inputmat_total, input_quantizer, local_chunk=True)
inputmat_total = ub_obj.get_buffer(input_quantizer)
nvtx_range_push(f"{nvtx_label}.gemm")
fprop_gemm_use_split_accumulator = _2X_ACC_FPROP
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator = _2X_ACC_FPROP
if fp8:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if hasattr(recipe, "fp8_gemm_fprop"):
fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
# Configure output quantizer
if output_quantizer is not None:
output_quantizer.set_usage(rowwise=True, columnwise=False)
out, *_, rs_out = general_gemm(
# Output buffer for Userbuffers reduce-scatter
reduce_scatter_out = None
if ub_overlap_rs_fprop:
out_shape = list(inp.shape)
out_shape[0] //= tp_world_size
out_shape[-1] = out_features
reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device)
# ------------------------------------------------------
# Forward GEMM
# Note: y = x * w^T
# ------------------------------------------------------
nvtx_range_push(f"{nvtx_label}.gemm")
gemm_out, *_, reduce_scatter_out = general_gemm(
weightmat,
inputmat_total,
get_workspace(),
quantization_params=output_quantizer,
out_dtype=out_dtype,
out_dtype=activation_dtype,
bias=bias,
use_split_accumulator=fprop_gemm_use_split_accumulator,
use_split_accumulator=use_split_accumulator,
ub=ub_obj,
ub_type=ub_type,
extra_output=rs_out,
extra_output=reduce_scatter_out,
)
nvtx_range_pop(f"{nvtx_label}.gemm")
# ------------------------------------------------------
# Finished forward GEMM...
# ------------------------------------------------------
# ------------------------------------------------------
# Prepare output tensor
# Note: Perform tensor-parallel communication
# ------------------------------------------------------
out = None
if ub_overlap_rs_fprop:
out = reduce_scatter_out
elif parallel_mode == "row" and tp_size > 1:
nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
out = gemm_out
if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel:
if symmetric_ar_type is not None:
out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
else:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
else:
out = gemm_out
# ------------------------------------------------------
# Output tensor is ready to return...
# ------------------------------------------------------
# ------------------------------------------------------
# Cache state for backward pass
# ------------------------------------------------------
if is_grad_enabled:
ctx.weight_quantizer = weight_quantizer
......@@ -388,19 +432,9 @@ class _Linear(torch.autograd.Function):
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
ctx.wgrad_store = wgrad_store
# Row Parallel Linear
if ub_overlap_rs_fprop:
out = rs_out
elif parallel_mode == "row":
nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel:
if symmetric_ar_type is not None:
out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
else:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
# ------------------------------------------------------
# Cached state for backward pass is ready...
# ------------------------------------------------------
return out
......@@ -414,28 +448,11 @@ class _Linear(torch.autograd.Function):
nvtx_label = f"{nvtx_label}.{ctx.ub_name}"
with torch.cuda.nvtx.range("_Linear_backward"):
if (
ctx.fp8
and any(
[
ctx.ub_overlap_ag,
ctx.ub_overlap_rs_dgrad,
ctx.ub_bulk_dgrad,
ctx.ub_bulk_wgrad,
]
)
and (ctx.fp8_recipe is not None)
):
if not ctx.fp8_recipe.float8_per_tensor_scaling():
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
saved_tensors = ctx.saved_tensors
inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking
restore_from_saved(ctx.tensor_objects, saved_tensors)
)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None
......@@ -465,69 +482,55 @@ class _Linear(torch.autograd.Function):
)
nvtx_range_pop(f"{nvtx_label}.fsdp_gather")
# Configure Userbuffers communication (comm+GEMM overlap)
ctx.ub_obj_gradout = None
ub_obj_dgrad = None
ub_obj_wgrad = None
ub_type_dgrad = None
ub_type_wgrad = None
dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
rs_out = None
dgrad_bulk = None
if ctx.ub_overlap_ag:
# Overlap grad_output all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG
elif ctx.ub_overlap_rs_dgrad:
# Overlap dgrad reduce-scatter with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.RS
rs_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
)
else:
if ctx.ub_bulk_dgrad:
# Overlap inputmat all-gather with dgrad compute
# NOTE: Copying into communication buffer will always prefer rowwise data,
# and will copy columnwise data if rowwise does not exist. In that case,
# the all-gather will apply to the leading dimension of the transpose,
# which then needs to be interleaved correctly before WGRAD.
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG
ub_obj_dgrad.copy_into_buffer(inputmat, ctx.input_quantizer, local_chunk=True)
if ctx.ub_bulk_wgrad:
# Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad")
ub_type_wgrad = tex.CommOverlapType.RS
ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer)
dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer)
# --------------------------------------------------
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# --------------------------------------------------
# Unmodified grad output tensor
grad_output_arg = grad_output
# Configure quantizer for grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
rowwise_usage = True
columnwise_usage = True
if ctx.ub_overlap_ag and isinstance(
ctx.grad_output_quantizer,
(Float8Quantizer, Float8CurrentScalingQuantizer),
):
# If data is in FP8 and communication is handled
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage = False
ctx.grad_output_quantizer.set_usage(
rowwise=rowwise_usage,
columnwise=columnwise_usage,
)
quantizer = ctx.grad_output_quantizer
quantizer.set_usage(rowwise=True, columnwise=True)
if ctx.ub_overlap_ag:
# Userbuffers only supports communication for one
# tensor usage at a time. Configure quantizer with
# usage for only dgrad GEMM.
quantizer.set_usage(columnwise=False)
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
(
grad_output,
......@@ -540,12 +543,21 @@ class _Linear(torch.autograd.Function):
)
nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
# Launch tensor-parallel communication for input tensor
# --------------------------------------------------
# Grad output tensor is ready for computing grad input...
# --------------------------------------------------
# --------------------------------------------------
# Prepare input tensor
# Note: Input tensor is needed for wgrad GEMM.
# Tensor-parallel communication is overlapped with dgrad
# GEMM.
# --------------------------------------------------
inputmat_total = None
inputmat_total_work = None
if ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad:
if ctx.backward_input_needs_gather:
quantizer = None
if ctx.fp8 or ctx.debug:
if (ctx.fp8 or ctx.debug) and not ctx.force_hp_input_gather:
quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
......@@ -553,72 +565,92 @@ class _Linear(torch.autograd.Function):
else:
# wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
gather_quantizer = None if ctx.force_hp_input_gather else quantizer
inputmat_total, inputmat_total_work = gather_along_first_dim(
inputmat,
ctx.tp_group,
async_op=True,
quantizer=gather_quantizer,
)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
if ctx.ub_bulk_dgrad:
inputmat_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_dgrad,
inputmat,
quantizer,
ctx.tp_group,
)
else:
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
inputmat_total, inputmat_total_work = gather_along_first_dim(
inputmat,
ctx.tp_group,
async_op=True,
quantizer=quantizer,
)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
else:
inputmat_total = inputmat
# --------------------------------------------------
# Input tensor is ready for computing grad weight...
# --------------------------------------------------
# Check whether to output wgrad GEMM directly into main grad
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# --------------------------------------------------
# Compute grad input tensor
# --------------------------------------------------
dgrad = None
dgrad_work = None
if ctx.requires_dgrad:
# Update quantizer
if ctx.grad_input_quantizer is not None:
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
# dgrad GEMM
nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD
# Make sure required data is available
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(rowwise_usage=True)
if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorBase):
weight_fp8.update_usage(columnwise_usage=True)
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator = _2X_ACC_DGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_gemm_use_split_accumulator = (
recipe.fp8_gemm_dgrad.use_split_accumulator
)
use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorBase):
weight_fp8.update_usage(
rowwise_usage=ctx.weight_quantizer.rowwise_usage,
columnwise_usage=ctx.weight_quantizer.columnwise_usage,
# Update grad input quantizer
if ctx.grad_input_quantizer is not None:
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
# Output buffers for Userbuffers reduce-scatter
gemm_out = None
reduce_scatter_out = None
if ctx.ub_overlap_rs_dgrad:
reduce_scatter_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device
)
elif ctx.ub_bulk_wgrad:
gemm_out = ub_obj_wgrad.get_buffer(local_chunk=False)
dgrad, *_, rs_out = general_gemm(
# dgrad GEMM
# Note: dx = dy * w
nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
gemm_out, *_, reduce_scatter_out = general_gemm(
weight_fp8,
grad_output,
get_workspace(),
layout="NN",
grad=True,
quantization_params=ctx.grad_input_quantizer,
out=dgrad_bulk,
out=gemm_out,
out_dtype=ctx.activation_dtype,
use_split_accumulator=dgrad_gemm_use_split_accumulator,
use_split_accumulator=use_split_accumulator,
ub=ub_obj_dgrad,
ub_type=ub_type_dgrad,
extra_output=rs_out,
extra_output=reduce_scatter_out,
bulk_overlap=ctx.ub_bulk_dgrad,
)
nvtx_range_pop(f"{nvtx_label}.dgrad_gemm")
# Launch tensor-parallel communication
# Prepare grad input tensor
# Note: Perform tensor-parallel communication
if ctx.ub_overlap_rs_dgrad:
dgrad = rs_out
elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad:
dgrad = reduce_scatter_out
elif ctx.ub_bulk_wgrad:
dgrad = ub_obj_wgrad.get_buffer(local_chunk=True)
elif ctx.parallel_mode == "column" and ctx.tp_size > 1:
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad")
dgrad = gemm_out
if ctx.sequence_parallel:
dgrad, dgrad_work = reduce_scatter_along_first_dim(
dgrad,
......@@ -628,41 +660,55 @@ class _Linear(torch.autograd.Function):
else:
dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad")
else:
dgrad = gemm_out
# --------------------------------------------------
# Grad input tensor has been computed...
# --------------------------------------------------
# --------------------------------------------------
# Compute grad weight
# --------------------------------------------------
# Compute grad weight tensor
wgrad = None
if ctx.requires_wgrad:
# Synchronize tensor-parallel communication for input tensor
if ctx.ub_bulk_dgrad:
inputmat_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer)
if ctx.fp8:
if inputmat._data is None:
# All-gather executed on columnwise data and result is in rowwise data,
# so we need to fix the interleaving before WGRAD.
inputmat_total = _fix_gathered_fp8_transpose(
inputmat_total, ctx.tp_size
)
elif not is_non_tn_fp8_gemm_supported():
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
inputmat_total._create_transpose()
# Prepare input tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if inputmat_total_work is not None:
inputmat_total_work.wait()
inputmat_total_work = None
if ctx.input_quantizer is not None and not isinstance(
inputmat_total, QuantizedTensorBase
):
# Async gather in BF16 does not asynchronously
# call quantizer after gather.
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
inputmat_total = ctx.input_quantizer(inputmat_total)
# Make sure GEMM inputs have required data
if isinstance(inputmat_total, QuantizedTensorBase):
inputmat_total.update_usage(columnwise_usage=True)
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(columnwise_usage=True)
if ctx.fp8 or ctx.debug:
if isinstance(inputmat_total, QuantizedTensorBase):
inputmat_total.update_usage(columnwise_usage=True)
else:
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
inputmat_total = ctx.input_quantizer(inputmat_total)
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output, _ = gather_along_first_dim(
grad_output_arg,
ctx.tp_group,
quantizer=ctx.grad_output_quantizer,
)
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(columnwise_usage=True)
else:
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output = ctx.grad_output_quantizer(grad_output)
# Figure out whether to use split accumulator
use_split_accumulator = _2X_ACC_WGRAD
......@@ -671,54 +717,95 @@ class _Linear(torch.autograd.Function):
if hasattr(recipe, "fp8_gemm_wgrad"):
use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator
# Output buffer for overlapping grad input
# Figure out whether to output wgrad GEMM directly into main grad
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# Output buffer for overlapping FP8 grad input
# reduce-scatter with wgrad GEMM
reduce_scatter_out = None
if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf():
rs_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
reduce_scatter_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device
)
# wgrad GEMM
# Note: Fuse with bgrad computation if needed
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
general_gemm_wgrad = functools.partial(
general_gemm,
out_dtype=(
# Arguments to include in wgrad GEMM closure
wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": (
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
),
workspace=get_workspace(),
layout="NT",
grad=True,
bias=(bias if (grad_bias is None and not ctx.fp8) else None),
out=main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad,
quantization_params=ctx.grad_weight_quantizer,
ub=ub_obj_wgrad,
ub_type=ub_type_wgrad,
extra_output=rs_out,
bulk_overlap=ctx.ub_bulk_wgrad,
)
"quantization_params": ctx.grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad,
"layout": "NT",
"out": main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": (bias if (grad_bias is None and not ctx.fp8) else None),
"use_split_accumulator": use_split_accumulator,
"grad": True,
"ub": ub_obj_wgrad,
"ub_type": ub_type_wgrad,
"extra_output": reduce_scatter_out,
"bulk_overlap": ctx.ub_bulk_wgrad,
}
def wgrad_gemm(
x: torch.Tensor,
dy: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform wgrad GEMM: dw = dy^T * x
May be fused with bgrad computation.
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
return dw, db
# Choose whether to call wgrad GEMM now or delay
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([inputmat_total, grad_output], general_gemm_wgrad)
if (
wgrad_gemm_kwargs["ub"] is not None
or wgrad_gemm_kwargs["ub_type"] is not None
or wgrad_gemm_kwargs["extra_output"] is not None
or wgrad_gemm_kwargs["bulk_overlap"]
):
raise NotImplementedError(
"Delayed weight grad computation is not supported "
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx.wgrad_store.put([inputmat_total, grad_output], wgrad_gemm)
else:
wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(inputmat_total, grad_output)
# Call wgrad GEMM now
wgrad, grad_bias_ = wgrad_gemm(inputmat_total, grad_output)
# Update grad bias if needed
if grad_bias is None:
grad_bias = grad_bias_
del grad_bias_
# Deallocate input tensor
# Deallocate input tensor if permitted
if ctx.owns_input:
clear_tensor_data(inputmat_total)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad:
if ub_obj_wgrad.is_fp8_ubuf():
dgrad = rs_out
dgrad = reduce_scatter_out
else:
dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, local_chunk=True)
dgrad = ub_obj_wgrad.get_buffer(local_chunk=True).clone()
# --------------------------------------------------
# Grad weight has been computed...
# --------------------------------------------------
# Don't return grad bias if not needed
if not ctx.use_bias:
......@@ -756,6 +843,7 @@ class _Linear(torch.autograd.Function):
else:
wgrad = None
# Update FP8 scaling factors if needed
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors")
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
......
......@@ -534,7 +534,9 @@ class BasicLinear(BasicOperation):
# Configure input tensor for backward pass
if with_quantized_compute and isinstance(x_local, QuantizedTensor):
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
return y, x_local, w
......@@ -622,7 +624,10 @@ class BasicLinear(BasicOperation):
# Check datatype
if dtype is None:
dtype = weight.dtype
if weight is not None:
dtype = weight.dtype
else:
dtype = grad_output.dtype
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
......@@ -814,7 +819,7 @@ class BasicLinear(BasicOperation):
x_async = None
dy_async = None
# Check grad input tensor
# Check grad weight tensor
dw = grad_weight
dw_dtype = dtype
if dw is None:
......
......@@ -4,30 +4,27 @@
"""Linear layer backward with Userbuffers communication."""
# pylint: skip-file ### TODO Debug Userbuffers support
from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Optional
from typing import Optional
import warnings
import torch
from transformer_engine_torch import CommOverlapAlgo
from transformer_engine_torch import CommOverlapType
from ...cpp_extensions import general_gemm
from ...distributed import get_distributed_world_size
from ...float8_tensor import Float8Tensor
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype
from ...module.base import get_ub, get_workspace
from ...distributed import gather_along_first_dim, get_distributed_world_size
from ...module.base import (
fill_userbuffers_buffer_for_all_gather,
get_ub,
get_workspace,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data
from ..basic import BasicLinear, Bias, ReduceScatter
from ..op import FusedOperation, FusibleOperation, OperationContext
from .._common import (
convert_tensor,
get_fp8_meta_from_fp8_tensor,
is_float8_tensor,
reshape,
)
class UserbuffersBackwardLinear(FusedOperation):
......@@ -47,9 +44,6 @@ class UserbuffersBackwardLinear(FusedOperation):
reduce_scatter: Optional[ReduceScatter],
) -> None:
### TODO Debug Userbuffers support
raise NotImplementedError("Userbuffers support has been broken by recent refactors")
# Basic operations that comprise this fused operation
op_idxs = {"linear": None, "bias": None, "reduce_scatter": None}
ops = []
......@@ -89,9 +83,8 @@ class UserbuffersBackwardLinear(FusedOperation):
grad_output: torch.Tensor,
input: Optional[torch.Tensor], # pylint: disable=redefined-builtin
weight: Optional[torch.Tensor],
input_dims: Iterable[int],
weight_dims: Iterable[int],
*,
input_requires_grad: bool = True,
weight_requires_grad: bool = True,
bias_requires_grad: bool = False,
device: Optional[torch.device] = None,
......@@ -102,11 +95,11 @@ class UserbuffersBackwardLinear(FusedOperation):
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
tensor_parallel_size: Optional[int] = None,
sequence_parallel: bool = False,
with_fp8_compute: bool = False,
input_fp8_meta: Optional[dict[str, Any]] = None,
weight_fp8_meta: Optional[dict[str, Any]] = None,
grad_output_fp8_meta: Optional[dict[str, Any]] = None,
grad_input_fp8_meta: Optional[dict[str, Any]] = None,
with_quantized_compute: bool = False,
input_quantizer: Optional[Quantizer] = None,
weight_quantizer: Optional[Quantizer] = None,
grad_output_quantizer: Optional[Quantizer] = None,
grad_input_quantizer: Optional[Quantizer] = None,
ub_comm_name: str,
) -> tuple[torch.Tensor, Optional[torch.Tensor], dict]:
"""Functional API for backward pass
......@@ -121,10 +114,6 @@ class UserbuffersBackwardLinear(FusedOperation):
weight: torch.Tensor, optional
Weight tensor. Required to compute loss gradient w.r.t.
input.
input_dims: iterable of int
Input tensor dimensions
weight_dims: iterable of int
Weight tensor dimensions
weight_requires_grad: bool
Whether to compute loss gradient w.r.t. weight tensor
bias_requires_grad: bool
......@@ -146,21 +135,18 @@ class UserbuffersBackwardLinear(FusedOperation):
parallelism, i.e. distributing input or output tensors
along outer dimension (sequence or batch dim) when not
distributing along inner dimension (embedding dim)
with_fp8_compute: bool, default = `False`
Whether to perform compute in FP8
input_fp8_meta: dict, optional
FP8 metadata for casting input tensor to FP8. Required for
FP8 compute if input is not already in FP8.
weight_fp8_meta: dict, optional
FP8 metadata for casting weight tensor to FP8. Required for
FP8 compute if weight is not already in FP8.
grad_output_fp8_meta: dict, optional
FP8 metadata for casting loss gradient w.r.t. output
tensor to FP8. Required if output grad is not already in
FP8.
grad_input_fp8_meta: dict, optional
FP8 metadata for casting loss gradient w.r.t. input
tensor to FP8
with_quantized_compute: bool, default = `False`
Whether to perform compute with quantized data.
input_quantizer: Quantizer, optional
Builder class for quantized input tensor.
weight_quantizer: Quantizer, optional
Builder class for quantized weight tensor.
grad_output_quantizer: Quantizer, optional
Builder class for quantized loss gradient w.r.t. output
tensor.
grad_input_quantizer: Quantizer, optional
Builder class for quantized loss gradient w.r.t. input
tensor.
ub_comm_name: str
Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is
used to access the corresponding Userbuffers communicators
......@@ -183,37 +169,24 @@ class UserbuffersBackwardLinear(FusedOperation):
# Check device
if device is None:
device = weight.device
if weight is not None:
device = weight.device
else:
device = grad_output.device
device = canonicalize_device(device)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
# Check datatype
if dtype is None:
dtype = weight.dtype
if weight is not None:
dtype = weight.dtype
else:
dtype = grad_output.dtype
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
# Input tensor dims
output_dims = tuple(grad_output.size())
input_dims = tuple(input_dims)
weight_dims = tuple(weight_dims)
if len(weight_dims) != 2:
raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})")
if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]:
raise ValueError(
f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={weight_dims}) "
"are not compatible"
)
if weight_dims[0] != output_dims[-1]:
raise ValueError(
f"Grad output tensor (shape={output_dims}) "
f"and weight tensor (shape={weight_dims}) "
"are not compatible"
)
# Check tensor parallel group
if tensor_parallel_size is None:
tensor_parallel_size = get_distributed_world_size(tensor_parallel_group)
......@@ -227,373 +200,283 @@ class UserbuffersBackwardLinear(FusedOperation):
if not sequence_parallel:
raise RuntimeError(f"Invalid configuration for Userbuffers ({sequence_parallel=})")
# Check if FP8 is enabled
if with_fp8_compute:
if grad_output_fp8_meta is None and not is_float8_tensor(grad_output):
raise ValueError("No FP8 metadata was provided for casting output gradient to FP8")
# dgrad GEMM is required
if not input_requires_grad:
warnings.warn(
"Linear input doesn't require gradient, "
"but Userbuffers implementation requires dgrad GEMM."
)
input_requires_grad = True
# Check quantizers
if with_quantized_compute:
if weight_requires_grad and input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if input_requires_grad and weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
if grad_output_quantizer is None:
raise ValueError("Missing quantizer for grad output tensor")
if grad_input_quantizer is not None:
raise ValueError("Quantized grad input is not supported")
else:
input_fp8_meta = None
weight_fp8_meta = None
grad_output_fp8_meta = None
grad_input_fp8_meta = None
with_fp8_grad_input = (
with_fp8_compute
and tensor_parallel_mode != "column"
and grad_input_fp8_meta is not None
)
input_quantizer = None
weight_quantizer = None
grad_output_quantizer = None
grad_input_quantizer = None
# Get Userbuffers communicators and algorithms
# Note: communication patterns are (1) overlap dy all-gather
# Get Userbuffers communicators
# Note: Communication patterns are (1) overlap dy all-gather
# with dgrad GEMM, (2) overlap x all-gather with dgrad GEMM
# and dx reduce-scatter with wgrad GEMM, (3) overlap dx
# reduce-scatter with dgrad GEMM.
with_ub_all_gather_dy = False
with_ub_reduce_scatter_dx = False
with_ub_all_gather_x = False
ub_comm_dy = None
ub_comm_dx = None
ub_comm_x = None
ub_algo_dy = None
ub_algo_dx = None
ub_algo_x = None
# reduce-scatter with dgrad GEMM
ub_comm_dgrad = None
ub_comm_wgrad = None
ub_type_dgrad = None
ub_type_wgrad = None
with_bulk_overlap = False
with_dgrad_all_gather_dy = False
with_dgrad_reduce_scatter_dx = False
with_dgrad_all_gather_x = False
with_wgrad_reduce_scatter_dx = False
if tensor_parallel_mode == "row":
with_ub_all_gather_dy = True
ub_comm_dy = get_ub(ub_comm_name + "_dgrad")
if with_fp8_compute and ub_comm_dy.is_atomic_gemm():
ub_algo_dy = CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
ub_algo_dy = CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad")
ub_type_dgrad = CommOverlapType.AG
with_dgrad_all_gather_dy = True
elif tensor_parallel_mode == "column":
with_ub_reduce_scatter_dx = True
if weight_requires_grad:
with_ub_all_gather_x = True
ub_comm_dx = get_ub(ub_comm_name + "_wgrad")
ub_comm_x = get_ub(ub_comm_name + "_dgrad")
ub_algo_dx = CommOverlapAlgo.BULK_OVERLAP_RS
ub_algo_x = CommOverlapAlgo.BULK_OVERLAP_AG
if input_requires_grad and weight_requires_grad:
with_bulk_overlap = True
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad")
ub_type_dgrad = CommOverlapType.AG
with_dgrad_all_gather_x = True
ub_comm_wgrad = get_ub(ub_comm_name + "_wgrad")
ub_type_wgrad = CommOverlapType.RS
with_wgrad_reduce_scatter_dx = True
if ub_comm_wgrad.is_fp8_ubuf():
raise RuntimeError(
"Userbuffers reduce-scatter is not supported with FP8 buffers"
)
else:
with_ub_all_gather_x = False
ub_comm_dx = get_ub(ub_comm_name + "_dgrad")
is_atomic_gemm = with_fp8_compute and ub_comm_dx.is_atomic_gemm()
ub_algo_dx = {
(True, True): CommOverlapAlgo.ATOMIC_GEMM_RS_P2P,
(True, False): CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P,
(False, True): CommOverlapAlgo.ATOMIC_GEMM_RS,
(False, False): CommOverlapAlgo.SPLIT_PIPELINED_RS,
}[(ub_comm_dx.is_p2p_overlap(), is_atomic_gemm)]
# Check grad output tensor
# Note: Possibly fuse cast with computing grad bias
dy_local = reshape(
grad_output,
(-1, output_dims[-1]),
device=device,
dtype=dtype,
)
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad")
ub_type_dgrad = CommOverlapType.RS
with_dgrad_reduce_scatter_dx = True
if ub_comm_dgrad.is_fp8_ubuf():
raise RuntimeError(
"Userbuffers reduce-scatter is not supported with FP8 buffers"
)
# Compute grad bias if needed
db = None
db_async = None
if bias_requires_grad and with_fp8_compute and with_ub_all_gather_dy:
# We don't have a grad bias impl that takes FP8 input. For
# cases where we cast to FP8 and all-gather, it's better
# to compute the grad bias on ungathered, non-FP8 values.
db = dy_local.sum(dim=0)
db_async = torch.distributed.all_reduce(
db,
group=tensor_parallel_group,
async_op=True,
)
if with_fp8_compute and not is_float8_tensor(dy_local):
fp8_dtype = get_fp8_te_dtype(
grad_output_fp8_meta["recipe"],
fprop_tensor=False,
)
if bias_requires_grad and db is None:
# Fused cast-transpose-bgrad
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False)
fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=device)
db, data, data_transpose = fp8_cast_transpose_bgrad_fused(
dy_local,
grad_output_fp8_meta[fp8_meta_key],
0,
fp8_dtype,
scale_inv=fp8_scale_inv,
)
if with_ub_all_gather_dy:
data = ub_comm_dy.get_ubuf_output(0).copy_(data)
dy_local = Float8Tensor(
data=data,
fp8_meta=grad_output_fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
dtype=dtype,
data_transpose=data_transpose,
if bias_requires_grad:
db = grad_output.sum(tuple(range(grad_output.dim() - 1)))
if tensor_parallel_mode == "row":
db_async = torch.distributed.all_reduce(
db,
group=tensor_parallel_group,
async_op=True,
)
else:
dy_local = Float8Tensor.to_float8(
dy_local,
fp8_meta=grad_output_fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
data=(ub_comm_dy.get_ubuf_output(0) if with_ub_all_gather_dy else None),
with_transpose_cache=(not with_ub_all_gather_dy),
# Cast grad output tensor dtype if needed
dy_local = grad_output
if with_quantized_compute:
if not isinstance(dy_local, QuantizedTensorBase):
with_columnwise = weight_requires_grad
if (
with_columnwise
and with_dgrad_all_gather_dy
and not isinstance(grad_output_quantizer, MXFP8Quantizer)
):
with_columnwise = False
grad_output_quantizer.set_usage(
rowwise=True,
columnwise=with_columnwise,
)
elif not with_fp8_compute and is_float8_tensor(dy_local):
if with_ub_all_gather_dy:
ub_local_buffer = ub_comm_dy.get_ubuf_output(0)
dy_local = ub_local_buffer.copy_(dy_local)
else:
dy_local = dy_local.dequantize()
if bias_requires_grad and db is None and with_fp8_compute and with_ub_all_gather_dy:
# We don't have a fused grad bias impl that takes FP8
# input. For cases where we cast to FP8 and all-gather,
# it's better to compute the grad bias on ungathered,
# non-FP8 values.
db = dy_local.sum(dim=0)
db_async = torch.distributed.all_reduce(
db,
group=tensor_parallel_group,
async_op=True,
)
dy_local = grad_output_quantizer(dy_local)
else:
if isinstance(dy_local, QuantizedTensorBase):
dy_local = dy_local.dequantize(dtype=dtype)
elif dy_local.dtype != dtype:
dy_local = dy_local.to(dtype=dtype)
# Cast weight tensor dtype if needed
if weight is None:
raise ValueError("Weight tensor is required to compute input grad")
w = weight
if with_quantized_compute:
if not isinstance(w, QuantizedTensorBase):
weight_quantizer.set_usage(columnwise=True)
w = weight_quantizer(w)
else:
if isinstance(w, QuantizedTensorBase):
w = w.dequantize(dtype=dtype)
elif w.dtype != dtype:
w = w.to(dtype=dtype)
# Check input tensor
# Cast input tensor dtype if needed
x_local = None
if weight_requires_grad:
x_local = reshape(
input,
(-1, input_dims[-1]),
device=device,
dtype=dtype,
)
if with_fp8_compute and not is_float8_tensor(x_local):
fp8_dtype = get_fp8_te_dtype(
input_fp8_meta["recipe"],
fprop_tensor=True,
if input is None:
raise ValueError("Input tensor is required to compute weight grad")
x_local = input
if with_quantized_compute:
if not isinstance(x_local, QuantizedTensorBase):
input_quantizer.set_usage(columnwise=True)
x_local = input_quantizer(x_local)
else:
if isinstance(x_local, QuantizedTensorBase):
x_local = x_local.dequantize(dtype=dtype)
elif x_local.dtype != dtype:
x_local = x_local.to(dtype=dtype)
# dgrad GEMM
dx_local = None
dx = None
dy = None
x = None
if input_requires_grad:
# Initialize grad output
if with_dgrad_all_gather_dy:
if grad_output_quantizer is not None:
grad_output_quantizer.set_usage(rowwise=True, columnwise=False)
dy, _ = fill_userbuffers_buffer_for_all_gather(
ub_comm_dgrad,
dy_local,
grad_output_quantizer,
tensor_parallel_group,
)
x_local = Float8Tensor.to_float8(
else:
dy = dy_local
# Construct grad input tensor if needed
if with_dgrad_reduce_scatter_dx or with_wgrad_reduce_scatter_dx:
dx_size = list(dy.size())
dx_size[-1] = w.size(-1)
dx_local_size = list(dx_size)
dx_local_size[0] //= tensor_parallel_size
if with_dgrad_reduce_scatter_dx:
dx_local = torch.empty(
dx_local_size,
dtype=dtype,
device=device,
)
elif with_wgrad_reduce_scatter_dx:
dx_local = ub_comm_wgrad.get_buffer(
local_chunk=True,
shape=dx_local_size,
)
dx = ub_comm_wgrad.get_buffer(
local_chunk=False,
shape=dx_size,
)
# Initialize input tensor if needed
if with_dgrad_all_gather_x:
if input_quantizer is not None:
input_quantizer.set_usage(rowwise=False, columnwise=True)
x, _ = fill_userbuffers_buffer_for_all_gather(
ub_comm_dgrad,
x_local,
fp8_meta=input_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
data=(ub_comm_x.get_ubuf_output(0) if with_ub_all_gather_x else None),
with_transpose_cache=(not with_ub_all_gather_x),
input_quantizer,
tensor_parallel_group,
)
elif not with_fp8_compute and is_float8_tensor(x_local):
if with_ub_all_gather_x:
ub_local_buffer = ub_comm_x.get_ubuf_output(0)
x_local = ub_local_buffer.copy_(x_local)
else:
x_local = x_local.dequantize()
# Check weight tensor
w = convert_tensor(
weight,
device=device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
if with_fp8_compute and not is_float8_tensor(w):
fp8_dtype = get_fp8_te_dtype(
weight_fp8_meta["recipe"],
fprop_tensor=True,
)
w = Float8Tensor.to_float8(
# Perform dgrad GEMM
dx, *_ = general_gemm(
w,
fp8_meta=weight_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
with_transpose_cache=True,
dy,
get_workspace(),
out_dtype=dtype,
quantization_params=grad_input_quantizer,
layout="NN",
out=dx,
use_split_accumulator=_2X_ACC_DGRAD,
grad=True,
ub=ub_comm_dgrad,
ub_type=ub_type_dgrad,
extra_output=dx_local if with_dgrad_reduce_scatter_dx else None,
bulk_overlap=with_bulk_overlap,
)
elif not with_fp8_compute and is_float8_tensor(w):
w = w.dequantize()
# Initialize buffers for UB all-gather if needed
dy = dy_local
x = x_local
if with_ub_all_gather_dy:
ub_local_buffer = ub_comm_dy.get_ubuf_output(0)
ub_global_buffer = ub_comm_dy.get_ubuf_output(1)
if with_fp8_compute:
dy = Float8Tensor.make_like(dy_local, data=ub_global_buffer)
if dy_local._data.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(dy_local._data)
else:
dy = ub_global_buffer
if dy_local.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(dy_local)
if with_ub_all_gather_x:
ub_local_buffer = ub_comm_x.get_ubuf_output(0)
ub_global_buffer = ub_comm_x.get_ubuf_output(1)
if with_fp8_compute:
x = Float8Tensor.make_like(x_local, data=ub_global_buffer)
if x_local._data.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(x_local._data)
else:
x = ub_global_buffer
if x_local.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(x_local)
if not (with_dgrad_reduce_scatter_dx or with_wgrad_reduce_scatter_dx):
dx_local = dx
# Construct grad input tensor
dx = None
dx_local = None
if with_ub_reduce_scatter_dx:
# Initialize buffers for UB reduce-scatter
dx = ub_comm_dx.get_ubuf_output(1)
ub_local_buffer = ub_comm_dx.get_ubuf_output(0)
if with_ub_all_gather_x:
dx_local = ub_local_buffer
else:
dx_local = torch.empty_like(ub_local_buffer)
else:
# Allocate grad input tensor
if with_fp8_grad_input:
fp8_dtype = get_fp8_te_dtype(
grad_input_fp8_meta["recipe"],
fprop_tensor=False,
)
data = torch.empty(
(dy.size(0), w.size(-1)),
dtype=torch.uint8,
device=device,
)
dx = Float8Tensor(
data=data,
fp8_meta=grad_input_fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
dtype=dtype,
# wgrad GEMM
dw = None
if weight_requires_grad:
# Initialize grad output
if tensor_parallel_mode == "row" and isinstance(grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, MXFP8 does not
# allow reusing the grad output that was gathered for
# the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
dy, _ = gather_along_first_dim(
grad_output,
tensor_parallel_group,
quantizer=grad_output_quantizer,
)
else:
dx = torch.empty(
(dy.size(0), w.size(-1)),
dtype=dtype,
device=device,
if tensor_parallel_mode == "column":
dy = dy_local
if dy is None:
raise RuntimeError(
"wgrad GEMM requires grad output tensor, which has not been initialized"
)
dx_local = dx
if isinstance(dy, QuantizedTensorBase):
dy.update_usage(rowwise_usage=False, columnwise_usage=True)
# Allocate grad input tensor
if grad_weight is None:
if accumulate_into_grad_weight:
raise ValueError(
"Attempted to accumulate into grad weight bufferwithout providing grad weight"
# Initialize input tensor
if tensor_parallel_mode == "row":
x = x_local
if x is None:
raise RuntimeError(
"wgrad GEMM requires input tensor, which has not been initialized"
)
grad_weight = torch.empty(
weight_dims,
dtype=dtype,
device=device,
memory_format=torch.contiguous_format,
)
if isinstance(x, QuantizedTensorBase):
x.update_usage(rowwise_usage=False, columnwise_usage=True)
# Check grad weight tensor
dw = grad_weight
dw_dtype = dtype
if dw is None:
if accumulate_into_grad_weight:
raise ValueError(
"Attempted to accumulate into grad weight tensor "
"without providing grad weight tensor"
)
else:
dw_dtype = dw.dtype
# Perform dgrad GEMM
if with_fp8_compute:
kwargs = {"out": dx, "use_split_accumulator": True}
if with_ub_all_gather_dy:
kwargs["ub_algo"] = ub_algo_dy
kwargs["ub"] = ub_comm_dy
elif with_ub_all_gather_x:
kwargs["ub_algo"] = ub_algo_x
kwargs["ub"] = ub_comm_x
elif with_ub_reduce_scatter_dx:
kwargs["ub_algo"] = ub_algo_dx
kwargs["ub"] = ub_comm_dx
kwargs["extra_output_tensor"] = dx_local
if with_fp8_grad_input:
fp8_meta, fp8_meta_index = get_fp8_meta_from_fp8_tensor(dx)
kwargs.update(
{
"out": dx._data,
"out_index": fp8_meta_index,
"fp8_meta_tensor": fp8_meta,
"D_dtype": dx._fp8_dtype,
}
)
fp8_gemm(
w.transpose_2d(),
w._scale_inv,
0,
w._fp8_dtype,
dy._data,
dy._scale_inv,
0,
dy._fp8_dtype,
dy.dtype,
get_workspace(),
**kwargs,
)
else:
kwargs = {"grad": True, "layout": "NN", "out": dx}
if with_ub_all_gather_dy:
kwargs["ub_algo"] = ub_algo_dy
kwargs["ub"] = ub_comm_dy
elif with_ub_all_gather_x:
kwargs["ub_algo"] = ub_algo_x
kwargs["ub"] = ub_comm_x
elif with_ub_reduce_scatter_dx:
kwargs["ub_algo"] = ub_algo_dx
kwargs["ub"] = ub_comm_dx
kwargs["extra_output_tensor"] = dx_local
gemm(w, dy, dx.dtype, get_workspace(), **kwargs)
grad_input = reshape(dx_local, input_dims)
# Perform wgrad GEMM
if not weight_requires_grad:
pass
elif with_fp8_compute:
kwargs = {
"accumulate": accumulate_into_grad_weight,
"out": grad_weight,
"use_split_accumulator": True,
}
if with_ub_reduce_scatter_dx:
kwargs["ub_algo"] = ub_algo_dx
kwargs["ub"] = ub_comm_dx
fp8_gemm(
x.transpose_2d(),
x._scale_inv,
0,
x._fp8_dtype,
dy.transpose_2d(),
dy._scale_inv,
0,
dy._fp8_dtype,
grad_weight.dtype,
get_workspace(),
**kwargs,
)
else:
kwargs = {
"accumulate": accumulate_into_grad_weight,
"layout": "NT",
"grad": True,
"use_bias": bias_requires_grad,
"out": grad_weight,
}
if with_ub_reduce_scatter_dx:
kwargs["ub_algo"] = ub_algo_dx
kwargs["ub"] = ub_comm_dx
grad_weight, db, _ = gemm(
# Perform wgrad GEMM
dw, *_ = general_gemm(
x,
dy,
grad_weight.dtype,
get_workspace(),
**kwargs,
out_dtype=dw_dtype,
accumulate=accumulate_into_grad_weight,
layout="NT",
out=dw,
use_split_accumulator=_2X_ACC_WGRAD,
grad=True,
ub=ub_comm_wgrad,
ub_type=ub_type_wgrad,
bulk_overlap=with_bulk_overlap,
)
# Bulk overlap reduce-scatter with non-FP8 buffer is
# in-place. Need to copy grad input tensor to avoid data
# corruption in Userbuffers buffer.
if with_wgrad_reduce_scatter_dx:
dx_local = dx_local.clone()
# Compute grad bias if needed
if db_async is not None:
db_async.wait()
if bias_requires_grad:
if db is None:
db = dy.sum(dim=0)
extra_outputs["grad_bias"] = db
return grad_input, grad_weight, extra_outputs
return dx_local, dw, extra_outputs
def fuser_backward(
self,
......@@ -633,40 +516,24 @@ class UserbuffersBackwardLinear(FusedOperation):
else:
accumulate_into_main_grad = False
# Hackily workaround Userbuffers bug with non-FP8 dgrad
# reduce-scatter overlap
weight_requires_grad = linear_op_ctx.weight_requires_grad
if not linear_op_ctx.with_fp8_compute and not weight_requires_grad:
warnings.warn(
"There is a correctness bug when using Userbuffers "
"to overlap a dgrad reduce-scatter with a non-FP8 dgrad GEMM. "
"Hackily working around by overlapping dgrad reduce-scatter "
"with wgrad GEMM, even though wgrad isn't needed. "
"Please contact Transformer Engine team "
"if you encounter this use-case."
)
weight_requires_grad = True
# Linear backward pass
retval = UserbuffersBackwardLinear._functional_backward(
grad_output=grad_output,
input=x_local,
weight=linear_op.weight,
input_dims=linear_op_ctx.input_dims,
weight_dims=linear_op.weight.size(),
weight_requires_grad=weight_requires_grad,
weight_requires_grad=linear_op_ctx.weight_requires_grad,
bias_requires_grad=(bias_op is not None),
device=linear_op.device,
dtype=linear_op_ctx.dtype,
grad_weight=grad_weight,
accumulate_into_grad_weight=accumulate_into_main_grad,
tensor_parallel_mode=self.tensor_parallel_mode,
tensor_parallel_group=self.tensor_parallel_group,
sequence_parallel=self.sequence_parallel,
with_fp8_compute=linear_op_ctx.with_fp8_compute,
weight_fp8_meta=linear_op_ctx.weight_fp8_meta,
grad_output_fp8_meta=linear_op_ctx.grad_output_fp8_meta,
grad_input_fp8_meta=linear_op_ctx.grad_input_fp8_meta,
with_quantized_compute=linear_op_ctx.with_quantized_compute,
input_quantizer=linear_op_ctx.input_quantizer,
weight_quantizer=linear_op_ctx.weight_quantizer,
grad_output_quantizer=linear_op_ctx.grad_output_quantizer,
grad_input_quantizer=None, # Not supported
ub_comm_name=linear_op._userbuffers_options["comm_name"],
)
grad_input, grad_weight, extra_outputs = retval
......@@ -707,8 +574,6 @@ def fuse_userbuffers_backward_linear(
"""
return ops ### TODO Debug Userbuffers support
# Return immediately if environment is not distributed
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
return ops
......
......@@ -4,20 +4,25 @@
"""Linear layer forward with Userbuffers communication."""
# pylint: skip-file ### TODO Debug Userbuffers support
from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Optional
import torch
from transformer_engine_torch import CommOverlapAlgo
from transformer_engine_torch import CommOverlapType
from ...cpp_extensions import general_gemm
from ...distributed import get_distributed_world_size
from ...float8_tensor import Float8Tensor
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype
from ...module.base import get_ub, get_workspace
from ...fp8 import FP8GlobalStateManager
from ...module.base import (
fill_userbuffers_buffer_for_all_gather,
get_ub,
get_workspace,
_2X_ACC_FPROP,
)
from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer
from ...tensor.float8_tensor import Float8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ...utils import canonicalize_device, canonicalize_dtype
from ..basic import BasicLinear, Bias, ReduceScatter
from ..op import (
......@@ -26,12 +31,6 @@ from ..op import (
FusibleOperation,
OperationContext,
)
from .._common import (
convert_tensor,
get_fp8_meta_from_fp8_tensor,
is_float8_tensor,
reshape,
)
class UserbuffersForwardLinear(FusedOperation):
......@@ -51,9 +50,6 @@ class UserbuffersForwardLinear(FusedOperation):
reduce_scatter: Optional[ReduceScatter],
) -> None:
### TODO Debug Userbuffers support
raise NotImplementedError("Userbuffers support has been broken by recent refactors")
# Basic operations that comprise this fused operation
op_idxs = {"linear": 0, "bias": None, "reduce_scatter": None}
ops = [linear]
......@@ -98,10 +94,10 @@ class UserbuffersForwardLinear(FusedOperation):
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
tensor_parallel_size: Optional[int] = None,
sequence_parallel: bool = False,
with_fp8_compute: bool = False,
input_fp8_meta: Optional[dict[str, Any]] = None,
weight_fp8_meta: Optional[dict[str, Any]] = None,
output_fp8_meta: Optional[dict[str, Any]] = None,
with_quantized_compute: bool = False,
input_quantizer: Optional[Quantizer] = None,
weight_quantizer: Optional[Quantizer] = None,
output_quantizer: Optional[Quantizer] = None,
ub_comm_name: str,
) -> tuple[torch.Tensor, dict]:
"""Functional API for forward pass
......@@ -127,16 +123,14 @@ class UserbuffersForwardLinear(FusedOperation):
parallelism, i.e. distributing input or output tensors
along outer dimension (sequence or batch dim) when not
distributing along inner dimension (embedding dim)
with_fp8_compute: bool, default = `False`
Whether to perform compute in FP8
input_fp8_meta: dict, optional
FP8 metadata for casting input tensor to FP8. Required for
FP8 compute if input is not already in FP8.
weight_fp8_meta: dict, optional
FP8 metadata for casting weight tensor to FP8. Required for
FP8 compute if weight is not already in FP8.
output_fp8_meta: dict, optional
FP8 metadata for casting output tensor to FP8
with_quantized_compute: bool, default = `False`
Whether to perform compute with quantized data.
input_quantizer: Quantizer, optional
Builder class for quantized input tensor.
weight_quantizer: Quantizer, optional
Builder class for quantized weight tensor.
output_quantizer: Quantizer, optional
Builder class for quantized output tensor.
ub_comm_name: str
Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is
used to access the corresponding Userbuffers communicators
......@@ -166,23 +160,6 @@ class UserbuffersForwardLinear(FusedOperation):
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
# Input tensor dims
input_dims = tuple(input.size())
weight_dims = tuple(weight.size())
if len(weight_dims) != 2:
raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})")
if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]:
raise ValueError(
f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={weight_dims}) "
"are not compatible"
)
# Output tensor dims
output_dims = list(input_dims)
output_dims[0] = -1
output_dims[-1] = weight_dims[0]
# Check tensor parallel group
if tensor_parallel_size is None:
tensor_parallel_size = get_distributed_world_size(tensor_parallel_group)
......@@ -196,235 +173,106 @@ class UserbuffersForwardLinear(FusedOperation):
if not sequence_parallel:
raise RuntimeError(f"Invalid configuration for Userbuffers ({sequence_parallel=})")
# Check if FP8 is enabled
if with_fp8_compute:
if input_fp8_meta is None and not is_float8_tensor(input):
raise ValueError("No FP8 metadata was provided for casting input to FP8")
if weight_fp8_meta is None and not is_float8_tensor(weight):
raise ValueError("No FP8 metadata was provided for casting weight to FP8")
# Check quantizers
if with_quantized_compute:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
if output_quantizer is not None:
raise ValueError("FP8 output is not supported")
else:
input_fp8_meta = None
weight_fp8_meta = None
output_fp8_meta = None
with_fp8_output = (
with_fp8_compute and tensor_parallel_mode != "row" and output_fp8_meta is not None
)
input_quantizer = None
weight_quantizer = None
output_quantizer = None
# Get Userbuffers communicator
ub_comm = get_ub(ub_comm_name + "_fprop")
ub_local_buffer = ub_comm.get_ubuf_output(0)
ub_global_buffer = ub_comm.get_ubuf_output(1)
with_ub_all_gather = tensor_parallel_mode == "column"
with_ub_reduce_scatter = tensor_parallel_mode == "row"
ub_type = CommOverlapType.AG if with_ub_all_gather else CommOverlapType.RS
# Choose Userbuffers communication algorithm
ub_algo = None
# Initialize input tensor
x_local = input
x = None
if with_ub_all_gather:
if with_fp8_compute and ub_comm.is_atomic_gemm():
ub_algo = CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
ub_algo = CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
elif with_ub_reduce_scatter:
is_atomic_gemm = with_fp8_compute and ub_comm.is_atomic_gemm()
ub_algo = {
(True, True): CommOverlapAlgo.ATOMIC_GEMM_RS_P2P,
(True, False): CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P,
(False, True): CommOverlapAlgo.ATOMIC_GEMM_RS,
(False, False): CommOverlapAlgo.SPLIT_PIPELINED_RS,
}[(ub_comm.is_p2p_overlap(), is_atomic_gemm)]
else:
raise RuntimeError("Could not choose Userbuffers communication algorithm")
# Cast input tensor to correct dtype
x_local = reshape(
input,
(-1, input_dims[-1]),
device=device,
dtype=dtype,
)
if with_fp8_compute and not is_float8_tensor(x_local):
fp8_dtype = get_fp8_te_dtype(
input_fp8_meta["recipe"],
fprop_tensor=True,
)
with_transpose_cache = weight.requires_grad
if tensor_parallel_mode == "column" and sequence_parallel:
with_transpose_cache = False
x_local = Float8Tensor.to_float8(
if input_quantizer is not None:
if not isinstance(x_local, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=True)
if isinstance(input_quantizer, Float8Quantizer):
input_quantizer.set_usage(columnwise=False)
x_local = input_quantizer(x_local)
input_quantizer.set_usage(rowwise=True, columnwise=False)
x, x_local = fill_userbuffers_buffer_for_all_gather(
ub_comm,
x_local,
fp8_meta=input_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
data=(ub_local_buffer if with_ub_all_gather else None),
with_transpose_cache=with_transpose_cache,
input_quantizer,
tensor_parallel_group,
)
elif not with_fp8_compute and is_float8_tensor(x_local):
if with_ub_all_gather:
x_local = ub_local_buffer.copy_(x_local)
else:
x_local = x_local.dequantize()
# Initialize buffers for UB all-gather if needed
x = x_local
if with_ub_all_gather:
if with_fp8_compute:
x = Float8Tensor.make_like(x_local, data=ub_global_buffer)
if x_local._data.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(x_local._data)
else:
x_local._data = torch.empty_like(x_local._data)
else:
if with_quantized_compute:
if not isinstance(x_local, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=True)
x_local = input_quantizer(x_local)
else:
x = ub_global_buffer
if x_local.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(x_local)
else:
x_local = torch.empty_like(x_local)
# Check weight tensor
w = convert_tensor(
weight,
device=device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
if with_fp8_compute and not is_float8_tensor(w):
fp8_dtype = get_fp8_te_dtype(
weight_fp8_meta["recipe"],
fprop_tensor=True,
)
w = Float8Tensor.to_float8(
w,
fp8_meta=weight_fp8_meta,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
)
elif not with_fp8_compute and is_float8_tensor(w):
if isinstance(x_local, QuantizedTensorBase):
x_local = x_local.dequantize(dtype=dtype)
if x_local.dtype != dtype:
x_local = x_local.to(dtype=dtype)
x = x_local
# Initialize weight tensor
w = weight
w_is_quantized = isinstance(w, QuantizedTensorBase)
if with_quantized_compute and not w_is_quantized:
weight_quantizer.set_usage(rowwise=True)
w = weight_quantizer(w)
elif not with_quantized_compute and w_is_quantized:
w = w.dequantize()
if not with_quantized_compute and w.dtype != dtype:
w = w.to(dtype=dtype)
# Check bias tensor
b = None
if bias is not None:
b = convert_tensor(
bias,
device=device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
# Construct output tensor
y = None
y_local = None
# Construct output tensor if needed
reduce_scatter_output = None
if with_ub_reduce_scatter:
# Initialize buffers for UB reduce-scatter
if with_fp8_output:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
fp8_dtype = get_fp8_te_dtype(
output_fp8_meta["recipe"],
fprop_tensor=True,
)
y = Float8Tensor(
data=ub_global_buffer,
fp8_meta=output_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=output_fp8_meta[fp8_meta_key].scale_inv[0],
dtype=dtype,
)
ub_comm.set_ubuf_scale_inv(y._scale_inv)
else:
y = ub_global_buffer
y_local = torch.empty(
(x.size(0) // tensor_parallel_size, weight_dims[0]),
dtype=dtype,
device=device,
)
else:
# Allocate output tensor
if with_fp8_output:
fp8_dtype = get_fp8_te_dtype(
output_fp8_meta["recipe"],
fprop_tensor=True,
)
data = torch.empty(
(x.size(0), weight_dims[0]),
dtype=torch.uint8,
device=device,
)
y = Float8Tensor(
data=data,
fp8_meta=output_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
dtype=dtype,
)
else:
y = torch.empty(
(x.size(0), weight_dims[0]),
dtype=dtype,
device=device,
)
y_local = y
y_local_size = list(x.size())
y_local_size[0] //= tensor_parallel_size
y_local_size[-1] = w.size(0)
reduce_scatter_output = torch.empty(y_local_size, dtype=dtype, device=device)
# Perform GEMM
if with_fp8_compute:
kwargs = {
"out": y,
"bias": b,
"use_bias": (b is not None),
"use_split_accumulator": False,
"ub_algo": ub_algo,
"ub": ub_comm,
}
if with_ub_all_gather:
kwargs["extra_output_tensor"] = x_local._data
if with_ub_reduce_scatter:
kwargs["extra_output_tensor"] = y_local
if with_fp8_output:
fp8_meta, fp8_meta_index = get_fp8_meta_from_fp8_tensor(y)
kwargs.update(
{
"out": y._data,
"out_index": fp8_meta_index,
"fp8_meta_tensor": fp8_meta,
"D_dtype": y._fp8_dtype,
}
)
fp8_gemm(
w._data,
w._scale_inv,
0,
w._fp8_dtype,
x._data,
x._scale_inv,
0,
x._fp8_dtype,
y.dtype,
get_workspace(),
**kwargs,
)
gemm_output, *_, reduce_scatter_output = general_gemm(
w,
x,
get_workspace(),
out_dtype=dtype,
quantization_params=output_quantizer,
bias=bias,
use_split_accumulator=_2X_ACC_FPROP,
ub=ub_comm,
ub_type=ub_type,
extra_output=reduce_scatter_output,
)
if with_ub_reduce_scatter:
y_local = reduce_scatter_output
else:
kwargs = {
"out": y,
"bias": b,
"use_bias": (b is not None),
"ub_algo": ub_algo,
"ub": ub_comm,
}
if with_ub_all_gather:
kwargs["extra_output_tensor"] = x_local
if with_ub_reduce_scatter:
kwargs["extra_output_tensor"] = y_local
gemm(w, x, y.dtype, get_workspace(), **kwargs)
# Reshape output tensor
out = reshape(y_local, output_dims)
y_local = gemm_output
# Detach input tensor if needed
# Note: PyTorch autograd produces esoteric errors if we save
# input tensor as context for backward pass.
if x_local is input:
x_local = x_local.detach()
# Configure input tensor for backward pass
if with_quantized_compute and isinstance(x_local, QuantizedTensorBase):
if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
# Return cast tensors
extra_outputs = {"input": x_local, "weight": w}
return out, extra_outputs
return y_local, extra_outputs
def fuser_forward(
self,
......@@ -450,23 +298,22 @@ class UserbuffersForwardLinear(FusedOperation):
if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments")
# FP8 metadata
with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled()
input_fp8_meta = None
weight_fp8_meta = None
output_fp8_meta = None
grad_output_fp8_meta = None
grad_input_fp8_meta = None
if with_fp8_compute:
input_fp8_meta = linear_op.get_fp8_meta("input")
weight_fp8_meta = linear_op.get_fp8_meta("param")
next_op = basic_op_next_ops[-1]
if next_op is not None and next_op.num_fp8_scales("input") > 0:
output_fp8_meta = next_op.get_fp8_meta("input")
grad_output_fp8_meta = linear_op.get_fp8_meta("grad_output")
# Quantization metadata
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
weight_quantizer = None
grad_output_quantizer = None
grad_input_quantizer = None
if with_quantized_compute:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if not recipe.delayed() and not recipe.mxfp8():
raise RuntimeError("Userbuffers is only supported with FP8 delayed scaling recipe")
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
prev_op = basic_op_prev_ops[0]
if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0:
grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output")
if prev_op is not None and prev_op.num_quantizers("backward") > 0 and recipe.delayed():
grad_input_quantizer = prev_op.get_quantizer("backward", 0)
# Get autocast dtype if needed
dtype = None
......@@ -482,26 +329,26 @@ class UserbuffersForwardLinear(FusedOperation):
input=input_,
weight=linear_op.weight,
bias=bias,
device=linear_op.device,
dtype=dtype,
tensor_parallel_mode=self.tensor_parallel_mode,
tensor_parallel_group=self.tensor_parallel_group,
tensor_parallel_size=self.tensor_parallel_size,
sequence_parallel=self.sequence_parallel,
with_fp8_compute=with_fp8_compute,
input_fp8_meta=input_fp8_meta,
weight_fp8_meta=weight_fp8_meta,
output_fp8_meta=output_fp8_meta,
with_quantized_compute=with_quantized_compute,
input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer,
output_quantizer=None, # Not supported
ub_comm_name=linear_op._userbuffers_options["comm_name"],
)
x_local = extra_outputs["input"]
# Save state for backward pass
linear_op_ctx.save_for_backward(x_local)
linear_op_ctx.with_fp8_compute = with_fp8_compute
linear_op_ctx.weight_fp8_meta = weight_fp8_meta
linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta
linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype
linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_.requires_grad
......@@ -529,8 +376,6 @@ def fuse_userbuffers_forward_linear(
"""
return ops ### TODO Debug Userbuffers support
# Return immediately if environment is not distributed
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
return ops
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment