Unverified Commit ce0b46c4 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

MXFP8 support in Userbuffers (#1711)



* Initial work toward restoring UB support in te.Sequential
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Forward UB linear runs, but has numerical error
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug UB forward tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Minor tweaks
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove Python checks for MXFP8 UB linear forward
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add dim check for MXFP8 full tiles
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move QuantizedTensor logic out of UB comm and into Python helper function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support MXFP8 AGs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Coalesce NCCL all-gathers for MXFP8 all-gather
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Initial impl of backward UB linear in te.Sequential
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug UB linear backward with no quantization

dgrad GEMM + dx RS is still broken.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix chunk dims for dgrad GEMM + dx RS
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debugging MXFP8 UB cases

Still failing with dy AG + wgrad GEMM
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use NCCL to overlap dy AG with dgrad GEMM
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug UB GEMM tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Initial refactoring of linear module forward
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactor linear module backward
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug linear module UB tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tweak test tensor dims
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Do not store autograd context within wgrad GEMM closure
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update LayerNormLinear
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update LayerNormMLP
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug UB tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Minor style tweaks
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix incorrect usage for GEMM input with block-scaled FP8
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix RS out dims
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable dgrad GEMM + UB AG + NCCL AG overlapping
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Disable dgrad GEMM + UB AG + NCCL AG overlap in te.Sequential
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Restore support for internal quantized tensors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add tests for MXFP8 GEMM with UB
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 3d2152c2
...@@ -26,7 +26,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PAT ...@@ -26,7 +26,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PAT
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
# python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
......
...@@ -20,7 +20,11 @@ from torch.distributed.elastic.multiprocessing.errors import record ...@@ -20,7 +20,11 @@ from torch.distributed.elastic.multiprocessing.errors import record
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_cublas_workspace_size_bytes from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.module.base import (
fill_userbuffers_buffer_for_all_gather,
get_cublas_workspace_size_bytes,
)
warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=FutureWarning)
...@@ -56,7 +60,11 @@ def _parse_args(argv=None, namespace=None): ...@@ -56,7 +60,11 @@ def _parse_args(argv=None, namespace=None):
) )
parser.add_argument("--seed", type=int, default=42, help="RNG seed.") parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
parser.add_argument( parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." "--quantization",
type=str.lower,
default="none",
choices=["none", "fp8", "mxfp8"],
help="Quantization recipe",
) )
parser.add_argument( parser.add_argument(
"--fp8-output", action="store_true", default=False, help="Get FP8 output from GEMM." "--fp8-output", action="store_true", default=False, help="Get FP8 output from GEMM."
...@@ -154,9 +162,9 @@ def _parse_args(argv=None, namespace=None): ...@@ -154,9 +162,9 @@ def _parse_args(argv=None, namespace=None):
if opts.atomic: if opts.atomic:
warnings.warn("Atomic GEMM is not supported with bulk overlap.") warnings.warn("Atomic GEMM is not supported with bulk overlap.")
opts.atomic = False opts.atomic = False
if opts.fp8: if opts.quantization != "none":
warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.") warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.")
opts.fp8 = False opts.quantization = "none"
elif opts.comm_type == tex.CommOverlapType.AG: elif opts.comm_type == tex.CommOverlapType.AG:
if opts.atomic: if opts.atomic:
setattr(opts, "atomic_rs_p2p", opts.p2p) setattr(opts, "atomic_rs_p2p", opts.p2p)
...@@ -164,8 +172,11 @@ def _parse_args(argv=None, namespace=None): ...@@ -164,8 +172,11 @@ def _parse_args(argv=None, namespace=None):
if opts.atomic: if opts.atomic:
if not te.fp8.check_fp8_support(): if not te.fp8.check_fp8_support():
assert not opts.fp8, "Atomic GEMM is only supported in FP8." assert opts.quantization == "none", "Atomic GEMM is only supported in FP8."
opts.fp8 = True opts.quantization = "fp8"
if opts.fp8_output:
assert ops.quantization == "fp8", "FP8 output is only supported with FP8 compute."
return opts return opts
...@@ -302,7 +313,11 @@ def _main(opts): ...@@ -302,7 +313,11 @@ def _main(opts):
inp_shape = (opts.seq_length, opts.batch_size, hidden_size) inp_shape = (opts.seq_length, opts.batch_size, hidden_size)
outer_size = reduce(operator.mul, inp_shape[:-1], 1) outer_size = reduce(operator.mul, inp_shape[:-1], 1)
buffer_dtype = torch.bfloat16 buffer_dtype = torch.bfloat16
if opts.fp8 and not opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.AG: if (
opts.quantization != "none"
and not opts.bulk_overlap
and opts.comm_type == tex.CommOverlapType.AG
):
buffer_dtype = torch.uint8 buffer_dtype = torch.uint8
ub_obj = ( ub_obj = (
tex.CommOverlapP2P( tex.CommOverlapP2P(
...@@ -447,6 +462,8 @@ def _main(opts): ...@@ -447,6 +462,8 @@ def _main(opts):
inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable
ref2_g = torch.matmul(inp2_g, ker2_g) ref2_g = torch.matmul(inp2_g, ker2_g)
# Initialize quantizers
with_quantized_compute = opts.quantization != "none"
inp_quantizer = None inp_quantizer = None
ker_quantizer = None ker_quantizer = None
out_quantizer = None out_quantizer = None
...@@ -454,7 +471,7 @@ def _main(opts): ...@@ -454,7 +471,7 @@ def _main(opts):
inp2_quantizer = None inp2_quantizer = None
ker2_quantizer = None ker2_quantizer = None
out2_quantizer = None out2_quantizer = None
if opts.fp8: if opts.quantization == "fp8":
# Structure to maintain amax and scale/scale_inv information for the kernel and input # Structure to maintain amax and scale/scale_inv information for the kernel and input
num_gemms = 6 if ub_obj2 is not None else 3 num_gemms = 6 if ub_obj2 is not None else 3
fp8_dtype = tex.DType.kFloat8E4M3 fp8_dtype = tex.DType.kFloat8E4M3
...@@ -499,11 +516,23 @@ def _main(opts): ...@@ -499,11 +516,23 @@ def _main(opts):
out2_quantizer = Float8Quantizer( out2_quantizer = Float8Quantizer(
fp8_scales[5].clone(), fp8_amaxes[5].clone(), fp8_dtype fp8_scales[5].clone(), fp8_amaxes[5].clone(), fp8_dtype
) )
elif opts.quantization == "mxfp8":
fp8_dtype = tex.DType.kFloat8E4M3
inp_quantizer = MXFP8Quantizer(fp8_dtype, columnwise=False)
ker_quantizer = MXFP8Quantizer(fp8_dtype)
if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
bulk_inp_quantizer = MXFP8Quantizer(fp8_dtype, columnwise=False)
elif ub_obj2 is not None:
inp2_quantizer = MXFP8Quantizer(fp8_dtype, columnwise=False)
ker2_quantizer = MXFP8Quantizer(fp8_dtype)
# Quantize tensors
if with_quantized_compute:
# Cast input to Float8Tensor # Quantize input tensor
inp_fp8 = inp_quantizer(inp) inp_fp8 = inp_quantizer(inp)
# Cast kernel to Float8Tensor # Quantize kernel tensor
kernel_t_fp8 = ker_quantizer(kernel_t) kernel_t_fp8 = ker_quantizer(kernel_t)
if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
bulk_inp_fp8 = bulk_inp_quantizer(bulk_inp) bulk_inp_fp8 = bulk_inp_quantizer(bulk_inp)
...@@ -540,31 +569,40 @@ def _main(opts): ...@@ -540,31 +569,40 @@ def _main(opts):
) )
# Set up comm/compute buffers # Set up comm/compute buffers
ag_out = None
rs_out = None rs_out = None
rs_out2 = None rs_out2 = None
if opts.comm_type == tex.CommOverlapType.AG: if opts.comm_type == tex.CommOverlapType.AG:
if opts.bulk_overlap: if opts.bulk_overlap:
ub_obj.copy_into_buffer(bulk_inp, bulk_inp_quantizer, True) ag_out, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj,
bulk_inp,
bulk_inp_quantizer,
tp_group,
)
gemm_inp = inp gemm_inp = inp
else: else:
ub_obj.copy_into_buffer(inp_fp8 if opts.fp8 else inp, inp_quantizer, True) ag_out, _ = fill_userbuffers_buffer_for_all_gather(
gemm_inp = ub_obj.get_buffer(inp_quantizer, False, inp_g.size()) ub_obj,
inp_fp8 if with_quantized_compute else inp,
inp_quantizer,
tp_group,
)
gemm_inp = ag_out
if ub_obj2 is not None: if ub_obj2 is not None:
if opts.fp8 and opts.fp8_output:
ub_obj2.set_buffer_params(out_quantizer)
rs_out2 = torch.empty( rs_out2 = torch.empty(
(outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda" (outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda"
) )
else: else:
if opts.bulk_overlap: if opts.bulk_overlap:
ub_obj.copy_into_buffer( if opts.quantization == "none":
bulk_inp_fp8 if opts.fp8 else bulk_inp, bulk_inp_quantizer, False ub_obj.copy_into_buffer(bulk_inp, local_chunk=False)
) if opts.quantization == "fp8":
if opts.fp8: ub_obj.copy_into_buffer(bulk_inp_fp8._data, local_chunk=False)
ub_obj.set_buffer_params(bulk_inp_quantizer) elif opts.quantization == "mxfp8":
elif opts.fp8 and opts.fp8_output: ub_obj.copy_into_buffer(bulk_inp_fp8._rowwise_data, local_chunk=False)
ub_obj.set_buffer_params(out_quantizer)
gemm_inp = inp_fp8 if opts.fp8 else inp gemm_inp = inp_fp8 if with_quantized_compute else inp
rs_out = torch.empty( rs_out = torch.empty(
(outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda" (outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda"
) )
...@@ -623,7 +661,7 @@ def _main(opts): ...@@ -623,7 +661,7 @@ def _main(opts):
if opts.use_cuda_graphs: if opts.use_cuda_graphs:
# Trace the CUDA graph first # Trace the CUDA graph first
g = torch.cuda.CUDAGraph() g = torch.cuda.CUDAGraph()
if opts.fp8: if with_quantized_compute:
if ub_obj is None: if ub_obj is None:
with torch.cuda.graph(g): with torch.cuda.graph(g):
all_outputs = _fp8_gemm() all_outputs = _fp8_gemm()
...@@ -643,7 +681,7 @@ def _main(opts): ...@@ -643,7 +681,7 @@ def _main(opts):
else: else:
for i in range(total_iters): for i in range(total_iters):
if opts.fp8: if with_quantized_compute:
start_events[i].record() start_events[i].record()
all_outputs = _fp8_gemm() all_outputs = _fp8_gemm()
end_events[i].record() end_events[i].record()
...@@ -688,10 +726,22 @@ def _main(opts): ...@@ -688,10 +726,22 @@ def _main(opts):
output_info = "" output_info = ""
if opts.comm_type == tex.CommOverlapType.AG: if opts.comm_type == tex.CommOverlapType.AG:
# Bulk overlap AG output is already gathered # Bulk overlap AG output is already gathered
test_out = ub_obj.get_buffer(bulk_inp_quantizer, False) test_out = ag_out
if bulk_inp_quantizer is None:
test_out = ub_obj.get_buffer(False)
else:
test_out = Float8Tensor(
shape=test_out.shape,
dtype=torch.bfloat16,
data=ub_obj.get_buffer(False),
fp8_scale=bulk_inp_quantizer.scale,
fp8_dtype=bulk_inp_quantizer.dtype,
quantizer=bulk_inp_quantizer,
)
else: else:
# Bulk overlap RS output needs to be gathered # Bulk overlap RS output needs to be gathered
out_local = ub_obj.get_buffer(bulk_inp_quantizer, True) out_local = ub_obj.get_buffer(True)
output_info += f"rs_output: {list(out_local.shape)} | " output_info += f"rs_output: {list(out_local.shape)} | "
test_out = te.distributed.gather_along_first_dim(out_local, tp_group)[0] test_out = te.distributed.gather_along_first_dim(out_local, tp_group)[0]
...@@ -762,8 +812,8 @@ def _main(opts): ...@@ -762,8 +812,8 @@ def _main(opts):
m = torch.argmax(diff) m = torch.argmax(diff)
abs_err = diff[m].item() abs_err = diff[m].item()
rel_err = abs_err / max(abs(ref_out.flatten()[m].item()), 1e-5) rel_err = abs_err / max(abs(ref_out.flatten()[m].item()), 1e-5)
rtol = 0.125 if opts.fp8 else 0.02 rtol = 0.02 if opts.quantization == "none" else 0.125
atol = 0.0625 if opts.fp8 else 0.001 atol = 0.001 if opts.quantization == "none" else 0.0625
if rel_err > rtol and abs_err > atol: if rel_err > rtol and abs_err > atol:
numerics_failed = True numerics_failed = True
numerics_info = ( numerics_info = (
......
...@@ -17,7 +17,12 @@ import torch ...@@ -17,7 +17,12 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling, Float8CurrentScaling from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
Format,
MXFP8BlockScaling,
)
warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=FutureWarning)
...@@ -163,7 +168,7 @@ def _parse_args(argv=None, namespace=None): ...@@ -163,7 +168,7 @@ def _parse_args(argv=None, namespace=None):
"--quantization", "--quantization",
type=str.lower, type=str.lower,
default="none", default="none",
choices=["none", "fp8_delayed_scaling", "fp8_current_scaling"], choices=["none", "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"],
help="Quantization recipe", help="Quantization recipe",
) )
parser.add_argument( parser.add_argument(
...@@ -414,6 +419,8 @@ def _train(opts): ...@@ -414,6 +419,8 @@ def _train(opts):
) )
elif opts.quantization == "fp8_current_scaling": elif opts.quantization == "fp8_current_scaling":
fp8_recipe = Float8CurrentScaling(fp8_format=fp8_format) fp8_recipe = Float8CurrentScaling(fp8_format=fp8_format)
elif opts.quantization == "mxfp8":
fp8_recipe = MXFP8BlockScaling()
# Prepare random input tensors # Prepare random input tensors
test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True) test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True)
......
...@@ -15,11 +15,12 @@ if torch.cuda.device_count() < 2: ...@@ -15,11 +15,12 @@ if torch.cuda.device_count() < 2:
pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.") pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
RNG_SEED: int = 42 RNG_SEED: int = 42
SEQ_LENGTH: int = 1024 SEQ_LENGTH: int = 1024
BATCH_SIZE: int = 2 BATCH_SIZE: int = 2
NUM_HEADS: int = 16 NUM_HEADS: int = 32
HEAD_DIM: int = 48 HEAD_DIM: int = 48
TE_LAYERS = [ TE_LAYERS = [
te.Linear, te.Linear,
...@@ -50,7 +51,7 @@ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" ...@@ -50,7 +51,7 @@ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
torch._dynamo.reset() torch._dynamo.reset()
def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8): def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, quantization):
test_path = TEST_ROOT / "run_gemm_with_overlap.py" test_path = TEST_ROOT / "run_gemm_with_overlap.py"
test_cmd = LAUNCH_CMD + [ test_cmd = LAUNCH_CMD + [
str(test_path), str(test_path),
...@@ -66,10 +67,11 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8): ...@@ -66,10 +67,11 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8):
if bulk: if bulk:
test_cmd.append("--bulk-overlap") test_cmd.append("--bulk-overlap")
else: else:
if fp8: if quantization == "fp8" and not fp8_available:
if not fp8_available: pytest.skip(reason_for_no_fp8)
pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available:
test_cmd.append("--fp8") pytest.skip(reason_for_no_mxfp8)
test_cmd.append(f"--quantization={quantization}")
if p2p: if p2p:
test_cmd.append("--p2p") test_cmd.append("--p2p")
if atomic: if atomic:
...@@ -107,8 +109,10 @@ def _run_layer_with_overlap( ...@@ -107,8 +109,10 @@ def _run_layer_with_overlap(
test_cmd.append("--overlap-rs-dgrad") test_cmd.append("--overlap-rs-dgrad")
if fp8: if fp8:
if not fp8_available: if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
test_cmd.append("--fp8") test_cmd.append("--fp8")
test_cmd.append(f"--quantization={quantization}") test_cmd.append(f"--quantization={quantization}")
...@@ -130,51 +134,34 @@ def _run_layer_with_overlap( ...@@ -130,51 +134,34 @@ def _run_layer_with_overlap(
raise AssertionError(result.stderr.decode()) raise AssertionError(result.stderr.decode())
@pytest.mark.parametrize( @pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8"))
"fp8", def test_split_all_gather_overlaps(quantization):
(False, True),
ids=[" BF16 - RING-EXCHANGE ", " FP8 - RING-EXCHANGE "],
)
def test_split_all_gather_overlaps(fp8):
""" """
Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm. te.cpp_extensions.fp8_gemm.
""" """
_run_gemm_with_overlap("AG", False, True, False, fp8) _run_gemm_with_overlap("AG", False, True, False, quantization)
@pytest.mark.parametrize( @pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8"))
"fp8,p2p", @pytest.mark.parametrize("p2p", (False, True))
[ def test_split_reduce_scatter_overlaps(quantization, p2p):
(False, False),
(False, True),
(True, False),
(True, True),
],
ids=[
" BF16 - PIPELINE ",
" BF16 - RING-EXCHANGE ",
" FP8 - PIPELINE ",
" FP8 - RING-EXCHANGE ",
],
)
def test_split_reduce_scatter_overlaps(fp8, p2p):
""" """
Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm. te.cpp_extensions.fp8_gemm.
""" """
_run_gemm_with_overlap("RS", False, p2p, False, fp8) _run_gemm_with_overlap("RS", False, p2p, False, quantization)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"comm_type, fp8, connections", "comm_type, quantization, connections",
[ [
("AG", False, 1), ("AG", "none", 1),
("RS", False, 1), ("RS", "none", 1),
("RS", True, 1), ("RS", "fp8", 1),
("AG", False, 8), ("AG", "none", 8),
("RS", False, 8), ("RS", "none", 8),
("RS", True, 8), ("RS", "fp8", 8),
], ],
ids=[ ids=[
"ALL-GATHER - BF16 - 1 connections", "ALL-GATHER - BF16 - 1 connections",
...@@ -185,7 +172,7 @@ def test_split_reduce_scatter_overlaps(fp8, p2p): ...@@ -185,7 +172,7 @@ def test_split_reduce_scatter_overlaps(fp8, p2p):
"REDUCE-SCATTER - FP8 - 8 connections", "REDUCE-SCATTER - FP8 - 8 connections",
], ],
) )
def test_bulk_overlaps(comm_type, fp8, connections): def test_bulk_overlaps(comm_type, quantization, connections):
""" """
Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm.
""" """
...@@ -196,10 +183,10 @@ def test_bulk_overlaps(comm_type, fp8, connections): ...@@ -196,10 +183,10 @@ def test_bulk_overlaps(comm_type, fp8, connections):
" 9.0 (HOPPER ARCH)." " 9.0 (HOPPER ARCH)."
) )
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
_run_gemm_with_overlap(comm_type, True, False, False, fp8) _run_gemm_with_overlap(comm_type, True, False, False, quantization)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
else: else:
_run_gemm_with_overlap(comm_type, True, False, False, fp8) _run_gemm_with_overlap(comm_type, True, False, False, quantization)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -251,15 +238,7 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d ...@@ -251,15 +238,7 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
@pytest.mark.parametrize( @pytest.mark.parametrize(
"quantization", "quantization",
["fp8_delayed_scaling", "fp8_current_scaling"], ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"],
ids=[" DELAYED SCALING ", " CURRENT SCALING "],
)
@pytest.mark.parametrize(
"fp8",
(True,),
ids=[
" FP8 ",
],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"layer_type,linear_parallel_mode,overlap_rs_dgrad", "layer_type,linear_parallel_mode,overlap_rs_dgrad",
...@@ -279,15 +258,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d ...@@ -279,15 +258,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
) )
), ),
ids=[ ids=[
f" {te.Linear.__name__} - ROW-PARALLEL ", f"{te.Linear.__name__}-row_tensor_parallel",
f" {te.Linear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", f"{te.Linear.__name__}-col_tensor_parallel-BULK DGRAD/WGRAD",
f" {te.Linear.__name__} - COL-PARLALEL - DGRAD+RS ", f"{te.Linear.__name__}-col_tensor_parallel-DGRAD+RS",
f" {te.LayerNormLinear.__name__} - ROW-PARALLEL ", f"{te.LayerNormLinear.__name__}-row_tensor_parallel",
f" {te.LayerNormLinear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", f"{te.LayerNormLinear.__name__}-col_tensor_parallel-BULK DGRAD/WGRAD",
f" {te.LayerNormLinear.__name__} - COL-PARALLEL - DGRAD+RS ", f"{te.LayerNormLinear.__name__}-col_tensor_parallel-DGRAD+RS",
] ]
+ [ + [
" " + " - ".join(test_name_parts) + " " "-".join(test_name_parts)
for test_name_parts in zip( for test_name_parts in zip(
[layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)], [layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)],
["BULK DGRAD/WGRAD", "DGRAD+RS"] * len(TE_LAYERS[2:]), ["BULK DGRAD/WGRAD", "DGRAD+RS"] * len(TE_LAYERS[2:]),
...@@ -295,12 +274,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d ...@@ -295,12 +274,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
], ],
) )
def test_layers_with_overlap_fp8( def test_layers_with_overlap_fp8(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization layer_type,
linear_parallel_mode,
overlap_rs_dgrad,
quantization,
): ):
""" """
Test Transformer Engine layers with comm+GEMM overlap. Test Transformer Engine layers with comm+GEMM overlap.
""" """
_run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization) _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -347,22 +329,11 @@ def test_multi_layer_with_overlap_bf16( ...@@ -347,22 +329,11 @@ def test_multi_layer_with_overlap_bf16(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"quantization", "quantization",
["fp8_delayed_scaling", "fp8_current_scaling"], ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"],
ids=[" DELAYED SCALING ", " CURRENT SCALING "],
)
@pytest.mark.parametrize(
"fp8",
(True,),
ids=[
" FP8 ",
],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_layers", "num_layers",
(2,), (2,),
ids=[
" 2 layers ",
],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"layer_type,linear_parallel_mode,overlap_rs_dgrad", "layer_type,linear_parallel_mode,overlap_rs_dgrad",
...@@ -374,7 +345,7 @@ def test_multi_layer_with_overlap_bf16( ...@@ -374,7 +345,7 @@ def test_multi_layer_with_overlap_bf16(
) )
), ),
ids=[ ids=[
" " + " - ".join(test_name_parts) + " " "-".join(test_name_parts)
for test_name_parts in zip( for test_name_parts in zip(
[te.TransformerLayer.__name__ for _ in range(2)], [te.TransformerLayer.__name__ for _ in range(2)],
["BULK DGRAD/WGRAD", "DGRAD+RS"], ["BULK DGRAD/WGRAD", "DGRAD+RS"],
...@@ -382,11 +353,11 @@ def test_multi_layer_with_overlap_bf16( ...@@ -382,11 +353,11 @@ def test_multi_layer_with_overlap_bf16(
], ],
) )
def test_multi_layer_with_overlap_fp8( def test_multi_layer_with_overlap_fp8(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers layer_type, linear_parallel_mode, overlap_rs_dgrad, quantization, num_layers
): ):
""" """
Test Transformer Engine layers with comm+GEMM overlap. Test Transformer Engine layers with comm+GEMM overlap.
""" """
_run_layer_with_overlap( _run_layer_with_overlap(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization, num_layers
) )
...@@ -19,7 +19,6 @@ import torch ...@@ -19,7 +19,6 @@ import torch
import transformer_engine import transformer_engine
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor from transformer_engine.pytorch.ops._common import is_float8_tensor
...@@ -27,6 +26,8 @@ from transformer_engine.pytorch.ops.fused import ( ...@@ -27,6 +26,8 @@ from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear, UserbuffersBackwardLinear,
UserbuffersForwardLinear, UserbuffersForwardLinear,
) )
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
# Import utility functions # Import utility functions
...@@ -36,6 +37,13 @@ from utils import dtype_tols, str_to_dtype ...@@ -36,6 +37,13 @@ from utils import dtype_tols, str_to_dtype
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
quantization_list: list[Optional[str]] = [None]
if fp8_available:
quantization_list.append("fp8")
if mxfp8_available:
quantization_list.append("mxfp8")
# Check if there are multiple GPUs # Check if there are multiple GPUs
if torch.cuda.device_count() < 2: if torch.cuda.device_count() < 2:
...@@ -51,7 +59,7 @@ class ModelConfig: ...@@ -51,7 +59,7 @@ class ModelConfig:
num_heads: int num_heads: int
head_dim: int head_dim: int
dtype: torch.dtype dtype: torch.dtype
fp8: bool quantization: Optional[str]
@property @property
def hidden_size(self): def hidden_size(self):
...@@ -129,12 +137,16 @@ def make_reference_and_test_tensors( ...@@ -129,12 +137,16 @@ def make_reference_and_test_tensors(
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Make copy of tensor # Make copy of tensor
test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8: if test_is_fp8:
test = Float8Tensor.to_float8(ref) quantizer = Float8Quantizer(
else: scale=torch.ones(1, dtype=torch.float32, device=test_device),
test = ref.to(device=test_device, dtype=test_dtype) amax=torch.zeros(1, dtype=torch.float32, device=test_device),
if test.data_ptr() == ref.data_ptr(): fp8_dtype=tex.DType.kFloat8E4M3,
test = test.clone() )
test = quantizer(test)
elif test.data_ptr() == ref.data_ptr():
test = test.clone()
# Make sure reference and test tensors represent exact same values # Make sure reference and test tensors represent exact same values
ref.copy_(test) ref.copy_(test)
...@@ -145,6 +157,21 @@ def make_reference_and_test_tensors( ...@@ -145,6 +157,21 @@ def make_reference_and_test_tensors(
return ref, test return ref, test
def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name == "fp8":
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
def _test_linear( def _test_linear(
*, *,
model_config: ModelConfig, model_config: ModelConfig,
...@@ -155,7 +182,8 @@ def _test_linear( ...@@ -155,7 +182,8 @@ def _test_linear(
weight_requires_grad: bool = True, weight_requires_grad: bool = True,
) -> None: ) -> None:
dtype = model_config.dtype dtype = model_config.dtype
fp8_compute = model_config.fp8 quantization = model_config.quantization
quantized_compute = quantization is not None
# Distributed process group # Distributed process group
process_group = world_group() process_group = world_group()
...@@ -175,14 +203,19 @@ def _test_linear( ...@@ -175,14 +203,19 @@ def _test_linear(
in_shape, in_shape,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8_compute, test_is_fp8=quantized_compute,
) )
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors( w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features), (out_features, in_features),
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8_compute, test_is_fp8=quantized_compute,
) )
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
b_ref, b_test = None, None b_ref, b_test = None, None
if bias: if bias:
if tensor_parallel_mode == "row": if tensor_parallel_mode == "row":
...@@ -198,9 +231,11 @@ def _test_linear( ...@@ -198,9 +231,11 @@ def _test_linear(
out_shape, out_shape,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8_compute, test_is_fp8=quantized_compute,
requires_grad=False, requires_grad=False,
) )
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation # Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref) y_ref = torch.nn.functional.linear(x_ref, w_ref)
...@@ -265,21 +300,15 @@ def _test_linear( ...@@ -265,21 +300,15 @@ def _test_linear(
x_test.requires_grad_() x_test.requires_grad_()
# Implementation with fusible operation # Implementation with fusible operation
with te.fp8_model_init(enabled=fp8_compute): recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_compute, recipe=recipe):
ops = [] ops = []
linear_op = None linear_op = None
bias_op = None bias_op = None
if tensor_parallel_mode == "column": if tensor_parallel_mode == "column":
userbuffers_options = {} userbuffers_options = {}
if not weight_requires_grad: if not weight_requires_grad:
if fp8_compute: userbuffers_options["comm_name"] = "fc1"
userbuffers_options["comm_name"] = "fc1"
else:
# There is a correctness bug with overlapping
# dgrad reduce-scatter with dgrad GEMM. Fall back
# to overlapping dgrad reduce-scatter with wgrad
# GEMM, even though wgrad isn't needed.
userbuffers_options["comm_name"] = "qkv"
else: else:
userbuffers_options["comm_name"] = "qkv" userbuffers_options["comm_name"] = "qkv"
linear_op = te_ops.BasicLinear( linear_op = te_ops.BasicLinear(
...@@ -322,7 +351,7 @@ def _test_linear( ...@@ -322,7 +351,7 @@ def _test_linear(
bias_op.bias.copy_(b_test) bias_op.bias.copy_(b_test)
del w_test del w_test
del b_test del b_test
with te.fp8_autocast(enabled=fp8_compute): with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = model(x_test) y_test = model(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -338,7 +367,7 @@ def _test_linear( ...@@ -338,7 +367,7 @@ def _test_linear(
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
if dtype == torch.float32: if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM tols = dtype_tols(torch.float16) # TF32 GEMM
if fp8_compute: if quantized_compute:
tols = dtype_tols( tols = dtype_tols(
model[0].weight._fp8_dtype model[0].weight._fp8_dtype
if is_float8_tensor(model[0].weight) if is_float8_tensor(model[0].weight)
...@@ -370,7 +399,7 @@ def run_parallel_tests(model_config: ModelConfig) -> None: ...@@ -370,7 +399,7 @@ def run_parallel_tests(model_config: ModelConfig) -> None:
for test_config in itertools.product( for test_config in itertools.product(
(False, True), # bias (False, True), # bias
("column", "row"), # tensor_parallel_mode ("column", "row"), # tensor_parallel_mode
(False, True), # weight_requires_grad (True, False), # weight_requires_grad
): ):
if rank == 0: if rank == 0:
print(f"Running _test_linear with {test_config=}") print(f"Running _test_linear with {test_config=}")
...@@ -390,19 +419,15 @@ if torch.cuda.device_count() > 1: ...@@ -390,19 +419,15 @@ if torch.cuda.device_count() > 1:
@pytest.mark.parametrize("world_size", _world_sizes) @pytest.mark.parametrize("world_size", _world_sizes)
@pytest.mark.parametrize("fp8", (False, True)) @pytest.mark.parametrize("quantization", quantization_list)
def test_fuser_ops_with_userbuffers( def test_fuser_ops_with_userbuffers(
*, *,
world_size: int, world_size: int,
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
fp8: bool, quantization: Optional[str],
) -> None: ) -> None:
"""Launch parallel job and run tests""" """Launch parallel job and run tests"""
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
# Parallel job launcher # Parallel job launcher
command = [] command = []
if tex.ubuf_built_with_mpi(): if tex.ubuf_built_with_mpi():
...@@ -424,8 +449,8 @@ def test_fuser_ops_with_userbuffers( ...@@ -424,8 +449,8 @@ def test_fuser_ops_with_userbuffers(
str(dtype), str(dtype),
) )
) )
if fp8: if quantization is not None:
command.append("--fp8") command.extend(("--quantization", quantization))
# Environment # Environment
env = dict(os.environ) env = dict(os.environ)
...@@ -445,12 +470,12 @@ def main() -> None: ...@@ -445,12 +470,12 @@ def main() -> None:
# Parse command-line arguments # Parse command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--parallel", action="store_true", help="Run parallel tests") parser.add_argument("--parallel", action="store_true", help="Run parallel tests")
parser.add_argument("--sequence-length", type=int, default=32) parser.add_argument("--sequence-length", type=int, default=256)
parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--num-heads", type=int, default=16) parser.add_argument("--num-heads", type=int, default=16)
parser.add_argument("--head-dim", type=int, default=32) parser.add_argument("--head-dim", type=int, default=256)
parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--fp8", action="store_true") parser.add_argument("--quantization", type=str, default=None)
args = parser.parse_args() args = parser.parse_args()
# Run parallel tests if needed # Run parallel tests if needed
...@@ -463,14 +488,17 @@ def main() -> None: ...@@ -463,14 +488,17 @@ def main() -> None:
num_heads=args.num_heads, num_heads=args.num_heads,
head_dim=args.head_dim, head_dim=args.head_dim,
dtype=str_to_dtype(args.dtype), dtype=str_to_dtype(args.dtype),
fp8=args.fp8, quantization=args.quantization,
) )
# Initialize Userbuffers # Initialize Userbuffers
group = world_group() # Initialize NCCL group = world_group() # Initialize NCCL
bootstrap_backend = "mpi" if launcher() == "ompi" else "nccl" bootstrap_backend = "mpi" if launcher() == "ompi" else "nccl"
userbuffer_configs = { userbuffer_configs = {
"fc1_dgrad": {"method": "pipeline"}, # Overlap dgrad RS with dgrad GEMM "fc1_dgrad": {
"method": "ring_exchange",
"fp8_buf": False,
}, # Overlap dgrad RS with dgrad GEMM
} }
te.module.base.initialize_ub( te.module.base.initialize_ub(
[ [
...@@ -478,7 +506,7 @@ def main() -> None: ...@@ -478,7 +506,7 @@ def main() -> None:
model_config.num_heads * model_config.head_dim, model_config.num_heads * model_config.head_dim,
], ],
torch.distributed.get_world_size(group), torch.distributed.get_world_size(group),
use_fp8=model_config.fp8, use_fp8=model_config.quantization is not None,
dtype=model_config.dtype, dtype=model_config.dtype,
bootstrap_backend=bootstrap_backend, bootstrap_backend=bootstrap_backend,
ub_cfgs=userbuffer_configs, ub_cfgs=userbuffer_configs,
......
...@@ -147,7 +147,44 @@ CommOverlapCore::~CommOverlapCore() { ...@@ -147,7 +147,44 @@ CommOverlapCore::~CommOverlapCore() {
TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, size_t chunk_offset, TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, size_t chunk_offset,
const std::vector<size_t> &chunk_shape) { const std::vector<size_t> &chunk_shape) {
TensorWrapper chunk; const auto scaling_mode = source.scaling_mode();
// Tensor dimensions
std::vector<size_t> shape = shape_to_vector(source.shape());
auto flatten_shape_to_2d = [](const std::vector<size_t> &shape) -> std::pair<size_t, size_t> {
if (shape.empty()) {
return {1, 1};
}
size_t height = 1;
for (size_t i = 0; i < shape.size() - 1; ++i) {
height *= shape[i];
}
return {height, shape.back()};
};
size_t height, width, chunk_height, chunk_width;
std::tie(height, width) = flatten_shape_to_2d(shape);
std::tie(chunk_height, chunk_width) = flatten_shape_to_2d(chunk_shape);
// Check tensor dimensions
#define NVTE_DIM_CHECK(cond, message) \
NVTE_CHECK(cond, message, " (tensor shape=", shape, ", chunk shape=", chunk_shape, \
", chunk offset=", chunk_offset, ")")
NVTE_DIM_CHECK(height > 0 && width > 0, "Attempted to get chunk from empty tensor");
NVTE_DIM_CHECK(chunk_height > 0 && chunk_width > 0, "Attempted to get empty tensor chunk");
NVTE_DIM_CHECK(chunk_height <= height && chunk_width <= width,
"Attempted to get out-of-bounds tensor chunk");
if (scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING) {
// MXFP8 scale-inverses are padded to a 2D matrix with dims that
// are divisible by 128. UB doesn't handle this padding yet.
NVTE_DIM_CHECK(height % 128 == 0 && width % 128 == 0,
"Userbuffers requires MXFP8 tensor dims that are divisible by 128");
NVTE_DIM_CHECK(chunk_height % 128 == 0 && chunk_width % 128 == 0,
"Userbuffers requires MXFP8 tensor chunk dims that are divisible by 128");
}
#undef NVTE_DIM_CHECK
// Construct tensor chunk
TensorWrapper chunk(scaling_mode);
for (int param_id = 0; param_id < NVTETensorParam::kNVTENumTensorParams; param_id++) { for (int param_id = 0; param_id < NVTETensorParam::kNVTENumTensorParams; param_id++) {
auto param_type = static_cast<NVTETensorParam>(param_id); auto param_type = static_cast<NVTETensorParam>(param_id);
auto param = source.get_parameter(param_type); auto param = source.get_parameter(param_type);
...@@ -163,8 +200,8 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz ...@@ -163,8 +200,8 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
param_shape = chunk_shape; param_shape = chunk_shape;
if (param_type == NVTETensorParam::kNVTEColumnwiseData && if (param_type == NVTETensorParam::kNVTEColumnwiseData &&
source.scaling_mode() != NVTEScalingMode::NVTE_MXFP8_1D_SCALING) { source.scaling_mode() == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) {
// Columnwise shape for non-block scaled tensors shifts the last dimension to the front // Columnwise shape for FP8 tensor-scaled tensors shifts the last dimension to the front
auto last_dim = param_shape.back(); auto last_dim = param_shape.back();
param_shape.pop_back(); param_shape.pop_back();
param_shape.insert(param_shape.begin(), last_dim); param_shape.insert(param_shape.begin(), last_dim);
...@@ -172,18 +209,16 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz ...@@ -172,18 +209,16 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
} else if (source.scaling_mode() == NVTEScalingMode::NVTE_MXFP8_1D_SCALING && } else if (source.scaling_mode() == NVTEScalingMode::NVTE_MXFP8_1D_SCALING &&
(param_type == NVTETensorParam::kNVTERowwiseScaleInv || (param_type == NVTETensorParam::kNVTERowwiseScaleInv ||
param_type == NVTETensorParam::kNVTEColumnwiseScaleInv)) { param_type == NVTETensorParam::kNVTEColumnwiseScaleInv)) {
// Calculate block scaling offset and size // Calculate offset and size for MXFP8 scale-invs
auto scaled_tensor_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv) size_t chunk_scale_height = chunk_height;
? source.shape().data[0] size_t chunk_scale_width = chunk_width;
: source.columnwise_shape().data[0]; if (param_type == NVTETensorParam::kNVTERowwiseScaleInv) {
auto scaled_chunk_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv) chunk_scale_width /= 32;
? chunk_shape.front() } else {
: chunk_shape.back(); chunk_scale_height /= 32;
auto chunk_scale_start = chunk_offset / 32; }
auto chunk_scale_end = (chunk_offset + scaled_chunk_dim_size) / 32; param_dptr += (chunk_offset / 32) * typeToSize(param_dtype);
auto chunk_scale_size = chunk_scale_end - chunk_scale_start; param_shape = {chunk_scale_height, chunk_scale_width};
param_dptr += chunk_scale_start * typeToSize(param_dtype);
param_shape = std::vector<size_t>{chunk_scale_size};
} }
// Set chunked source parameters into the chunked tensor output // Set chunked source parameters into the chunked tensor output
...@@ -422,10 +457,21 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -422,10 +457,21 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
size_t k = transa ? A.size(1) : A.size(0); size_t k = transa ? A.size(1) : A.size(0);
size_t n = _ubuf.size(0); size_t n = _ubuf.size(0);
size_t m_chunk = m / _num_splits; size_t m_chunk = m / _num_splits;
const std::vector<size_t> input_a_chunk_shape =
(transa ? std::vector<size_t>{m_chunk, k} : std::vector<size_t>{k, m_chunk});
const std::vector<size_t> output_chunk_shape = {n, m_chunk};
size_t input_a_chunk_size = m_chunk * k; size_t input_a_chunk_size = m_chunk * k;
size_t output_chunk_size = n * m_chunk; size_t output_chunk_size = n * m_chunk;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Helper function to get bias chunk if needed
auto maybe_get_bias_chunk = [this, &bias, m_chunk](size_t chunk_id) -> TensorWrapper {
if (bias.dptr() == nullptr) {
return TensorWrapper();
}
return get_tensor_chunk(bias, chunk_id * m_chunk, {m_chunk});
};
// Catch up the default torch stream // Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
for (size_t i = 0; i < _stream_compute.size(); i++) { for (size_t i = 0; i < _stream_compute.size(); i++) {
...@@ -437,21 +483,23 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -437,21 +483,23 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr()); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
if (_rs_overlap_first_gemm) { if (_rs_overlap_first_gemm) {
auto input_a_chunk = get_tensor_chunk(A, 0, {m_chunk, k}); auto input_a_chunk = get_tensor_chunk(A, 0, input_a_chunk_shape);
auto output_chunk = get_buffer_chunk_like(D, 0, {m, m_chunk}); auto output_chunk = get_buffer_chunk_like(D, 0, output_chunk_shape);
auto bias_chunk = maybe_get_bias_chunk(0);
auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0]); use_split_accumulator, _math_sms, _stream_compute[0]);
for (int i = 1; i < _num_splits; i++) { for (int i = 1; i < _num_splits; i++) {
input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k}); input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, input_a_chunk_shape);
output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk}); output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, output_chunk_shape);
bias_chunk = maybe_get_bias_chunk(i);
workspace_chunk = get_tensor_chunk( workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms, accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]); _stream_compute[i % _stream_compute.size()]);
...@@ -494,12 +542,13 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -494,12 +542,13 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
} }
} else { } else {
for (int i = 0; i < _num_splits; i++) { for (int i = 0; i < _num_splits; i++) {
auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k}); auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, input_a_chunk_shape);
auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk}); auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, output_chunk_shape);
auto bias_chunk = maybe_get_bias_chunk(i);
auto workspace_chunk = get_tensor_chunk( auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms, accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]); _stream_compute[i % _stream_compute.size()]);
...@@ -759,8 +808,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -759,8 +808,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// Get communication and GEMM output chunk sizes // Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const bool do_gelu = pre_gelu_out.numel() > 0; const bool do_gelu = pre_gelu_out.numel() > 0;
size_t input_chunk_size = n_chunk * k;
size_t output_chunk_size = n_chunk * m;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
...@@ -771,8 +818,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -771,8 +818,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
} }
if (_aggregate) { if (_aggregate) {
const int num_steps = _tp_size / 2; const int num_steps = _tp_size / 2;
input_chunk_size *= 2;
output_chunk_size *= 2; // Chunk dims
std::vector<size_t> input_b_chunk_shape =
(transb ? std::vector<size_t>{k, 2 * n_chunk} : std::vector<size_t>{2 * n_chunk, k});
std::vector<size_t> output_chunk_shape = {2 * n_chunk, k};
size_t input_b_chunk_size = 2 * n_chunk * k;
size_t output_chunk_size = 2 * n_chunk * m;
// Initial 1X input chunk exchange between neighboring peers // Initial 1X input chunk exchange between neighboring peers
int send_chunk_id = _tp_id; int send_chunk_id = _tp_id;
...@@ -801,8 +853,9 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -801,8 +853,9 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// GEMM // GEMM
auto input_b_chunk = auto input_b_chunk =
get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk * 2, k}); get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);
auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk * 2, m}); auto output_chunk =
get_tensor_chunk(D, output_chunk_size * send_chunk_id, output_chunk_shape);
auto aux_chunk = auto aux_chunk =
(do_gelu) (do_gelu)
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk * 2, k}) ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk * 2, k})
...@@ -834,6 +887,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -834,6 +887,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
} }
} }
} else { } else {
// Chunk dims
std::vector<size_t> input_b_chunk_shape =
(transb ? std::vector<size_t>{k, n_chunk} : std::vector<size_t>{n_chunk, k});
std::vector<size_t> output_chunk_shape = {n_chunk, m};
size_t input_b_chunk_size = n_chunk * k;
size_t output_chunk_size = n_chunk * m;
for (int i = 0; i < _tp_size; i++) { for (int i = 0; i < _tp_size; i++) {
// Set the userbuffer id. Buffer under send is the input for the current // Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
...@@ -845,8 +905,10 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -845,8 +905,10 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
int recv_offset = comm_bytes * recv_chunk_id; int recv_offset = comm_bytes * recv_chunk_id;
// GEMM // GEMM
auto input_b_chunk = get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk, k}); auto input_b_chunk =
auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk, m}); get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);
auto output_chunk =
get_tensor_chunk(D, output_chunk_size * send_chunk_id, output_chunk_shape);
auto aux_chunk = auto aux_chunk =
(do_gelu) (do_gelu)
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk, k}) ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk, k})
...@@ -996,6 +1058,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, ...@@ -996,6 +1058,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
auto input_b_chunk = get_tensor_chunk(B, input_b_chunk_id * input_chunk_size, {n_chunk, k}); auto input_b_chunk = get_tensor_chunk(B, input_b_chunk_id * input_chunk_size, {n_chunk, k});
auto output_chunk = get_buffer_chunk_by_id(D, i); auto output_chunk = get_buffer_chunk_by_id(D, i);
auto workspace_chunk = auto workspace_chunk =
get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk}); get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk});
......
...@@ -12,10 +12,18 @@ ...@@ -12,10 +12,18 @@
#include <cudnn.h> #include <cudnn.h>
#include <nvrtc.h> #include <nvrtc.h>
#include <iostream>
#include <stdexcept> #include <stdexcept>
#include "../util/string.h" #include "../util/string.h"
#define NVTE_WARN(...) \
do { \
std::cerr << ::transformer_engine::concat_strings( \
__FILE__ ":", __LINE__, " in function ", __func__, ": ", \
::transformer_engine::concat_strings(__VA_ARGS__), "\n"); \
} while (false)
#define NVTE_ERROR(...) \ #define NVTE_ERROR(...) \
do { \ do { \
throw ::std::runtime_error(::transformer_engine::concat_strings( \ throw ::std::runtime_error(::transformer_engine::concat_strings( \
......
...@@ -452,12 +452,10 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve ...@@ -452,12 +452,10 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve
~CommOverlap() {} ~CommOverlap() {}
void set_buffer_params(py::handle quantizer); void copy_into_buffer(const at::Tensor &input, bool local_chunk = false);
void copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk = false); at::Tensor get_buffer(bool local_chunk = false,
std::optional<std::vector<int64_t>> shape = std::nullopt);
py::object get_buffer(py::handle quantizer, bool local_chunk = false,
std::optional<const std::vector<int64_t>> shape = std::nullopt);
}; // CommOverlap }; // CommOverlap
...@@ -473,12 +471,10 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm ...@@ -473,12 +471,10 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
~CommOverlapP2P() {} ~CommOverlapP2P() {}
void set_buffer_params(py::handle quantizer); void copy_into_buffer(const at::Tensor &input, bool local_chunk = false);
void copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk = false);
py::object get_buffer(py::handle quantizer, bool local_chunk = false, at::Tensor get_buffer(bool local_chunk = false,
std::optional<const std::vector<int64_t>> shape = std::nullopt); std::optional<std::vector<int64_t>> shape = std::nullopt);
}; // CommOverlapP2P }; // CommOverlapP2P
......
...@@ -141,81 +141,79 @@ CommOverlap::CommOverlap(const std::vector<size_t> &buffer_shape, at::ScalarType ...@@ -141,81 +141,79 @@ CommOverlap::CommOverlap(const std::vector<size_t> &buffer_shape, at::ScalarType
num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm,
set_sm_margin, atomic_gemm, rs_overlap_first_gemm) {} set_sm_margin, atomic_gemm, rs_overlap_first_gemm) {}
void CommOverlap::set_buffer_params(py::handle quantizer) {
std::unique_ptr<te::pytorch::Quantizer> my_quantizer = te::pytorch::convert_quantizer(quantizer);
my_quantizer->set_quantization_params(&_ubuf);
_ubuf_scale_inv_initialized = true;
}
/* /*
** Helper function to copy input to _ubuf ** Helper function to copy input to _ubuf
*/ */
void CommOverlap::copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk) { void CommOverlap::copy_into_buffer(const at::Tensor &input, bool local_chunk) {
auto input_tensor = te::pytorch::makeTransformerEngineTensor(input, quantizer); const auto &input_ = input.contiguous();
auto input_ptr = input_tensor.dptr() ? input_tensor.dptr() : input_tensor.columnwise_dptr();
NVTE_CHECK(input_ptr, "Input tensor does not have rowwise or columnwise data!"); // Check element size
const size_t element_size = input.element_size();
char *ubuf_ptr = reinterpret_cast<char *>(_ubuf.dptr()); NVTE_CHECK(_ubuf.element_size() == element_size,
"Tried to copy data into a Userbuffers buffer but dtypes are not compatible ",
"(input dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(),
" bytes)");
// Input data
const size_t input_size = input_.numel();
const void *src_ptr = input_.data_ptr();
// Userbuffers data
const size_t ubuf_size = _ubuf.numel();
void *dst_ptr = _ubuf.dptr();
if (local_chunk) { if (local_chunk) {
if (input_tensor.numel() * _tp_size > _ubuf.numel()) NVTE_CHECK(input_size * _tp_size == ubuf_size,
NVTE_ERROR("input is larger than the local communication buffer!"); "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ",
if (input_tensor.element_size() != _ubuf.element_size()) "(input_size=", input_size, ", tensor_parallel_size=", _tp_size,
NVTE_ERROR("input data type does not match communication buffer!"); ", ubuf_size=", ubuf_size, ")");
ubuf_ptr += (_ubuf.numel() / _tp_size) * _tp_id * _ubuf.element_size(); dst_ptr = (reinterpret_cast<char *>(dst_ptr) + (ubuf_size / _tp_size) * _tp_id * element_size);
} else { } else {
if (input_tensor.numel() > _ubuf.numel()) NVTE_CHECK(input_size == ubuf_size,
NVTE_ERROR("input is larger than the global communication buffer!"); "Tried to copy an invalid tensor into a Userbuffers buffer ",
if (input_tensor.element_size() != _ubuf.element_size()) "(input_size=", input_size, ", ubuf_size=", ubuf_size, ")");
NVTE_ERROR("input data type does not match communication buffer!");
} }
// Copy either row or columnwise data into the communication buffer's columnwise data // Copy data
// NOTE: _ubuf.columnwise_dptr() is not a valid copy target because it is not registered with auto stream_main = at::cuda::getCurrentCUDAStream();
// the Userbuffers communicator.
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0));
NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input_tensor.dptr(), NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, input_size * element_size,
input_tensor.numel() * input_tensor.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm)); cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm));
} }
py::object CommOverlap::get_buffer(py::handle quantizer, bool local_chunk, at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional<std::vector<int64_t>> shape) {
std::optional<const std::vector<int64_t>> shape) { // Check buffer shape
using namespace te::pytorch; const size_t ubuf_size = _ubuf.numel();
char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.dptr()); if (shape) {
if (local_chunk) ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); const size_t requested_size = transformer_engine::pytorch::product(*shape);
if (local_chunk) {
std::vector<int64_t> torch_shape; NVTE_CHECK(requested_size * _tp_size == ubuf_size,
if (shape.has_value()) { "Invalid shape for local chunk of a Userbuffers buffer (requested shape=", *shape,
torch_shape = shape.value(); ", tensor_parallel_size=", _tp_size, ", ubuf_size=", ubuf_size, ")");
size_t requested = product(torch_shape); } else {
auto expected = local_chunk ? _ubuf.numel() / _tp_size : _ubuf.numel(); NVTE_CHECK(requested_size == ubuf_size,
NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested, "Invalid shape for a Userbuffers buffer (requested shape=", *shape,
") does not match allocated buffer size (", expected, ")!"); ", ubuf_size=", ubuf_size, ")");
}
} else { } else {
int64_t output_c_dim0 = (local_chunk) ? _ubuf.size(0) / _tp_size : _ubuf.size(0); int64_t dim0 = _ubuf.size(0);
int64_t output_c_dim1 = _ubuf.size(1); int64_t dim1 = _ubuf.size(1);
torch_shape = {output_c_dim0, output_c_dim1}; if (local_chunk) {
dim0 /= _tp_size;
}
shape = {dim0, dim1};
} }
auto ubuf_tensor = torch::from_blob(reinterpret_cast<void *>(ubuf_wt_ptr), torch_shape, // Data pointer
at::dtype(GetATenDType(_ubuf.dtype())).device(torch::kCUDA)); void *ubuf_ptr = _ubuf.dptr();
if (local_chunk) {
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer); ubuf_ptr = (reinterpret_cast<char *>(ubuf_ptr) +
std::vector<size_t> te_shape; (ubuf_size / _tp_size) * _tp_id * _ubuf.element_size());
for (auto s : torch_shape) te_shape.emplace_back(static_cast<size_t>(s)); }
// Always output a rowwise-only QuantizedTensor // Construct PyTorch tensor
// TODO (Alp): This needs to produce an un-interleaved transpose when required. const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype());
auto is_internal = my_quantizer->internal; return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
auto uses_columnwise = my_quantizer->columnwise_usage;
my_quantizer->internal = false;
my_quantizer->columnwise_usage = false;
auto [te_tensor, py_tensor] = my_quantizer->create_tensor(te_shape, _ubuf.dtype(), ubuf_tensor);
my_quantizer->internal = is_internal;
my_quantizer->columnwise_usage = uses_columnwise;
return py_tensor;
} }
/*************************************************************************************************** /***************************************************************************************************
...@@ -236,74 +234,69 @@ CommOverlapP2P::CommOverlapP2P(const std::vector<size_t> &buffer_shape, at::Scal ...@@ -236,74 +234,69 @@ CommOverlapP2P::CommOverlapP2P(const std::vector<size_t> &buffer_shape, at::Scal
comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce,
atomic_gemm, aggregate) {} atomic_gemm, aggregate) {}
void CommOverlapP2P::set_buffer_params(py::handle quantizer) {
std::unique_ptr<te::pytorch::Quantizer> my_quantizer = te::pytorch::convert_quantizer(quantizer);
my_quantizer->set_quantization_params(&_ubuf);
for (size_t i = 0; i < _ubufs.size(); i++) my_quantizer->set_quantization_params(&_ubufs[i]);
}
/* /*
** Copy input to _ubufs[0] ** Copy input to _ubufs[0]
*/ */
void CommOverlapP2P::copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk) { void CommOverlapP2P::copy_into_buffer(const at::Tensor &input, bool local_chunk) {
auto input_tensor = te::pytorch::makeTransformerEngineTensor(input, quantizer); const auto &input_ = input.contiguous();
auto input_ptr = input_tensor.dptr() ? input_tensor.dptr() : input_tensor.columnwise_dptr();
NVTE_CHECK(input_ptr, "Input tensor does not have rowwise or columnwise data!"); // Check element size
const size_t element_size = input.element_size();
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); NVTE_CHECK(_ubuf.element_size() == element_size,
"Tried to copy data into a Userbuffers buffer but dtypes are not compatible ",
"(input dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(),
" bytes)");
// Input data
const size_t input_size = input_.numel();
const void *src_ptr = input_.data_ptr();
// Userbuffers data
void *dst_ptr;
if (local_chunk) { if (local_chunk) {
// Copy input to the target ubuf chunk by rank offset NVTE_CHECK(_ubufs[_tp_id].numel() == input_size,
if (input_tensor.numel() * _tp_size > _ubuf.numel()) "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ",
NVTE_ERROR("input is larger than the local communication buffer!"); "(input_size=", input_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")");
if (input_tensor.element_size() != _ubuf.element_size()) dst_ptr = _ubufs[_tp_id].dptr();
NVTE_ERROR("input data type does not match communication buffer!");
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input_ptr,
input_tensor.numel() * input_tensor.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main));
} else { } else {
if (input_tensor.numel() > _ubuf.numel()) NVTE_CHECK(_ubuf.numel() == input_size,
NVTE_ERROR("input is larger than the global communication buffer!"); "Tried to copy an invalid tensor into a Userbuffers buffer ",
if (input_tensor.element_size() != _ubuf.element_size()) "(input_size=", input_size, ", ubuf_size=", _ubuf.numel(), ")");
NVTE_ERROR("input data type does not match communication buffer!"); dst_ptr = _ubuf.dptr();
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input_ptr,
input_tensor.numel() * input_tensor.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main));
} }
// Copy data
NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, input_size * element_size,
cudaMemcpyDeviceToDevice,
(cudaStream_t)at::cuda::getCurrentCUDAStream()));
} }
py::object CommOverlapP2P::get_buffer(py::handle quantizer, bool local_chunk, at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional<std::vector<int64_t>> shape) {
std::optional<const std::vector<int64_t>> shape) { // Check buffer shape
using namespace te::pytorch; if (shape) {
char *ubuf_wt_ptr = reinterpret_cast<char *>(local_chunk ? _ubufs[_tp_id].dptr() : _ubuf.dptr()); const size_t requested_size = transformer_engine::pytorch::product(*shape);
if (local_chunk) {
std::vector<int64_t> torch_shape; NVTE_CHECK(requested_size == _ubufs[_tp_id].numel(),
if (shape.has_value()) { "Invalid shape for local chunk of a Userbuffers buffer (requested shape=", *shape,
torch_shape = shape.value(); ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")");
size_t requested = product(torch_shape); } else {
auto expected = local_chunk ? _ubufs[_tp_id].numel() : _ubuf.numel(); NVTE_CHECK(requested_size == _ubuf.numel(),
NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested, "Invalid shape for a Userbuffers buffer (requested shape=", *shape,
") does not match allocated buffer size (", expected, ")!"); ", ubuf_size=", _ubuf.numel(), ")");
}
} else { } else {
int64_t output_c_dim0 = (local_chunk) ? _ubuf.size(0) / _tp_size : _ubuf.size(0); int64_t dim0 = _ubuf.size(0);
int64_t output_c_dim1 = _ubuf.size(1); int64_t dim1 = _ubuf.size(1);
torch_shape = {output_c_dim0, output_c_dim1}; if (local_chunk) {
dim0 /= _tp_size;
}
shape = {dim0, dim1};
} }
auto ubuf_tensor = torch::from_blob(reinterpret_cast<void *>(ubuf_wt_ptr), torch_shape,
at::dtype(GetATenDType(_ubuf.dtype())).device(torch::kCUDA)); // Data pointer
void *ubuf_ptr = local_chunk ? _ubufs[_tp_id].dptr() : _ubuf.dptr();
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
std::vector<size_t> te_shape; // Construct PyTorch tensor
for (auto s : torch_shape) te_shape.emplace_back(static_cast<size_t>(s)); const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype());
return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
// Always output a rowwise-only QuantizedTensor
// TODO (Alp): This needs to produce an un-interleaved transpose when required.
auto is_internal = my_quantizer->internal;
auto uses_columnwise = my_quantizer->columnwise_usage;
my_quantizer->internal = false;
my_quantizer->columnwise_usage = false;
auto [te_tensor, py_tensor] = my_quantizer->create_tensor(te_shape, _ubuf.dtype(), ubuf_tensor);
my_quantizer->internal = is_internal;
my_quantizer->columnwise_usage = uses_columnwise;
return py_tensor;
} }
...@@ -360,10 +360,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -360,10 +360,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true,
py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false) py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false)
.def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"), .def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"),
py::arg("quantizer"), py::arg("local_chunk") = false) py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlap::get_buffer, py::arg("quantizer"), .def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false,
py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) py::arg("shape") = std::nullopt);
.def("set_buffer_params", &CommOverlap::set_buffer_params);
py::class_<CommOverlapP2P, std::shared_ptr<CommOverlapP2P>, py::class_<CommOverlapP2P, std::shared_ptr<CommOverlapP2P>,
transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>( transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>(
...@@ -378,8 +377,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -378,8 +377,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false,
py::arg("use_ce") = true, py::arg("aggregate") = false) py::arg("use_ce") = true, py::arg("aggregate") = false)
.def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"), .def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"),
py::arg("quantizer"), py::arg("local_chunk") = false) py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("quantizer"), .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false,
py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) py::arg("shape") = std::nullopt);
.def("set_buffer_params", &CommOverlapP2P::set_buffer_params);
} }
...@@ -919,8 +919,10 @@ def _all_gather_fp8( ...@@ -919,8 +919,10 @@ def _all_gather_fp8(
# Note: We cannot directly all-gather the transposed FP8 tensor, # Note: We cannot directly all-gather the transposed FP8 tensor,
# so temporarily modify quantizer to avoid creating FP8 transpose. # so temporarily modify quantizer to avoid creating FP8 transpose.
if not isinstance(inp, Float8TensorBase): if not isinstance(inp, Float8TensorBase):
if quantizer is None: assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer))
raise ValueError("Input tensor is not FP8 and no quantizer was provided") # we cannot directly gather the transposed fp8 tensor
# so we need to disable columnwise usage for the quantizer
# and then set it back to the original value after quantizing
init_rowwise_usage = quantizer.rowwise_usage init_rowwise_usage = quantizer.rowwise_usage
init_columnwise_usage = quantizer.columnwise_usage init_columnwise_usage = quantizer.columnwise_usage
quantizer.set_usage(rowwise=True, columnwise=False) quantizer.set_usage(rowwise=True, columnwise=False)
...@@ -1063,11 +1065,11 @@ def _all_gather_mxfp8( ...@@ -1063,11 +1065,11 @@ def _all_gather_mxfp8(
dtype = inp.dtype dtype = inp.dtype
elif isinstance(inp, MXFP8TensorBase): elif isinstance(inp, MXFP8TensorBase):
if inp._rowwise_data is not None: if inp._rowwise_data is not None:
in_shape = inp._rowwise_data.device.size() in_shape = inp._rowwise_data.size()
device = inp._rowwise_data.device device = inp._rowwise_data.device
dtype = inp._rowwise_data.dtype dtype = inp._rowwise_data.dtype
elif inp._columnwise_data is not None: elif inp._columnwise_data is not None:
in_shape = inp._columnwise_data.device.size() in_shape = inp._columnwise_data.size()
device = inp._columnwise_data.device device = inp._columnwise_data.device
dtype = inp._columnwise_data.dtype dtype = inp._columnwise_data.dtype
else: else:
......
...@@ -6,8 +6,6 @@ ...@@ -6,8 +6,6 @@
from typing import Any, List, Optional, Tuple, Union, Callable from typing import Any, List, Optional, Tuple, Union, Callable
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce
from operator import mul as multiply_op
import queue import queue
import torch import torch
...@@ -15,7 +13,6 @@ import torch ...@@ -15,7 +13,6 @@ import torch
from .. import cpp_extensions as tex from .. import cpp_extensions as tex
from ..constants import TE_DType from ..constants import TE_DType
from ..utils import get_default_init_method from ..utils import get_default_init_method
from ..tensor.float8_tensor import Float8Tensor
def _get_normalization_func(normalization: str, forward: bool): def _get_normalization_func(normalization: str, forward: bool):
...@@ -33,39 +30,6 @@ def _get_normalization_func(normalization: str, forward: bool): ...@@ -33,39 +30,6 @@ def _get_normalization_func(normalization: str, forward: bool):
return bwd_normalization_funcs[normalization] return bwd_normalization_funcs[normalization]
def _fix_gathered_fp8_transpose(fp8_tensor: Float8Tensor, tp_size: int) -> Float8Tensor:
"""Reorder FP8 transposes after Userbuffers gather.
The all-gather is performed in-place in the Float8Tensor's
row-wise data, and afterwards we need to do a transpose to get the
correct ordering. This misuses data fields in Float8Tensor and
should be considered an evil hack. It would be best to move
transpose logic into CommOverlap::get_buffer.
Responsibility for fixing: adener, tmoon
"""
assert isinstance(fp8_tensor, Float8Tensor), "Tensor is not a Float8Tensor"
assert tp_size > 1, "The tensor transpose cannot be interleaved when TP size is 1"
assert fp8_tensor._data is not None, "The tensor does not hold any rowwise data"
assert (
fp8_tensor._data.shape[0] % tp_size == 0
), "Leading dimension of data is not divisble by TP size"
data = fp8_tensor._data
batched_size = reduce(multiply_op, data.shape[1:])
interleaved_shape = [tp_size, data.shape[0] // tp_size, batched_size]
transposed_shape = [data.shape[0] // tp_size, batched_size * tp_size]
fp8_tensor._transpose = (
data.view(interleaved_shape).transpose(0, 1).contiguous().view(transposed_shape)
)
fp8_tensor._transpose_invalid = False
fp8_tensor._data = None
return fp8_tensor
def apply_normalization( def apply_normalization(
inputmat: torch.Tensor, inputmat: torch.Tensor,
ln_out: torch.Tensor, ln_out: torch.Tensor,
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""Base modules and utilities for TransformerEngine PyTorch API""" """Base modules and utilities for TransformerEngine PyTorch API"""
import io import io
import math
import os import os
import pickle import pickle
import warnings import warnings
...@@ -35,7 +36,9 @@ from ..distributed import ( ...@@ -35,7 +36,9 @@ from ..distributed import (
_fsdp_gather_tensors, _fsdp_gather_tensors,
) )
from ..constants import dist_group_type from ..constants import dist_group_type
from ..tensor import QuantizedTensor, Quantizer from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorBase, Quantizer
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
...@@ -418,6 +421,142 @@ def destroy_ub(): ...@@ -418,6 +421,142 @@ def destroy_ub():
layers_atomic_ring_exchange = [] layers_atomic_ring_exchange = []
def fill_userbuffers_buffer_for_all_gather(
comm,
local_tensor: torch.Tensor,
quantizer: Optional[Quantizer],
process_group,
) -> tuple[torch.Tensor | QuantizedTensorBase, torch.Tensor | QuantizedTensorBase]:
"""Fill local shard of Userbuffers buffer with data for all-gather
Returns the full tensor and the local shard, both using the
Userbuffers buffer as their underlying data. These tensors should
be used carefully (e.g. only immediately before and after a
Userbuffers operation) since the underlying data may be
overwritten by other Userbuffers operations.
May perform blocking communication if needed for the gathered
tensor's metadata, e.g. scaling factors.
"""
# Tensor dimensions
local_shape = local_tensor.size()
if not local_shape:
raise ValueError(f"Invalid local tensor (shape={tuple(local_shape)})")
process_group_size = torch.distributed.get_world_size(process_group)
global_shape = list(local_shape)
global_shape[0] *= process_group_size
# Unquantized data
if quantizer is None:
if isinstance(local_tensor, QuantizedTensorBase):
local_tensor = local_tensor.dequantize()
if comm.is_fp8_ubuf():
raise RuntimeError(
"Attempting to all-gather unquantized tensor, "
"but Userbuffers is initialized with FP8 buffers"
)
comm.copy_into_buffer(local_tensor, local_chunk=True)
global_tensor = comm.get_buffer(shape=global_shape)
return global_tensor, local_tensor
# FP8 data
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
if not isinstance(local_tensor, Float8TensorBase):
if isinstance(local_tensor, QuantizedTensorBase):
local_tensor.dequantize()
quantizer.set_usage(rowwise=True, columnwise=False)
local_tensor = quantizer(local_tensor)
if not comm.is_fp8_ubuf():
raise RuntimeError(
"Attempting to all-gather FP8 tensor, "
"but Userbuffers is not initialized with FP8 buffers"
)
comm.copy_into_buffer(local_tensor._data, local_chunk=True)
global_tensor_data = comm.get_buffer(shape=global_shape)
global_tensor = Float8TensorBase(
data=global_tensor_data,
fp8_scale_inv=local_tensor._scale_inv,
fp8_dtype=local_tensor._fp8_dtype,
quantizer=quantizer,
)
return global_tensor, local_tensor
# MXFP8 data
if isinstance(quantizer, MXFP8Quantizer):
# Cast to MXFP8 if needed
if not isinstance(local_tensor, MXFP8TensorBase):
if isinstance(local_tensor, QuantizedTensorBase):
local_tensor.dequantize()
local_tensor = quantizer(local_tensor)
if not comm.is_fp8_ubuf():
raise RuntimeError(
"Attempting to all-gather MXFP8 tensor, "
"but Userbuffers is not initialized with FP8 buffers"
)
# Check which MXFP8 buffer to communicate
if quantizer.rowwise_usage == quantizer.columnwise_usage:
raise ValueError(
"Userbuffers can only communicate one MXFP8 buffer at a time, "
f"but quantizer has rowwise_usage={quantizer.rowwise_usage}, "
f"columnwise_usage={quantizer.columnwise_usage}"
)
with_rowwise_data = quantizer.rowwise_usage
# Copy MXFP8 data to local chunk of Userbuffers buffer
local_data = (
local_tensor._rowwise_data if with_rowwise_data else local_tensor._columnwise_data
)
comm.copy_into_buffer(local_data, local_chunk=True)
# Gather scaling-inverses
if math.prod(local_shape[:-1]) % 128 != 0:
raise ValueError(
"Userbuffers requires MXFP8 tensor dims that are divisible by 128, "
f"but got MXFP8 tensor with shape={tuple(local_shape)}"
)
local_scale_inv = (
local_tensor._rowwise_scale_inv
if with_rowwise_data
else local_tensor._columnwise_scale_inv
)
local_scale_inv_size = list(local_scale_inv.size())
global_scale_inv = torch.empty(
[process_group_size * local_scale_inv_size[0]] + local_scale_inv_size[1:],
dtype=local_scale_inv.dtype,
device=local_scale_inv.device,
)
torch.distributed.all_gather_into_tensor(
global_scale_inv,
local_scale_inv,
group=process_group,
)
# Construct MXFP8 tensor with Userbuffers buffer
rowwise_data, rowwise_scale_inv = None, None
columnwise_data, columnwise_scale_inv = None, None
global_data = comm.get_buffer(shape=global_shape)
if with_rowwise_data:
rowwise_data, rowwise_scale_inv = global_data, global_scale_inv
else:
columnwise_data, columnwise_scale_inv = global_data, global_scale_inv
global_tensor = MXFP8TensorBase(
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
fp8_dtype=local_tensor._fp8_dtype,
quantizer=quantizer,
)
return global_tensor, local_tensor
# Unsupported data format
raise ValueError(f"Unsupported quantizer for Userbuffers ({quantizer})")
class TransformerEngineBaseModule(torch.nn.Module, ABC): class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module.""" """Base TE module."""
...@@ -866,11 +1005,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -866,11 +1005,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Non-FP8 case: bgrad is fused with wgrad for this case. # Non-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8 and not ctx.debug: if not ctx.fp8 and not ctx.debug:
if gather_grad_output: if gather_grad_output:
if not ctx.ub_overlap_ag: if not ctx.ub_overlap_ag: # Perform NCCL all-gather
grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group) grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group)
else: else: # Initialize Userbuffers all-gather
ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True) grad_output, _ = fill_userbuffers_buffer_for_all_gather(
grad_output = ctx.ub_obj_gradout.get_buffer(quantizer) ctx.ub_obj_gradout,
grad_output,
None,
ctx.tp_group,
)
return grad_output, None return grad_output, None
# FP8 with all-gather: unfused bgrad, fused cast + transpose # FP8 with all-gather: unfused bgrad, fused cast + transpose
...@@ -893,8 +1036,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -893,8 +1036,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_output = quantizer(grad_output) grad_output = quantizer(grad_output)
# Copy into communication buffer, and replace original gradient with it # Copy into communication buffer, and replace original gradient with it
ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True) grad_output, _ = fill_userbuffers_buffer_for_all_gather(
grad_output = ctx.ub_obj_gradout.get_buffer(quantizer) ctx.ub_obj_gradout,
grad_output,
quantizer,
ctx.tp_group,
)
else: else:
grad_output, _ = gather_along_first_dim( grad_output, _ = gather_along_first_dim(
grad_output, grad_output,
...@@ -1165,7 +1312,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1165,7 +1312,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute(): if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
return return
with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"): with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"):
(wgrad, grad_bias_, _, _), _ = self.wgrad_store.pop() (wgrad, bgrad), _ = self.wgrad_store.pop()
if not self.fuse_wgrad_accumulation: if not self.fuse_wgrad_accumulation:
unfused_weights = [getattr(self, name) for name in self.weight_names] unfused_weights = [getattr(self, name) for name in self.weight_names]
weight_tensor = noop_cat(unfused_weights) weight_tensor = noop_cat(unfused_weights)
...@@ -1174,9 +1321,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1174,9 +1321,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.use_bias: if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
if bias_tensor.grad is None: if bias_tensor.grad is None:
bias_tensor.grad = grad_bias_.to(bias_tensor.dtype) bias_tensor.grad = bgrad.to(bias_tensor.dtype)
del grad_bias_
del wgrad
def _validate_name(self): def _validate_name(self):
""" """
......
This diff is collapsed.
...@@ -534,7 +534,9 @@ class BasicLinear(BasicOperation): ...@@ -534,7 +534,9 @@ class BasicLinear(BasicOperation):
# Configure input tensor for backward pass # Configure input tensor for backward pass
if with_quantized_compute and isinstance(x_local, QuantizedTensor): if with_quantized_compute and isinstance(x_local, QuantizedTensor):
x_local.update_usage(rowwise_usage=False, columnwise_usage=True) if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
return y, x_local, w return y, x_local, w
...@@ -622,7 +624,10 @@ class BasicLinear(BasicOperation): ...@@ -622,7 +624,10 @@ class BasicLinear(BasicOperation):
# Check datatype # Check datatype
if dtype is None: if dtype is None:
dtype = weight.dtype if weight is not None:
dtype = weight.dtype
else:
dtype = grad_output.dtype
dtype = canonicalize_dtype(dtype) dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16): if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
...@@ -814,7 +819,7 @@ class BasicLinear(BasicOperation): ...@@ -814,7 +819,7 @@ class BasicLinear(BasicOperation):
x_async = None x_async = None
dy_async = None dy_async = None
# Check grad input tensor # Check grad weight tensor
dw = grad_weight dw = grad_weight
dw_dtype = dtype dw_dtype = dtype
if dw is None: if dw is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment