Unverified Commit ce0b46c4 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

MXFP8 support in Userbuffers (#1711)



* Initial work toward restoring UB support in te.Sequential
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Forward UB linear runs, but has numerical error
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug UB forward tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Minor tweaks
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove Python checks for MXFP8 UB linear forward
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add dim check for MXFP8 full tiles
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move QuantizedTensor logic out of UB comm and into Python helper function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support MXFP8 AGs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Coalesce NCCL all-gathers for MXFP8 all-gather
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Initial impl of backward UB linear in te.Sequential
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug UB linear backward with no quantization

dgrad GEMM + dx RS is still broken.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix chunk dims for dgrad GEMM + dx RS
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debugging MXFP8 UB cases

Still failing with dy AG + wgrad GEMM
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use NCCL to overlap dy AG with dgrad GEMM
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug UB GEMM tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Initial refactoring of linear module forward
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactor linear module backward
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug linear module UB tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tweak test tensor dims
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Do not store autograd context within wgrad GEMM closure
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update LayerNormLinear
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update LayerNormMLP
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug UB tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Minor style tweaks
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix incorrect usage for GEMM input with block-scaled FP8
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix RS out dims
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable dgrad GEMM + UB AG + NCCL AG overlapping
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Disable dgrad GEMM + UB AG + NCCL AG overlap in te.Sequential
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Restore support for internal quantized tensors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add tests for MXFP8 GEMM with UB
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 3d2152c2
......@@ -26,7 +26,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PAT
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
# python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
......
......@@ -20,7 +20,11 @@ from torch.distributed.elastic.multiprocessing.errors import record
import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_cublas_workspace_size_bytes
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.module.base import (
fill_userbuffers_buffer_for_all_gather,
get_cublas_workspace_size_bytes,
)
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
......@@ -56,7 +60,11 @@ def _parse_args(argv=None, namespace=None):
)
parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
"--quantization",
type=str.lower,
default="none",
choices=["none", "fp8", "mxfp8"],
help="Quantization recipe",
)
parser.add_argument(
"--fp8-output", action="store_true", default=False, help="Get FP8 output from GEMM."
......@@ -154,9 +162,9 @@ def _parse_args(argv=None, namespace=None):
if opts.atomic:
warnings.warn("Atomic GEMM is not supported with bulk overlap.")
opts.atomic = False
if opts.fp8:
if opts.quantization != "none":
warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.")
opts.fp8 = False
opts.quantization = "none"
elif opts.comm_type == tex.CommOverlapType.AG:
if opts.atomic:
setattr(opts, "atomic_rs_p2p", opts.p2p)
......@@ -164,8 +172,11 @@ def _parse_args(argv=None, namespace=None):
if opts.atomic:
if not te.fp8.check_fp8_support():
assert not opts.fp8, "Atomic GEMM is only supported in FP8."
opts.fp8 = True
assert opts.quantization == "none", "Atomic GEMM is only supported in FP8."
opts.quantization = "fp8"
if opts.fp8_output:
assert ops.quantization == "fp8", "FP8 output is only supported with FP8 compute."
return opts
......@@ -302,7 +313,11 @@ def _main(opts):
inp_shape = (opts.seq_length, opts.batch_size, hidden_size)
outer_size = reduce(operator.mul, inp_shape[:-1], 1)
buffer_dtype = torch.bfloat16
if opts.fp8 and not opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.AG:
if (
opts.quantization != "none"
and not opts.bulk_overlap
and opts.comm_type == tex.CommOverlapType.AG
):
buffer_dtype = torch.uint8
ub_obj = (
tex.CommOverlapP2P(
......@@ -447,6 +462,8 @@ def _main(opts):
inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable
ref2_g = torch.matmul(inp2_g, ker2_g)
# Initialize quantizers
with_quantized_compute = opts.quantization != "none"
inp_quantizer = None
ker_quantizer = None
out_quantizer = None
......@@ -454,7 +471,7 @@ def _main(opts):
inp2_quantizer = None
ker2_quantizer = None
out2_quantizer = None
if opts.fp8:
if opts.quantization == "fp8":
# Structure to maintain amax and scale/scale_inv information for the kernel and input
num_gemms = 6 if ub_obj2 is not None else 3
fp8_dtype = tex.DType.kFloat8E4M3
......@@ -499,11 +516,23 @@ def _main(opts):
out2_quantizer = Float8Quantizer(
fp8_scales[5].clone(), fp8_amaxes[5].clone(), fp8_dtype
)
elif opts.quantization == "mxfp8":
fp8_dtype = tex.DType.kFloat8E4M3
inp_quantizer = MXFP8Quantizer(fp8_dtype, columnwise=False)
ker_quantizer = MXFP8Quantizer(fp8_dtype)
if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
bulk_inp_quantizer = MXFP8Quantizer(fp8_dtype, columnwise=False)
elif ub_obj2 is not None:
inp2_quantizer = MXFP8Quantizer(fp8_dtype, columnwise=False)
ker2_quantizer = MXFP8Quantizer(fp8_dtype)
# Quantize tensors
if with_quantized_compute:
# Cast input to Float8Tensor
# Quantize input tensor
inp_fp8 = inp_quantizer(inp)
# Cast kernel to Float8Tensor
# Quantize kernel tensor
kernel_t_fp8 = ker_quantizer(kernel_t)
if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
bulk_inp_fp8 = bulk_inp_quantizer(bulk_inp)
......@@ -540,31 +569,40 @@ def _main(opts):
)
# Set up comm/compute buffers
ag_out = None
rs_out = None
rs_out2 = None
if opts.comm_type == tex.CommOverlapType.AG:
if opts.bulk_overlap:
ub_obj.copy_into_buffer(bulk_inp, bulk_inp_quantizer, True)
ag_out, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj,
bulk_inp,
bulk_inp_quantizer,
tp_group,
)
gemm_inp = inp
else:
ub_obj.copy_into_buffer(inp_fp8 if opts.fp8 else inp, inp_quantizer, True)
gemm_inp = ub_obj.get_buffer(inp_quantizer, False, inp_g.size())
ag_out, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj,
inp_fp8 if with_quantized_compute else inp,
inp_quantizer,
tp_group,
)
gemm_inp = ag_out
if ub_obj2 is not None:
if opts.fp8 and opts.fp8_output:
ub_obj2.set_buffer_params(out_quantizer)
rs_out2 = torch.empty(
(outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda"
)
else:
if opts.bulk_overlap:
ub_obj.copy_into_buffer(
bulk_inp_fp8 if opts.fp8 else bulk_inp, bulk_inp_quantizer, False
)
if opts.fp8:
ub_obj.set_buffer_params(bulk_inp_quantizer)
elif opts.fp8 and opts.fp8_output:
ub_obj.set_buffer_params(out_quantizer)
gemm_inp = inp_fp8 if opts.fp8 else inp
if opts.quantization == "none":
ub_obj.copy_into_buffer(bulk_inp, local_chunk=False)
if opts.quantization == "fp8":
ub_obj.copy_into_buffer(bulk_inp_fp8._data, local_chunk=False)
elif opts.quantization == "mxfp8":
ub_obj.copy_into_buffer(bulk_inp_fp8._rowwise_data, local_chunk=False)
gemm_inp = inp_fp8 if with_quantized_compute else inp
rs_out = torch.empty(
(outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda"
)
......@@ -623,7 +661,7 @@ def _main(opts):
if opts.use_cuda_graphs:
# Trace the CUDA graph first
g = torch.cuda.CUDAGraph()
if opts.fp8:
if with_quantized_compute:
if ub_obj is None:
with torch.cuda.graph(g):
all_outputs = _fp8_gemm()
......@@ -643,7 +681,7 @@ def _main(opts):
else:
for i in range(total_iters):
if opts.fp8:
if with_quantized_compute:
start_events[i].record()
all_outputs = _fp8_gemm()
end_events[i].record()
......@@ -688,10 +726,22 @@ def _main(opts):
output_info = ""
if opts.comm_type == tex.CommOverlapType.AG:
# Bulk overlap AG output is already gathered
test_out = ub_obj.get_buffer(bulk_inp_quantizer, False)
test_out = ag_out
if bulk_inp_quantizer is None:
test_out = ub_obj.get_buffer(False)
else:
test_out = Float8Tensor(
shape=test_out.shape,
dtype=torch.bfloat16,
data=ub_obj.get_buffer(False),
fp8_scale=bulk_inp_quantizer.scale,
fp8_dtype=bulk_inp_quantizer.dtype,
quantizer=bulk_inp_quantizer,
)
else:
# Bulk overlap RS output needs to be gathered
out_local = ub_obj.get_buffer(bulk_inp_quantizer, True)
out_local = ub_obj.get_buffer(True)
output_info += f"rs_output: {list(out_local.shape)} | "
test_out = te.distributed.gather_along_first_dim(out_local, tp_group)[0]
......@@ -762,8 +812,8 @@ def _main(opts):
m = torch.argmax(diff)
abs_err = diff[m].item()
rel_err = abs_err / max(abs(ref_out.flatten()[m].item()), 1e-5)
rtol = 0.125 if opts.fp8 else 0.02
atol = 0.0625 if opts.fp8 else 0.001
rtol = 0.02 if opts.quantization == "none" else 0.125
atol = 0.001 if opts.quantization == "none" else 0.0625
if rel_err > rtol and abs_err > atol:
numerics_failed = True
numerics_info = (
......
......@@ -17,7 +17,12 @@ import torch
import torch.distributed as dist
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling, Float8CurrentScaling
from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
Format,
MXFP8BlockScaling,
)
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
......@@ -163,7 +168,7 @@ def _parse_args(argv=None, namespace=None):
"--quantization",
type=str.lower,
default="none",
choices=["none", "fp8_delayed_scaling", "fp8_current_scaling"],
choices=["none", "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"],
help="Quantization recipe",
)
parser.add_argument(
......@@ -414,6 +419,8 @@ def _train(opts):
)
elif opts.quantization == "fp8_current_scaling":
fp8_recipe = Float8CurrentScaling(fp8_format=fp8_format)
elif opts.quantization == "mxfp8":
fp8_recipe = MXFP8BlockScaling()
# Prepare random input tensors
test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True)
......
......@@ -15,11 +15,12 @@ if torch.cuda.device_count() < 2:
pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
RNG_SEED: int = 42
SEQ_LENGTH: int = 1024
BATCH_SIZE: int = 2
NUM_HEADS: int = 16
NUM_HEADS: int = 32
HEAD_DIM: int = 48
TE_LAYERS = [
te.Linear,
......@@ -50,7 +51,7 @@ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
torch._dynamo.reset()
def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8):
def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, quantization):
test_path = TEST_ROOT / "run_gemm_with_overlap.py"
test_cmd = LAUNCH_CMD + [
str(test_path),
......@@ -66,10 +67,11 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8):
if bulk:
test_cmd.append("--bulk-overlap")
else:
if fp8:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
test_cmd.append("--fp8")
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
test_cmd.append(f"--quantization={quantization}")
if p2p:
test_cmd.append("--p2p")
if atomic:
......@@ -107,8 +109,10 @@ def _run_layer_with_overlap(
test_cmd.append("--overlap-rs-dgrad")
if fp8:
if not fp8_available:
if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
test_cmd.append("--fp8")
test_cmd.append(f"--quantization={quantization}")
......@@ -130,51 +134,34 @@ def _run_layer_with_overlap(
raise AssertionError(result.stderr.decode())
@pytest.mark.parametrize(
"fp8",
(False, True),
ids=[" BF16 - RING-EXCHANGE ", " FP8 - RING-EXCHANGE "],
)
def test_split_all_gather_overlaps(fp8):
@pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8"))
def test_split_all_gather_overlaps(quantization):
"""
Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap("AG", False, True, False, fp8)
_run_gemm_with_overlap("AG", False, True, False, quantization)
@pytest.mark.parametrize(
"fp8,p2p",
[
(False, False),
(False, True),
(True, False),
(True, True),
],
ids=[
" BF16 - PIPELINE ",
" BF16 - RING-EXCHANGE ",
" FP8 - PIPELINE ",
" FP8 - RING-EXCHANGE ",
],
)
def test_split_reduce_scatter_overlaps(fp8, p2p):
@pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8"))
@pytest.mark.parametrize("p2p", (False, True))
def test_split_reduce_scatter_overlaps(quantization, p2p):
"""
Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap("RS", False, p2p, False, fp8)
_run_gemm_with_overlap("RS", False, p2p, False, quantization)
@pytest.mark.parametrize(
"comm_type, fp8, connections",
"comm_type, quantization, connections",
[
("AG", False, 1),
("RS", False, 1),
("RS", True, 1),
("AG", False, 8),
("RS", False, 8),
("RS", True, 8),
("AG", "none", 1),
("RS", "none", 1),
("RS", "fp8", 1),
("AG", "none", 8),
("RS", "none", 8),
("RS", "fp8", 8),
],
ids=[
"ALL-GATHER - BF16 - 1 connections",
......@@ -185,7 +172,7 @@ def test_split_reduce_scatter_overlaps(fp8, p2p):
"REDUCE-SCATTER - FP8 - 8 connections",
],
)
def test_bulk_overlaps(comm_type, fp8, connections):
def test_bulk_overlaps(comm_type, quantization, connections):
"""
Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm.
"""
......@@ -196,10 +183,10 @@ def test_bulk_overlaps(comm_type, fp8, connections):
" 9.0 (HOPPER ARCH)."
)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
_run_gemm_with_overlap(comm_type, True, False, False, fp8)
_run_gemm_with_overlap(comm_type, True, False, False, quantization)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
else:
_run_gemm_with_overlap(comm_type, True, False, False, fp8)
_run_gemm_with_overlap(comm_type, True, False, False, quantization)
@pytest.mark.parametrize(
......@@ -251,15 +238,7 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
@pytest.mark.parametrize(
"quantization",
["fp8_delayed_scaling", "fp8_current_scaling"],
ids=[" DELAYED SCALING ", " CURRENT SCALING "],
)
@pytest.mark.parametrize(
"fp8",
(True,),
ids=[
" FP8 ",
],
["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"],
)
@pytest.mark.parametrize(
"layer_type,linear_parallel_mode,overlap_rs_dgrad",
......@@ -279,15 +258,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
)
),
ids=[
f" {te.Linear.__name__} - ROW-PARALLEL ",
f" {te.Linear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ",
f" {te.Linear.__name__} - COL-PARLALEL - DGRAD+RS ",
f" {te.LayerNormLinear.__name__} - ROW-PARALLEL ",
f" {te.LayerNormLinear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ",
f" {te.LayerNormLinear.__name__} - COL-PARALLEL - DGRAD+RS ",
f"{te.Linear.__name__}-row_tensor_parallel",
f"{te.Linear.__name__}-col_tensor_parallel-BULK DGRAD/WGRAD",
f"{te.Linear.__name__}-col_tensor_parallel-DGRAD+RS",
f"{te.LayerNormLinear.__name__}-row_tensor_parallel",
f"{te.LayerNormLinear.__name__}-col_tensor_parallel-BULK DGRAD/WGRAD",
f"{te.LayerNormLinear.__name__}-col_tensor_parallel-DGRAD+RS",
]
+ [
" " + " - ".join(test_name_parts) + " "
"-".join(test_name_parts)
for test_name_parts in zip(
[layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)],
["BULK DGRAD/WGRAD", "DGRAD+RS"] * len(TE_LAYERS[2:]),
......@@ -295,12 +274,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
],
)
def test_layers_with_overlap_fp8(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization
layer_type,
linear_parallel_mode,
overlap_rs_dgrad,
quantization,
):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization)
_run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization)
@pytest.mark.parametrize(
......@@ -347,22 +329,11 @@ def test_multi_layer_with_overlap_bf16(
@pytest.mark.parametrize(
"quantization",
["fp8_delayed_scaling", "fp8_current_scaling"],
ids=[" DELAYED SCALING ", " CURRENT SCALING "],
)
@pytest.mark.parametrize(
"fp8",
(True,),
ids=[
" FP8 ",
],
["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"],
)
@pytest.mark.parametrize(
"num_layers",
(2,),
ids=[
" 2 layers ",
],
)
@pytest.mark.parametrize(
"layer_type,linear_parallel_mode,overlap_rs_dgrad",
......@@ -374,7 +345,7 @@ def test_multi_layer_with_overlap_bf16(
)
),
ids=[
" " + " - ".join(test_name_parts) + " "
"-".join(test_name_parts)
for test_name_parts in zip(
[te.TransformerLayer.__name__ for _ in range(2)],
["BULK DGRAD/WGRAD", "DGRAD+RS"],
......@@ -382,11 +353,11 @@ def test_multi_layer_with_overlap_bf16(
],
)
def test_multi_layer_with_overlap_fp8(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers
layer_type, linear_parallel_mode, overlap_rs_dgrad, quantization, num_layers
):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers
layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization, num_layers
)
......@@ -19,7 +19,6 @@ import torch
import transformer_engine
import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
......@@ -27,6 +26,8 @@ from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear,
UserbuffersForwardLinear,
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.utils import is_bf16_compatible
# Import utility functions
......@@ -36,6 +37,13 @@ from utils import dtype_tols, str_to_dtype
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
quantization_list: list[Optional[str]] = [None]
if fp8_available:
quantization_list.append("fp8")
if mxfp8_available:
quantization_list.append("mxfp8")
# Check if there are multiple GPUs
if torch.cuda.device_count() < 2:
......@@ -51,7 +59,7 @@ class ModelConfig:
num_heads: int
head_dim: int
dtype: torch.dtype
fp8: bool
quantization: Optional[str]
@property
def hidden_size(self):
......@@ -129,12 +137,16 @@ def make_reference_and_test_tensors(
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Make copy of tensor
test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8:
test = Float8Tensor.to_float8(ref)
else:
test = ref.to(device=test_device, dtype=test_dtype)
if test.data_ptr() == ref.data_ptr():
test = test.clone()
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device),
amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
test = quantizer(test)
elif test.data_ptr() == ref.data_ptr():
test = test.clone()
# Make sure reference and test tensors represent exact same values
ref.copy_(test)
......@@ -145,6 +157,21 @@ def make_reference_and_test_tensors(
return ref, test
def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name == "fp8":
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
def _test_linear(
*,
model_config: ModelConfig,
......@@ -155,7 +182,8 @@ def _test_linear(
weight_requires_grad: bool = True,
) -> None:
dtype = model_config.dtype
fp8_compute = model_config.fp8
quantization = model_config.quantization
quantized_compute = quantization is not None
# Distributed process group
process_group = world_group()
......@@ -175,14 +203,19 @@ def _test_linear(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_compute,
test_is_fp8=quantized_compute,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_compute,
test_is_fp8=quantized_compute,
)
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
b_ref, b_test = None, None
if bias:
if tensor_parallel_mode == "row":
......@@ -198,9 +231,11 @@ def _test_linear(
out_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_compute,
test_is_fp8=quantized_compute,
requires_grad=False,
)
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
......@@ -265,21 +300,15 @@ def _test_linear(
x_test.requires_grad_()
# Implementation with fusible operation
with te.fp8_model_init(enabled=fp8_compute):
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_compute, recipe=recipe):
ops = []
linear_op = None
bias_op = None
if tensor_parallel_mode == "column":
userbuffers_options = {}
if not weight_requires_grad:
if fp8_compute:
userbuffers_options["comm_name"] = "fc1"
else:
# There is a correctness bug with overlapping
# dgrad reduce-scatter with dgrad GEMM. Fall back
# to overlapping dgrad reduce-scatter with wgrad
# GEMM, even though wgrad isn't needed.
userbuffers_options["comm_name"] = "qkv"
userbuffers_options["comm_name"] = "fc1"
else:
userbuffers_options["comm_name"] = "qkv"
linear_op = te_ops.BasicLinear(
......@@ -322,7 +351,7 @@ def _test_linear(
bias_op.bias.copy_(b_test)
del w_test
del b_test
with te.fp8_autocast(enabled=fp8_compute):
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = model(x_test)
y_test.backward(dy_test)
......@@ -338,7 +367,7 @@ def _test_linear(
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if fp8_compute:
if quantized_compute:
tols = dtype_tols(
model[0].weight._fp8_dtype
if is_float8_tensor(model[0].weight)
......@@ -370,7 +399,7 @@ def run_parallel_tests(model_config: ModelConfig) -> None:
for test_config in itertools.product(
(False, True), # bias
("column", "row"), # tensor_parallel_mode
(False, True), # weight_requires_grad
(True, False), # weight_requires_grad
):
if rank == 0:
print(f"Running _test_linear with {test_config=}")
......@@ -390,19 +419,15 @@ if torch.cuda.device_count() > 1:
@pytest.mark.parametrize("world_size", _world_sizes)
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("quantization", quantization_list)
def test_fuser_ops_with_userbuffers(
*,
world_size: int,
dtype: torch.dtype = torch.bfloat16,
fp8: bool,
quantization: Optional[str],
) -> None:
"""Launch parallel job and run tests"""
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
# Parallel job launcher
command = []
if tex.ubuf_built_with_mpi():
......@@ -424,8 +449,8 @@ def test_fuser_ops_with_userbuffers(
str(dtype),
)
)
if fp8:
command.append("--fp8")
if quantization is not None:
command.extend(("--quantization", quantization))
# Environment
env = dict(os.environ)
......@@ -445,12 +470,12 @@ def main() -> None:
# Parse command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--parallel", action="store_true", help="Run parallel tests")
parser.add_argument("--sequence-length", type=int, default=32)
parser.add_argument("--sequence-length", type=int, default=256)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--num-heads", type=int, default=16)
parser.add_argument("--head-dim", type=int, default=32)
parser.add_argument("--head-dim", type=int, default=256)
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--fp8", action="store_true")
parser.add_argument("--quantization", type=str, default=None)
args = parser.parse_args()
# Run parallel tests if needed
......@@ -463,14 +488,17 @@ def main() -> None:
num_heads=args.num_heads,
head_dim=args.head_dim,
dtype=str_to_dtype(args.dtype),
fp8=args.fp8,
quantization=args.quantization,
)
# Initialize Userbuffers
group = world_group() # Initialize NCCL
bootstrap_backend = "mpi" if launcher() == "ompi" else "nccl"
userbuffer_configs = {
"fc1_dgrad": {"method": "pipeline"}, # Overlap dgrad RS with dgrad GEMM
"fc1_dgrad": {
"method": "ring_exchange",
"fp8_buf": False,
}, # Overlap dgrad RS with dgrad GEMM
}
te.module.base.initialize_ub(
[
......@@ -478,7 +506,7 @@ def main() -> None:
model_config.num_heads * model_config.head_dim,
],
torch.distributed.get_world_size(group),
use_fp8=model_config.fp8,
use_fp8=model_config.quantization is not None,
dtype=model_config.dtype,
bootstrap_backend=bootstrap_backend,
ub_cfgs=userbuffer_configs,
......
......@@ -147,7 +147,44 @@ CommOverlapCore::~CommOverlapCore() {
TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, size_t chunk_offset,
const std::vector<size_t> &chunk_shape) {
TensorWrapper chunk;
const auto scaling_mode = source.scaling_mode();
// Tensor dimensions
std::vector<size_t> shape = shape_to_vector(source.shape());
auto flatten_shape_to_2d = [](const std::vector<size_t> &shape) -> std::pair<size_t, size_t> {
if (shape.empty()) {
return {1, 1};
}
size_t height = 1;
for (size_t i = 0; i < shape.size() - 1; ++i) {
height *= shape[i];
}
return {height, shape.back()};
};
size_t height, width, chunk_height, chunk_width;
std::tie(height, width) = flatten_shape_to_2d(shape);
std::tie(chunk_height, chunk_width) = flatten_shape_to_2d(chunk_shape);
// Check tensor dimensions
#define NVTE_DIM_CHECK(cond, message) \
NVTE_CHECK(cond, message, " (tensor shape=", shape, ", chunk shape=", chunk_shape, \
", chunk offset=", chunk_offset, ")")
NVTE_DIM_CHECK(height > 0 && width > 0, "Attempted to get chunk from empty tensor");
NVTE_DIM_CHECK(chunk_height > 0 && chunk_width > 0, "Attempted to get empty tensor chunk");
NVTE_DIM_CHECK(chunk_height <= height && chunk_width <= width,
"Attempted to get out-of-bounds tensor chunk");
if (scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING) {
// MXFP8 scale-inverses are padded to a 2D matrix with dims that
// are divisible by 128. UB doesn't handle this padding yet.
NVTE_DIM_CHECK(height % 128 == 0 && width % 128 == 0,
"Userbuffers requires MXFP8 tensor dims that are divisible by 128");
NVTE_DIM_CHECK(chunk_height % 128 == 0 && chunk_width % 128 == 0,
"Userbuffers requires MXFP8 tensor chunk dims that are divisible by 128");
}
#undef NVTE_DIM_CHECK
// Construct tensor chunk
TensorWrapper chunk(scaling_mode);
for (int param_id = 0; param_id < NVTETensorParam::kNVTENumTensorParams; param_id++) {
auto param_type = static_cast<NVTETensorParam>(param_id);
auto param = source.get_parameter(param_type);
......@@ -163,8 +200,8 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
param_shape = chunk_shape;
if (param_type == NVTETensorParam::kNVTEColumnwiseData &&
source.scaling_mode() != NVTEScalingMode::NVTE_MXFP8_1D_SCALING) {
// Columnwise shape for non-block scaled tensors shifts the last dimension to the front
source.scaling_mode() == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) {
// Columnwise shape for FP8 tensor-scaled tensors shifts the last dimension to the front
auto last_dim = param_shape.back();
param_shape.pop_back();
param_shape.insert(param_shape.begin(), last_dim);
......@@ -172,18 +209,16 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
} else if (source.scaling_mode() == NVTEScalingMode::NVTE_MXFP8_1D_SCALING &&
(param_type == NVTETensorParam::kNVTERowwiseScaleInv ||
param_type == NVTETensorParam::kNVTEColumnwiseScaleInv)) {
// Calculate block scaling offset and size
auto scaled_tensor_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv)
? source.shape().data[0]
: source.columnwise_shape().data[0];
auto scaled_chunk_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv)
? chunk_shape.front()
: chunk_shape.back();
auto chunk_scale_start = chunk_offset / 32;
auto chunk_scale_end = (chunk_offset + scaled_chunk_dim_size) / 32;
auto chunk_scale_size = chunk_scale_end - chunk_scale_start;
param_dptr += chunk_scale_start * typeToSize(param_dtype);
param_shape = std::vector<size_t>{chunk_scale_size};
// Calculate offset and size for MXFP8 scale-invs
size_t chunk_scale_height = chunk_height;
size_t chunk_scale_width = chunk_width;
if (param_type == NVTETensorParam::kNVTERowwiseScaleInv) {
chunk_scale_width /= 32;
} else {
chunk_scale_height /= 32;
}
param_dptr += (chunk_offset / 32) * typeToSize(param_dtype);
param_shape = {chunk_scale_height, chunk_scale_width};
}
// Set chunked source parameters into the chunked tensor output
......@@ -422,10 +457,21 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
size_t k = transa ? A.size(1) : A.size(0);
size_t n = _ubuf.size(0);
size_t m_chunk = m / _num_splits;
const std::vector<size_t> input_a_chunk_shape =
(transa ? std::vector<size_t>{m_chunk, k} : std::vector<size_t>{k, m_chunk});
const std::vector<size_t> output_chunk_shape = {n, m_chunk};
size_t input_a_chunk_size = m_chunk * k;
size_t output_chunk_size = n * m_chunk;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Helper function to get bias chunk if needed
auto maybe_get_bias_chunk = [this, &bias, m_chunk](size_t chunk_id) -> TensorWrapper {
if (bias.dptr() == nullptr) {
return TensorWrapper();
}
return get_tensor_chunk(bias, chunk_id * m_chunk, {m_chunk});
};
// Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
for (size_t i = 0; i < _stream_compute.size(); i++) {
......@@ -437,21 +483,23 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
if (_rs_overlap_first_gemm) {
auto input_a_chunk = get_tensor_chunk(A, 0, {m_chunk, k});
auto output_chunk = get_buffer_chunk_like(D, 0, {m, m_chunk});
auto input_a_chunk = get_tensor_chunk(A, 0, input_a_chunk_shape);
auto output_chunk = get_buffer_chunk_like(D, 0, output_chunk_shape);
auto bias_chunk = maybe_get_bias_chunk(0);
auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0]);
for (int i = 1; i < _num_splits; i++) {
input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k});
output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk});
input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, input_a_chunk_shape);
output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, output_chunk_shape);
bias_chunk = maybe_get_bias_chunk(i);
workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
......@@ -494,12 +542,13 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
}
} else {
for (int i = 0; i < _num_splits; i++) {
auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k});
auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk});
auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, input_a_chunk_shape);
auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, output_chunk_shape);
auto bias_chunk = maybe_get_bias_chunk(i);
auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
......@@ -759,8 +808,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const bool do_gelu = pre_gelu_out.numel() > 0;
size_t input_chunk_size = n_chunk * k;
size_t output_chunk_size = n_chunk * m;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
......@@ -771,8 +818,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
}
if (_aggregate) {
const int num_steps = _tp_size / 2;
input_chunk_size *= 2;
output_chunk_size *= 2;
// Chunk dims
std::vector<size_t> input_b_chunk_shape =
(transb ? std::vector<size_t>{k, 2 * n_chunk} : std::vector<size_t>{2 * n_chunk, k});
std::vector<size_t> output_chunk_shape = {2 * n_chunk, k};
size_t input_b_chunk_size = 2 * n_chunk * k;
size_t output_chunk_size = 2 * n_chunk * m;
// Initial 1X input chunk exchange between neighboring peers
int send_chunk_id = _tp_id;
......@@ -801,8 +853,9 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// GEMM
auto input_b_chunk =
get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk * 2, k});
auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk * 2, m});
get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);
auto output_chunk =
get_tensor_chunk(D, output_chunk_size * send_chunk_id, output_chunk_shape);
auto aux_chunk =
(do_gelu)
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk * 2, k})
......@@ -834,6 +887,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
}
}
} else {
// Chunk dims
std::vector<size_t> input_b_chunk_shape =
(transb ? std::vector<size_t>{k, n_chunk} : std::vector<size_t>{n_chunk, k});
std::vector<size_t> output_chunk_shape = {n_chunk, m};
size_t input_b_chunk_size = n_chunk * k;
size_t output_chunk_size = n_chunk * m;
for (int i = 0; i < _tp_size; i++) {
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
......@@ -845,8 +905,10 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
int recv_offset = comm_bytes * recv_chunk_id;
// GEMM
auto input_b_chunk = get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk, k});
auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk, m});
auto input_b_chunk =
get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);
auto output_chunk =
get_tensor_chunk(D, output_chunk_size * send_chunk_id, output_chunk_shape);
auto aux_chunk =
(do_gelu)
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk, k})
......@@ -996,6 +1058,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
auto input_b_chunk = get_tensor_chunk(B, input_b_chunk_id * input_chunk_size, {n_chunk, k});
auto output_chunk = get_buffer_chunk_by_id(D, i);
auto workspace_chunk =
get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk});
......
......@@ -12,10 +12,18 @@
#include <cudnn.h>
#include <nvrtc.h>
#include <iostream>
#include <stdexcept>
#include "../util/string.h"
#define NVTE_WARN(...) \
do { \
std::cerr << ::transformer_engine::concat_strings( \
__FILE__ ":", __LINE__, " in function ", __func__, ": ", \
::transformer_engine::concat_strings(__VA_ARGS__), "\n"); \
} while (false)
#define NVTE_ERROR(...) \
do { \
throw ::std::runtime_error(::transformer_engine::concat_strings( \
......
......@@ -452,12 +452,10 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve
~CommOverlap() {}
void set_buffer_params(py::handle quantizer);
void copy_into_buffer(const at::Tensor &input, bool local_chunk = false);
void copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk = false);
py::object get_buffer(py::handle quantizer, bool local_chunk = false,
std::optional<const std::vector<int64_t>> shape = std::nullopt);
at::Tensor get_buffer(bool local_chunk = false,
std::optional<std::vector<int64_t>> shape = std::nullopt);
}; // CommOverlap
......@@ -473,12 +471,10 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
~CommOverlapP2P() {}
void set_buffer_params(py::handle quantizer);
void copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk = false);
void copy_into_buffer(const at::Tensor &input, bool local_chunk = false);
py::object get_buffer(py::handle quantizer, bool local_chunk = false,
std::optional<const std::vector<int64_t>> shape = std::nullopt);
at::Tensor get_buffer(bool local_chunk = false,
std::optional<std::vector<int64_t>> shape = std::nullopt);
}; // CommOverlapP2P
......
......@@ -141,81 +141,79 @@ CommOverlap::CommOverlap(const std::vector<size_t> &buffer_shape, at::ScalarType
num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm,
set_sm_margin, atomic_gemm, rs_overlap_first_gemm) {}
void CommOverlap::set_buffer_params(py::handle quantizer) {
std::unique_ptr<te::pytorch::Quantizer> my_quantizer = te::pytorch::convert_quantizer(quantizer);
my_quantizer->set_quantization_params(&_ubuf);
_ubuf_scale_inv_initialized = true;
}
/*
** Helper function to copy input to _ubuf
*/
void CommOverlap::copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk) {
auto input_tensor = te::pytorch::makeTransformerEngineTensor(input, quantizer);
auto input_ptr = input_tensor.dptr() ? input_tensor.dptr() : input_tensor.columnwise_dptr();
NVTE_CHECK(input_ptr, "Input tensor does not have rowwise or columnwise data!");
char *ubuf_ptr = reinterpret_cast<char *>(_ubuf.dptr());
void CommOverlap::copy_into_buffer(const at::Tensor &input, bool local_chunk) {
const auto &input_ = input.contiguous();
// Check element size
const size_t element_size = input.element_size();
NVTE_CHECK(_ubuf.element_size() == element_size,
"Tried to copy data into a Userbuffers buffer but dtypes are not compatible ",
"(input dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(),
" bytes)");
// Input data
const size_t input_size = input_.numel();
const void *src_ptr = input_.data_ptr();
// Userbuffers data
const size_t ubuf_size = _ubuf.numel();
void *dst_ptr = _ubuf.dptr();
if (local_chunk) {
if (input_tensor.numel() * _tp_size > _ubuf.numel())
NVTE_ERROR("input is larger than the local communication buffer!");
if (input_tensor.element_size() != _ubuf.element_size())
NVTE_ERROR("input data type does not match communication buffer!");
ubuf_ptr += (_ubuf.numel() / _tp_size) * _tp_id * _ubuf.element_size();
NVTE_CHECK(input_size * _tp_size == ubuf_size,
"Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ",
"(input_size=", input_size, ", tensor_parallel_size=", _tp_size,
", ubuf_size=", ubuf_size, ")");
dst_ptr = (reinterpret_cast<char *>(dst_ptr) + (ubuf_size / _tp_size) * _tp_id * element_size);
} else {
if (input_tensor.numel() > _ubuf.numel())
NVTE_ERROR("input is larger than the global communication buffer!");
if (input_tensor.element_size() != _ubuf.element_size())
NVTE_ERROR("input data type does not match communication buffer!");
NVTE_CHECK(input_size == ubuf_size,
"Tried to copy an invalid tensor into a Userbuffers buffer ",
"(input_size=", input_size, ", ubuf_size=", ubuf_size, ")");
}
// Copy either row or columnwise data into the communication buffer's columnwise data
// NOTE: _ubuf.columnwise_dptr() is not a valid copy target because it is not registered with
// the Userbuffers communicator.
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
// Copy data
auto stream_main = at::cuda::getCurrentCUDAStream();
NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0));
NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input_tensor.dptr(),
input_tensor.numel() * input_tensor.element_size(),
NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, input_size * element_size,
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm));
}
py::object CommOverlap::get_buffer(py::handle quantizer, bool local_chunk,
std::optional<const std::vector<int64_t>> shape) {
using namespace te::pytorch;
char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.dptr());
if (local_chunk) ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size();
std::vector<int64_t> torch_shape;
if (shape.has_value()) {
torch_shape = shape.value();
size_t requested = product(torch_shape);
auto expected = local_chunk ? _ubuf.numel() / _tp_size : _ubuf.numel();
NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested,
") does not match allocated buffer size (", expected, ")!");
at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional<std::vector<int64_t>> shape) {
// Check buffer shape
const size_t ubuf_size = _ubuf.numel();
if (shape) {
const size_t requested_size = transformer_engine::pytorch::product(*shape);
if (local_chunk) {
NVTE_CHECK(requested_size * _tp_size == ubuf_size,
"Invalid shape for local chunk of a Userbuffers buffer (requested shape=", *shape,
", tensor_parallel_size=", _tp_size, ", ubuf_size=", ubuf_size, ")");
} else {
NVTE_CHECK(requested_size == ubuf_size,
"Invalid shape for a Userbuffers buffer (requested shape=", *shape,
", ubuf_size=", ubuf_size, ")");
}
} else {
int64_t output_c_dim0 = (local_chunk) ? _ubuf.size(0) / _tp_size : _ubuf.size(0);
int64_t output_c_dim1 = _ubuf.size(1);
torch_shape = {output_c_dim0, output_c_dim1};
int64_t dim0 = _ubuf.size(0);
int64_t dim1 = _ubuf.size(1);
if (local_chunk) {
dim0 /= _tp_size;
}
shape = {dim0, dim1};
}
auto ubuf_tensor = torch::from_blob(reinterpret_cast<void *>(ubuf_wt_ptr), torch_shape,
at::dtype(GetATenDType(_ubuf.dtype())).device(torch::kCUDA));
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
std::vector<size_t> te_shape;
for (auto s : torch_shape) te_shape.emplace_back(static_cast<size_t>(s));
// Always output a rowwise-only QuantizedTensor
// TODO (Alp): This needs to produce an un-interleaved transpose when required.
auto is_internal = my_quantizer->internal;
auto uses_columnwise = my_quantizer->columnwise_usage;
my_quantizer->internal = false;
my_quantizer->columnwise_usage = false;
auto [te_tensor, py_tensor] = my_quantizer->create_tensor(te_shape, _ubuf.dtype(), ubuf_tensor);
my_quantizer->internal = is_internal;
my_quantizer->columnwise_usage = uses_columnwise;
return py_tensor;
// Data pointer
void *ubuf_ptr = _ubuf.dptr();
if (local_chunk) {
ubuf_ptr = (reinterpret_cast<char *>(ubuf_ptr) +
(ubuf_size / _tp_size) * _tp_id * _ubuf.element_size());
}
// Construct PyTorch tensor
const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype());
return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
}
/***************************************************************************************************
......@@ -236,74 +234,69 @@ CommOverlapP2P::CommOverlapP2P(const std::vector<size_t> &buffer_shape, at::Scal
comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce,
atomic_gemm, aggregate) {}
void CommOverlapP2P::set_buffer_params(py::handle quantizer) {
std::unique_ptr<te::pytorch::Quantizer> my_quantizer = te::pytorch::convert_quantizer(quantizer);
my_quantizer->set_quantization_params(&_ubuf);
for (size_t i = 0; i < _ubufs.size(); i++) my_quantizer->set_quantization_params(&_ubufs[i]);
}
/*
** Copy input to _ubufs[0]
*/
void CommOverlapP2P::copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk) {
auto input_tensor = te::pytorch::makeTransformerEngineTensor(input, quantizer);
auto input_ptr = input_tensor.dptr() ? input_tensor.dptr() : input_tensor.columnwise_dptr();
NVTE_CHECK(input_ptr, "Input tensor does not have rowwise or columnwise data!");
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
void CommOverlapP2P::copy_into_buffer(const at::Tensor &input, bool local_chunk) {
const auto &input_ = input.contiguous();
// Check element size
const size_t element_size = input.element_size();
NVTE_CHECK(_ubuf.element_size() == element_size,
"Tried to copy data into a Userbuffers buffer but dtypes are not compatible ",
"(input dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(),
" bytes)");
// Input data
const size_t input_size = input_.numel();
const void *src_ptr = input_.data_ptr();
// Userbuffers data
void *dst_ptr;
if (local_chunk) {
// Copy input to the target ubuf chunk by rank offset
if (input_tensor.numel() * _tp_size > _ubuf.numel())
NVTE_ERROR("input is larger than the local communication buffer!");
if (input_tensor.element_size() != _ubuf.element_size())
NVTE_ERROR("input data type does not match communication buffer!");
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input_ptr,
input_tensor.numel() * input_tensor.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main));
NVTE_CHECK(_ubufs[_tp_id].numel() == input_size,
"Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ",
"(input_size=", input_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")");
dst_ptr = _ubufs[_tp_id].dptr();
} else {
if (input_tensor.numel() > _ubuf.numel())
NVTE_ERROR("input is larger than the global communication buffer!");
if (input_tensor.element_size() != _ubuf.element_size())
NVTE_ERROR("input data type does not match communication buffer!");
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input_ptr,
input_tensor.numel() * input_tensor.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main));
NVTE_CHECK(_ubuf.numel() == input_size,
"Tried to copy an invalid tensor into a Userbuffers buffer ",
"(input_size=", input_size, ", ubuf_size=", _ubuf.numel(), ")");
dst_ptr = _ubuf.dptr();
}
// Copy data
NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, input_size * element_size,
cudaMemcpyDeviceToDevice,
(cudaStream_t)at::cuda::getCurrentCUDAStream()));
}
py::object CommOverlapP2P::get_buffer(py::handle quantizer, bool local_chunk,
std::optional<const std::vector<int64_t>> shape) {
using namespace te::pytorch;
char *ubuf_wt_ptr = reinterpret_cast<char *>(local_chunk ? _ubufs[_tp_id].dptr() : _ubuf.dptr());
std::vector<int64_t> torch_shape;
if (shape.has_value()) {
torch_shape = shape.value();
size_t requested = product(torch_shape);
auto expected = local_chunk ? _ubufs[_tp_id].numel() : _ubuf.numel();
NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested,
") does not match allocated buffer size (", expected, ")!");
at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional<std::vector<int64_t>> shape) {
// Check buffer shape
if (shape) {
const size_t requested_size = transformer_engine::pytorch::product(*shape);
if (local_chunk) {
NVTE_CHECK(requested_size == _ubufs[_tp_id].numel(),
"Invalid shape for local chunk of a Userbuffers buffer (requested shape=", *shape,
", local_ubuf_size=", _ubufs[_tp_id].numel(), ")");
} else {
NVTE_CHECK(requested_size == _ubuf.numel(),
"Invalid shape for a Userbuffers buffer (requested shape=", *shape,
", ubuf_size=", _ubuf.numel(), ")");
}
} else {
int64_t output_c_dim0 = (local_chunk) ? _ubuf.size(0) / _tp_size : _ubuf.size(0);
int64_t output_c_dim1 = _ubuf.size(1);
torch_shape = {output_c_dim0, output_c_dim1};
int64_t dim0 = _ubuf.size(0);
int64_t dim1 = _ubuf.size(1);
if (local_chunk) {
dim0 /= _tp_size;
}
shape = {dim0, dim1};
}
auto ubuf_tensor = torch::from_blob(reinterpret_cast<void *>(ubuf_wt_ptr), torch_shape,
at::dtype(GetATenDType(_ubuf.dtype())).device(torch::kCUDA));
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
std::vector<size_t> te_shape;
for (auto s : torch_shape) te_shape.emplace_back(static_cast<size_t>(s));
// Always output a rowwise-only QuantizedTensor
// TODO (Alp): This needs to produce an un-interleaved transpose when required.
auto is_internal = my_quantizer->internal;
auto uses_columnwise = my_quantizer->columnwise_usage;
my_quantizer->internal = false;
my_quantizer->columnwise_usage = false;
auto [te_tensor, py_tensor] = my_quantizer->create_tensor(te_shape, _ubuf.dtype(), ubuf_tensor);
my_quantizer->internal = is_internal;
my_quantizer->columnwise_usage = uses_columnwise;
return py_tensor;
// Data pointer
void *ubuf_ptr = local_chunk ? _ubufs[_tp_id].dptr() : _ubuf.dptr();
// Construct PyTorch tensor
const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype());
return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
}
......@@ -360,10 +360,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true,
py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false)
.def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"),
py::arg("quantizer"), py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlap::get_buffer, py::arg("quantizer"),
py::arg("local_chunk") = false, py::arg("shape") = std::nullopt)
.def("set_buffer_params", &CommOverlap::set_buffer_params);
py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false,
py::arg("shape") = std::nullopt);
py::class_<CommOverlapP2P, std::shared_ptr<CommOverlapP2P>,
transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>(
......@@ -378,8 +377,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false,
py::arg("use_ce") = true, py::arg("aggregate") = false)
.def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"),
py::arg("quantizer"), py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("quantizer"),
py::arg("local_chunk") = false, py::arg("shape") = std::nullopt)
.def("set_buffer_params", &CommOverlapP2P::set_buffer_params);
py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false,
py::arg("shape") = std::nullopt);
}
......@@ -919,8 +919,10 @@ def _all_gather_fp8(
# Note: We cannot directly all-gather the transposed FP8 tensor,
# so temporarily modify quantizer to avoid creating FP8 transpose.
if not isinstance(inp, Float8TensorBase):
if quantizer is None:
raise ValueError("Input tensor is not FP8 and no quantizer was provided")
assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer))
# we cannot directly gather the transposed fp8 tensor
# so we need to disable columnwise usage for the quantizer
# and then set it back to the original value after quantizing
init_rowwise_usage = quantizer.rowwise_usage
init_columnwise_usage = quantizer.columnwise_usage
quantizer.set_usage(rowwise=True, columnwise=False)
......@@ -1063,11 +1065,11 @@ def _all_gather_mxfp8(
dtype = inp.dtype
elif isinstance(inp, MXFP8TensorBase):
if inp._rowwise_data is not None:
in_shape = inp._rowwise_data.device.size()
in_shape = inp._rowwise_data.size()
device = inp._rowwise_data.device
dtype = inp._rowwise_data.dtype
elif inp._columnwise_data is not None:
in_shape = inp._columnwise_data.device.size()
in_shape = inp._columnwise_data.size()
device = inp._columnwise_data.device
dtype = inp._columnwise_data.dtype
else:
......
......@@ -6,8 +6,6 @@
from typing import Any, List, Optional, Tuple, Union, Callable
from dataclasses import dataclass
from functools import reduce
from operator import mul as multiply_op
import queue
import torch
......@@ -15,7 +13,6 @@ import torch
from .. import cpp_extensions as tex
from ..constants import TE_DType
from ..utils import get_default_init_method
from ..tensor.float8_tensor import Float8Tensor
def _get_normalization_func(normalization: str, forward: bool):
......@@ -33,39 +30,6 @@ def _get_normalization_func(normalization: str, forward: bool):
return bwd_normalization_funcs[normalization]
def _fix_gathered_fp8_transpose(fp8_tensor: Float8Tensor, tp_size: int) -> Float8Tensor:
"""Reorder FP8 transposes after Userbuffers gather.
The all-gather is performed in-place in the Float8Tensor's
row-wise data, and afterwards we need to do a transpose to get the
correct ordering. This misuses data fields in Float8Tensor and
should be considered an evil hack. It would be best to move
transpose logic into CommOverlap::get_buffer.
Responsibility for fixing: adener, tmoon
"""
assert isinstance(fp8_tensor, Float8Tensor), "Tensor is not a Float8Tensor"
assert tp_size > 1, "The tensor transpose cannot be interleaved when TP size is 1"
assert fp8_tensor._data is not None, "The tensor does not hold any rowwise data"
assert (
fp8_tensor._data.shape[0] % tp_size == 0
), "Leading dimension of data is not divisble by TP size"
data = fp8_tensor._data
batched_size = reduce(multiply_op, data.shape[1:])
interleaved_shape = [tp_size, data.shape[0] // tp_size, batched_size]
transposed_shape = [data.shape[0] // tp_size, batched_size * tp_size]
fp8_tensor._transpose = (
data.view(interleaved_shape).transpose(0, 1).contiguous().view(transposed_shape)
)
fp8_tensor._transpose_invalid = False
fp8_tensor._data = None
return fp8_tensor
def apply_normalization(
inputmat: torch.Tensor,
ln_out: torch.Tensor,
......
......@@ -4,6 +4,7 @@
"""Base modules and utilities for TransformerEngine PyTorch API"""
import io
import math
import os
import pickle
import warnings
......@@ -35,7 +36,9 @@ from ..distributed import (
_fsdp_gather_tensors,
)
from ..constants import dist_group_type
from ..tensor import QuantizedTensor, Quantizer
from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorBase, Quantizer
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
......@@ -418,6 +421,142 @@ def destroy_ub():
layers_atomic_ring_exchange = []
def fill_userbuffers_buffer_for_all_gather(
comm,
local_tensor: torch.Tensor,
quantizer: Optional[Quantizer],
process_group,
) -> tuple[torch.Tensor | QuantizedTensorBase, torch.Tensor | QuantizedTensorBase]:
"""Fill local shard of Userbuffers buffer with data for all-gather
Returns the full tensor and the local shard, both using the
Userbuffers buffer as their underlying data. These tensors should
be used carefully (e.g. only immediately before and after a
Userbuffers operation) since the underlying data may be
overwritten by other Userbuffers operations.
May perform blocking communication if needed for the gathered
tensor's metadata, e.g. scaling factors.
"""
# Tensor dimensions
local_shape = local_tensor.size()
if not local_shape:
raise ValueError(f"Invalid local tensor (shape={tuple(local_shape)})")
process_group_size = torch.distributed.get_world_size(process_group)
global_shape = list(local_shape)
global_shape[0] *= process_group_size
# Unquantized data
if quantizer is None:
if isinstance(local_tensor, QuantizedTensorBase):
local_tensor = local_tensor.dequantize()
if comm.is_fp8_ubuf():
raise RuntimeError(
"Attempting to all-gather unquantized tensor, "
"but Userbuffers is initialized with FP8 buffers"
)
comm.copy_into_buffer(local_tensor, local_chunk=True)
global_tensor = comm.get_buffer(shape=global_shape)
return global_tensor, local_tensor
# FP8 data
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
if not isinstance(local_tensor, Float8TensorBase):
if isinstance(local_tensor, QuantizedTensorBase):
local_tensor.dequantize()
quantizer.set_usage(rowwise=True, columnwise=False)
local_tensor = quantizer(local_tensor)
if not comm.is_fp8_ubuf():
raise RuntimeError(
"Attempting to all-gather FP8 tensor, "
"but Userbuffers is not initialized with FP8 buffers"
)
comm.copy_into_buffer(local_tensor._data, local_chunk=True)
global_tensor_data = comm.get_buffer(shape=global_shape)
global_tensor = Float8TensorBase(
data=global_tensor_data,
fp8_scale_inv=local_tensor._scale_inv,
fp8_dtype=local_tensor._fp8_dtype,
quantizer=quantizer,
)
return global_tensor, local_tensor
# MXFP8 data
if isinstance(quantizer, MXFP8Quantizer):
# Cast to MXFP8 if needed
if not isinstance(local_tensor, MXFP8TensorBase):
if isinstance(local_tensor, QuantizedTensorBase):
local_tensor.dequantize()
local_tensor = quantizer(local_tensor)
if not comm.is_fp8_ubuf():
raise RuntimeError(
"Attempting to all-gather MXFP8 tensor, "
"but Userbuffers is not initialized with FP8 buffers"
)
# Check which MXFP8 buffer to communicate
if quantizer.rowwise_usage == quantizer.columnwise_usage:
raise ValueError(
"Userbuffers can only communicate one MXFP8 buffer at a time, "
f"but quantizer has rowwise_usage={quantizer.rowwise_usage}, "
f"columnwise_usage={quantizer.columnwise_usage}"
)
with_rowwise_data = quantizer.rowwise_usage
# Copy MXFP8 data to local chunk of Userbuffers buffer
local_data = (
local_tensor._rowwise_data if with_rowwise_data else local_tensor._columnwise_data
)
comm.copy_into_buffer(local_data, local_chunk=True)
# Gather scaling-inverses
if math.prod(local_shape[:-1]) % 128 != 0:
raise ValueError(
"Userbuffers requires MXFP8 tensor dims that are divisible by 128, "
f"but got MXFP8 tensor with shape={tuple(local_shape)}"
)
local_scale_inv = (
local_tensor._rowwise_scale_inv
if with_rowwise_data
else local_tensor._columnwise_scale_inv
)
local_scale_inv_size = list(local_scale_inv.size())
global_scale_inv = torch.empty(
[process_group_size * local_scale_inv_size[0]] + local_scale_inv_size[1:],
dtype=local_scale_inv.dtype,
device=local_scale_inv.device,
)
torch.distributed.all_gather_into_tensor(
global_scale_inv,
local_scale_inv,
group=process_group,
)
# Construct MXFP8 tensor with Userbuffers buffer
rowwise_data, rowwise_scale_inv = None, None
columnwise_data, columnwise_scale_inv = None, None
global_data = comm.get_buffer(shape=global_shape)
if with_rowwise_data:
rowwise_data, rowwise_scale_inv = global_data, global_scale_inv
else:
columnwise_data, columnwise_scale_inv = global_data, global_scale_inv
global_tensor = MXFP8TensorBase(
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
fp8_dtype=local_tensor._fp8_dtype,
quantizer=quantizer,
)
return global_tensor, local_tensor
# Unsupported data format
raise ValueError(f"Unsupported quantizer for Userbuffers ({quantizer})")
class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module."""
......@@ -866,11 +1005,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Non-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8 and not ctx.debug:
if gather_grad_output:
if not ctx.ub_overlap_ag:
if not ctx.ub_overlap_ag: # Perform NCCL all-gather
grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group)
else:
ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True)
grad_output = ctx.ub_obj_gradout.get_buffer(quantizer)
else: # Initialize Userbuffers all-gather
grad_output, _ = fill_userbuffers_buffer_for_all_gather(
ctx.ub_obj_gradout,
grad_output,
None,
ctx.tp_group,
)
return grad_output, None
# FP8 with all-gather: unfused bgrad, fused cast + transpose
......@@ -893,8 +1036,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_output = quantizer(grad_output)
# Copy into communication buffer, and replace original gradient with it
ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True)
grad_output = ctx.ub_obj_gradout.get_buffer(quantizer)
grad_output, _ = fill_userbuffers_buffer_for_all_gather(
ctx.ub_obj_gradout,
grad_output,
quantizer,
ctx.tp_group,
)
else:
grad_output, _ = gather_along_first_dim(
grad_output,
......@@ -1165,7 +1312,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
return
with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"):
(wgrad, grad_bias_, _, _), _ = self.wgrad_store.pop()
(wgrad, bgrad), _ = self.wgrad_store.pop()
if not self.fuse_wgrad_accumulation:
unfused_weights = [getattr(self, name) for name in self.weight_names]
weight_tensor = noop_cat(unfused_weights)
......@@ -1174,9 +1321,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
if bias_tensor.grad is None:
bias_tensor.grad = grad_bias_.to(bias_tensor.dtype)
del grad_bias_
del wgrad
bias_tensor.grad = bgrad.to(bias_tensor.dtype)
def _validate_name(self):
"""
......
This diff is collapsed.
......@@ -534,7 +534,9 @@ class BasicLinear(BasicOperation):
# Configure input tensor for backward pass
if with_quantized_compute and isinstance(x_local, QuantizedTensor):
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
return y, x_local, w
......@@ -622,7 +624,10 @@ class BasicLinear(BasicOperation):
# Check datatype
if dtype is None:
dtype = weight.dtype
if weight is not None:
dtype = weight.dtype
else:
dtype = grad_output.dtype
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
......@@ -814,7 +819,7 @@ class BasicLinear(BasicOperation):
x_async = None
dy_async = None
# Check grad input tensor
# Check grad weight tensor
dw = grad_weight
dw_dtype = dtype
if dw is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment