Unverified Commit ce0b46c4 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

MXFP8 support in Userbuffers (#1711)



* Initial work toward restoring UB support in te.Sequential
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Forward UB linear runs, but has numerical error
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug UB forward tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Minor tweaks
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove Python checks for MXFP8 UB linear forward
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add dim check for MXFP8 full tiles
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move QuantizedTensor logic out of UB comm and into Python helper function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support MXFP8 AGs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Coalesce NCCL all-gathers for MXFP8 all-gather
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Initial impl of backward UB linear in te.Sequential
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug UB linear backward with no quantization

dgrad GEMM + dx RS is still broken.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix chunk dims for dgrad GEMM + dx RS
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debugging MXFP8 UB cases

Still failing with dy AG + wgrad GEMM
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use NCCL to overlap dy AG with dgrad GEMM
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug UB GEMM tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Initial refactoring of linear module forward
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactor linear module backward
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug linear module UB tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tweak test tensor dims
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Do not store autograd context within wgrad GEMM closure
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update LayerNormLinear
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update LayerNormMLP
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug UB tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Minor style tweaks
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix incorrect usage for GEMM input with block-scaled FP8
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix RS out dims
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable dgrad GEMM + UB AG + NCCL AG overlapping
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Disable dgrad GEMM + UB AG + NCCL AG overlap in te.Sequential
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Restore support for internal quantized tensors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add tests for MXFP8 GEMM with UB
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 3d2152c2
...@@ -26,7 +26,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PAT ...@@ -26,7 +26,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PAT
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
# python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
......
...@@ -20,7 +20,11 @@ from torch.distributed.elastic.multiprocessing.errors import record ...@@ -20,7 +20,11 @@ from torch.distributed.elastic.multiprocessing.errors import record
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_cublas_workspace_size_bytes from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.module.base import (
fill_userbuffers_buffer_for_all_gather,
get_cublas_workspace_size_bytes,
)
warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=FutureWarning)
...@@ -56,7 +60,11 @@ def _parse_args(argv=None, namespace=None): ...@@ -56,7 +60,11 @@ def _parse_args(argv=None, namespace=None):
) )
parser.add_argument("--seed", type=int, default=42, help="RNG seed.") parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
parser.add_argument( parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." "--quantization",
type=str.lower,
default="none",
choices=["none", "fp8", "mxfp8"],
help="Quantization recipe",
) )
parser.add_argument( parser.add_argument(
"--fp8-output", action="store_true", default=False, help="Get FP8 output from GEMM." "--fp8-output", action="store_true", default=False, help="Get FP8 output from GEMM."
...@@ -154,9 +162,9 @@ def _parse_args(argv=None, namespace=None): ...@@ -154,9 +162,9 @@ def _parse_args(argv=None, namespace=None):
if opts.atomic: if opts.atomic:
warnings.warn("Atomic GEMM is not supported with bulk overlap.") warnings.warn("Atomic GEMM is not supported with bulk overlap.")
opts.atomic = False opts.atomic = False
if opts.fp8: if opts.quantization != "none":
warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.") warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.")
opts.fp8 = False opts.quantization = "none"
elif opts.comm_type == tex.CommOverlapType.AG: elif opts.comm_type == tex.CommOverlapType.AG:
if opts.atomic: if opts.atomic:
setattr(opts, "atomic_rs_p2p", opts.p2p) setattr(opts, "atomic_rs_p2p", opts.p2p)
...@@ -164,8 +172,11 @@ def _parse_args(argv=None, namespace=None): ...@@ -164,8 +172,11 @@ def _parse_args(argv=None, namespace=None):
if opts.atomic: if opts.atomic:
if not te.fp8.check_fp8_support(): if not te.fp8.check_fp8_support():
assert not opts.fp8, "Atomic GEMM is only supported in FP8." assert opts.quantization == "none", "Atomic GEMM is only supported in FP8."
opts.fp8 = True opts.quantization = "fp8"
if opts.fp8_output:
assert ops.quantization == "fp8", "FP8 output is only supported with FP8 compute."
return opts return opts
...@@ -302,7 +313,11 @@ def _main(opts): ...@@ -302,7 +313,11 @@ def _main(opts):
inp_shape = (opts.seq_length, opts.batch_size, hidden_size) inp_shape = (opts.seq_length, opts.batch_size, hidden_size)
outer_size = reduce(operator.mul, inp_shape[:-1], 1) outer_size = reduce(operator.mul, inp_shape[:-1], 1)
buffer_dtype = torch.bfloat16 buffer_dtype = torch.bfloat16
if opts.fp8 and not opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.AG: if (
opts.quantization != "none"
and not opts.bulk_overlap
and opts.comm_type == tex.CommOverlapType.AG
):
buffer_dtype = torch.uint8 buffer_dtype = torch.uint8
ub_obj = ( ub_obj = (
tex.CommOverlapP2P( tex.CommOverlapP2P(
...@@ -447,6 +462,8 @@ def _main(opts): ...@@ -447,6 +462,8 @@ def _main(opts):
inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable
ref2_g = torch.matmul(inp2_g, ker2_g) ref2_g = torch.matmul(inp2_g, ker2_g)
# Initialize quantizers
with_quantized_compute = opts.quantization != "none"
inp_quantizer = None inp_quantizer = None
ker_quantizer = None ker_quantizer = None
out_quantizer = None out_quantizer = None
...@@ -454,7 +471,7 @@ def _main(opts): ...@@ -454,7 +471,7 @@ def _main(opts):
inp2_quantizer = None inp2_quantizer = None
ker2_quantizer = None ker2_quantizer = None
out2_quantizer = None out2_quantizer = None
if opts.fp8: if opts.quantization == "fp8":
# Structure to maintain amax and scale/scale_inv information for the kernel and input # Structure to maintain amax and scale/scale_inv information for the kernel and input
num_gemms = 6 if ub_obj2 is not None else 3 num_gemms = 6 if ub_obj2 is not None else 3
fp8_dtype = tex.DType.kFloat8E4M3 fp8_dtype = tex.DType.kFloat8E4M3
...@@ -499,11 +516,23 @@ def _main(opts): ...@@ -499,11 +516,23 @@ def _main(opts):
out2_quantizer = Float8Quantizer( out2_quantizer = Float8Quantizer(
fp8_scales[5].clone(), fp8_amaxes[5].clone(), fp8_dtype fp8_scales[5].clone(), fp8_amaxes[5].clone(), fp8_dtype
) )
elif opts.quantization == "mxfp8":
fp8_dtype = tex.DType.kFloat8E4M3
inp_quantizer = MXFP8Quantizer(fp8_dtype, columnwise=False)
ker_quantizer = MXFP8Quantizer(fp8_dtype)
if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
bulk_inp_quantizer = MXFP8Quantizer(fp8_dtype, columnwise=False)
elif ub_obj2 is not None:
inp2_quantizer = MXFP8Quantizer(fp8_dtype, columnwise=False)
ker2_quantizer = MXFP8Quantizer(fp8_dtype)
# Cast input to Float8Tensor # Quantize tensors
if with_quantized_compute:
# Quantize input tensor
inp_fp8 = inp_quantizer(inp) inp_fp8 = inp_quantizer(inp)
# Cast kernel to Float8Tensor # Quantize kernel tensor
kernel_t_fp8 = ker_quantizer(kernel_t) kernel_t_fp8 = ker_quantizer(kernel_t)
if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
bulk_inp_fp8 = bulk_inp_quantizer(bulk_inp) bulk_inp_fp8 = bulk_inp_quantizer(bulk_inp)
...@@ -540,31 +569,40 @@ def _main(opts): ...@@ -540,31 +569,40 @@ def _main(opts):
) )
# Set up comm/compute buffers # Set up comm/compute buffers
ag_out = None
rs_out = None rs_out = None
rs_out2 = None rs_out2 = None
if opts.comm_type == tex.CommOverlapType.AG: if opts.comm_type == tex.CommOverlapType.AG:
if opts.bulk_overlap: if opts.bulk_overlap:
ub_obj.copy_into_buffer(bulk_inp, bulk_inp_quantizer, True) ag_out, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj,
bulk_inp,
bulk_inp_quantizer,
tp_group,
)
gemm_inp = inp gemm_inp = inp
else: else:
ub_obj.copy_into_buffer(inp_fp8 if opts.fp8 else inp, inp_quantizer, True) ag_out, _ = fill_userbuffers_buffer_for_all_gather(
gemm_inp = ub_obj.get_buffer(inp_quantizer, False, inp_g.size()) ub_obj,
inp_fp8 if with_quantized_compute else inp,
inp_quantizer,
tp_group,
)
gemm_inp = ag_out
if ub_obj2 is not None: if ub_obj2 is not None:
if opts.fp8 and opts.fp8_output:
ub_obj2.set_buffer_params(out_quantizer)
rs_out2 = torch.empty( rs_out2 = torch.empty(
(outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda" (outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda"
) )
else: else:
if opts.bulk_overlap: if opts.bulk_overlap:
ub_obj.copy_into_buffer( if opts.quantization == "none":
bulk_inp_fp8 if opts.fp8 else bulk_inp, bulk_inp_quantizer, False ub_obj.copy_into_buffer(bulk_inp, local_chunk=False)
) if opts.quantization == "fp8":
if opts.fp8: ub_obj.copy_into_buffer(bulk_inp_fp8._data, local_chunk=False)
ub_obj.set_buffer_params(bulk_inp_quantizer) elif opts.quantization == "mxfp8":
elif opts.fp8 and opts.fp8_output: ub_obj.copy_into_buffer(bulk_inp_fp8._rowwise_data, local_chunk=False)
ub_obj.set_buffer_params(out_quantizer)
gemm_inp = inp_fp8 if opts.fp8 else inp gemm_inp = inp_fp8 if with_quantized_compute else inp
rs_out = torch.empty( rs_out = torch.empty(
(outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda" (outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda"
) )
...@@ -623,7 +661,7 @@ def _main(opts): ...@@ -623,7 +661,7 @@ def _main(opts):
if opts.use_cuda_graphs: if opts.use_cuda_graphs:
# Trace the CUDA graph first # Trace the CUDA graph first
g = torch.cuda.CUDAGraph() g = torch.cuda.CUDAGraph()
if opts.fp8: if with_quantized_compute:
if ub_obj is None: if ub_obj is None:
with torch.cuda.graph(g): with torch.cuda.graph(g):
all_outputs = _fp8_gemm() all_outputs = _fp8_gemm()
...@@ -643,7 +681,7 @@ def _main(opts): ...@@ -643,7 +681,7 @@ def _main(opts):
else: else:
for i in range(total_iters): for i in range(total_iters):
if opts.fp8: if with_quantized_compute:
start_events[i].record() start_events[i].record()
all_outputs = _fp8_gemm() all_outputs = _fp8_gemm()
end_events[i].record() end_events[i].record()
...@@ -688,10 +726,22 @@ def _main(opts): ...@@ -688,10 +726,22 @@ def _main(opts):
output_info = "" output_info = ""
if opts.comm_type == tex.CommOverlapType.AG: if opts.comm_type == tex.CommOverlapType.AG:
# Bulk overlap AG output is already gathered # Bulk overlap AG output is already gathered
test_out = ub_obj.get_buffer(bulk_inp_quantizer, False) test_out = ag_out
if bulk_inp_quantizer is None:
test_out = ub_obj.get_buffer(False)
else:
test_out = Float8Tensor(
shape=test_out.shape,
dtype=torch.bfloat16,
data=ub_obj.get_buffer(False),
fp8_scale=bulk_inp_quantizer.scale,
fp8_dtype=bulk_inp_quantizer.dtype,
quantizer=bulk_inp_quantizer,
)
else: else:
# Bulk overlap RS output needs to be gathered # Bulk overlap RS output needs to be gathered
out_local = ub_obj.get_buffer(bulk_inp_quantizer, True) out_local = ub_obj.get_buffer(True)
output_info += f"rs_output: {list(out_local.shape)} | " output_info += f"rs_output: {list(out_local.shape)} | "
test_out = te.distributed.gather_along_first_dim(out_local, tp_group)[0] test_out = te.distributed.gather_along_first_dim(out_local, tp_group)[0]
...@@ -762,8 +812,8 @@ def _main(opts): ...@@ -762,8 +812,8 @@ def _main(opts):
m = torch.argmax(diff) m = torch.argmax(diff)
abs_err = diff[m].item() abs_err = diff[m].item()
rel_err = abs_err / max(abs(ref_out.flatten()[m].item()), 1e-5) rel_err = abs_err / max(abs(ref_out.flatten()[m].item()), 1e-5)
rtol = 0.125 if opts.fp8 else 0.02 rtol = 0.02 if opts.quantization == "none" else 0.125
atol = 0.0625 if opts.fp8 else 0.001 atol = 0.001 if opts.quantization == "none" else 0.0625
if rel_err > rtol and abs_err > atol: if rel_err > rtol and abs_err > atol:
numerics_failed = True numerics_failed = True
numerics_info = ( numerics_info = (
......
...@@ -17,7 +17,12 @@ import torch ...@@ -17,7 +17,12 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling, Float8CurrentScaling from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
Format,
MXFP8BlockScaling,
)
warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=FutureWarning)
...@@ -163,7 +168,7 @@ def _parse_args(argv=None, namespace=None): ...@@ -163,7 +168,7 @@ def _parse_args(argv=None, namespace=None):
"--quantization", "--quantization",
type=str.lower, type=str.lower,
default="none", default="none",
choices=["none", "fp8_delayed_scaling", "fp8_current_scaling"], choices=["none", "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"],
help="Quantization recipe", help="Quantization recipe",
) )
parser.add_argument( parser.add_argument(
...@@ -414,6 +419,8 @@ def _train(opts): ...@@ -414,6 +419,8 @@ def _train(opts):
) )
elif opts.quantization == "fp8_current_scaling": elif opts.quantization == "fp8_current_scaling":
fp8_recipe = Float8CurrentScaling(fp8_format=fp8_format) fp8_recipe = Float8CurrentScaling(fp8_format=fp8_format)
elif opts.quantization == "mxfp8":
fp8_recipe = MXFP8BlockScaling()
# Prepare random input tensors # Prepare random input tensors
test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True) test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True)
......
...@@ -15,11 +15,12 @@ if torch.cuda.device_count() < 2: ...@@ -15,11 +15,12 @@ if torch.cuda.device_count() < 2:
pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.") pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
RNG_SEED: int = 42 RNG_SEED: int = 42
SEQ_LENGTH: int = 1024 SEQ_LENGTH: int = 1024
BATCH_SIZE: int = 2 BATCH_SIZE: int = 2
NUM_HEADS: int = 16 NUM_HEADS: int = 32
HEAD_DIM: int = 48 HEAD_DIM: int = 48
TE_LAYERS = [ TE_LAYERS = [
te.Linear, te.Linear,
...@@ -50,7 +51,7 @@ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" ...@@ -50,7 +51,7 @@ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
torch._dynamo.reset() torch._dynamo.reset()
def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8): def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, quantization):
test_path = TEST_ROOT / "run_gemm_with_overlap.py" test_path = TEST_ROOT / "run_gemm_with_overlap.py"
test_cmd = LAUNCH_CMD + [ test_cmd = LAUNCH_CMD + [
str(test_path), str(test_path),
...@@ -66,10 +67,11 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8): ...@@ -66,10 +67,11 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8):
if bulk: if bulk:
test_cmd.append("--bulk-overlap") test_cmd.append("--bulk-overlap")
else: else:
if fp8: if quantization == "fp8" and not fp8_available:
if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
test_cmd.append("--fp8") if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
test_cmd.append(f"--quantization={quantization}")
if p2p: if p2p:
test_cmd.append("--p2p") test_cmd.append("--p2p")
if atomic: if atomic:
...@@ -107,8 +109,10 @@ def _run_layer_with_overlap( ...@@ -107,8 +109,10 @@ def _run_layer_with_overlap(
test_cmd.append("--overlap-rs-dgrad") test_cmd.append("--overlap-rs-dgrad")
if fp8: if fp8:
if not fp8_available: if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
test_cmd.append("--fp8") test_cmd.append("--fp8")
test_cmd.append(f"--quantization={quantization}") test_cmd.append(f"--quantization={quantization}")
...@@ -130,51 +134,34 @@ def _run_layer_with_overlap( ...@@ -130,51 +134,34 @@ def _run_layer_with_overlap(
raise AssertionError(result.stderr.decode()) raise AssertionError(result.stderr.decode())
@pytest.mark.parametrize( @pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8"))
"fp8", def test_split_all_gather_overlaps(quantization):
(False, True),
ids=[" BF16 - RING-EXCHANGE ", " FP8 - RING-EXCHANGE "],
)
def test_split_all_gather_overlaps(fp8):
""" """
Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm. te.cpp_extensions.fp8_gemm.
""" """
_run_gemm_with_overlap("AG", False, True, False, fp8) _run_gemm_with_overlap("AG", False, True, False, quantization)
@pytest.mark.parametrize( @pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8"))
"fp8,p2p", @pytest.mark.parametrize("p2p", (False, True))
[ def test_split_reduce_scatter_overlaps(quantization, p2p):
(False, False),
(False, True),
(True, False),
(True, True),
],
ids=[
" BF16 - PIPELINE ",
" BF16 - RING-EXCHANGE ",
" FP8 - PIPELINE ",
" FP8 - RING-EXCHANGE ",
],
)
def test_split_reduce_scatter_overlaps(fp8, p2p):
""" """
Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm. te.cpp_extensions.fp8_gemm.
""" """
_run_gemm_with_overlap("RS", False, p2p, False, fp8) _run_gemm_with_overlap("RS", False, p2p, False, quantization)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"comm_type, fp8, connections", "comm_type, quantization, connections",
[ [
("AG", False, 1), ("AG", "none", 1),
("RS", False, 1), ("RS", "none", 1),
("RS", True, 1), ("RS", "fp8", 1),
("AG", False, 8), ("AG", "none", 8),
("RS", False, 8), ("RS", "none", 8),
("RS", True, 8), ("RS", "fp8", 8),
], ],
ids=[ ids=[
"ALL-GATHER - BF16 - 1 connections", "ALL-GATHER - BF16 - 1 connections",
...@@ -185,7 +172,7 @@ def test_split_reduce_scatter_overlaps(fp8, p2p): ...@@ -185,7 +172,7 @@ def test_split_reduce_scatter_overlaps(fp8, p2p):
"REDUCE-SCATTER - FP8 - 8 connections", "REDUCE-SCATTER - FP8 - 8 connections",
], ],
) )
def test_bulk_overlaps(comm_type, fp8, connections): def test_bulk_overlaps(comm_type, quantization, connections):
""" """
Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm.
""" """
...@@ -196,10 +183,10 @@ def test_bulk_overlaps(comm_type, fp8, connections): ...@@ -196,10 +183,10 @@ def test_bulk_overlaps(comm_type, fp8, connections):
" 9.0 (HOPPER ARCH)." " 9.0 (HOPPER ARCH)."
) )
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
_run_gemm_with_overlap(comm_type, True, False, False, fp8) _run_gemm_with_overlap(comm_type, True, False, False, quantization)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
else: else:
_run_gemm_with_overlap(comm_type, True, False, False, fp8) _run_gemm_with_overlap(comm_type, True, False, False, quantization)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -251,15 +238,7 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d ...@@ -251,15 +238,7 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
@pytest.mark.parametrize( @pytest.mark.parametrize(
"quantization", "quantization",
["fp8_delayed_scaling", "fp8_current_scaling"], ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"],
ids=[" DELAYED SCALING ", " CURRENT SCALING "],
)
@pytest.mark.parametrize(
"fp8",
(True,),
ids=[
" FP8 ",
],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"layer_type,linear_parallel_mode,overlap_rs_dgrad", "layer_type,linear_parallel_mode,overlap_rs_dgrad",
...@@ -279,15 +258,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d ...@@ -279,15 +258,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
) )
), ),
ids=[ ids=[
f" {te.Linear.__name__} - ROW-PARALLEL ", f"{te.Linear.__name__}-row_tensor_parallel",
f" {te.Linear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", f"{te.Linear.__name__}-col_tensor_parallel-BULK DGRAD/WGRAD",
f" {te.Linear.__name__} - COL-PARLALEL - DGRAD+RS ", f"{te.Linear.__name__}-col_tensor_parallel-DGRAD+RS",
f" {te.LayerNormLinear.__name__} - ROW-PARALLEL ", f"{te.LayerNormLinear.__name__}-row_tensor_parallel",
f" {te.LayerNormLinear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", f"{te.LayerNormLinear.__name__}-col_tensor_parallel-BULK DGRAD/WGRAD",
f" {te.LayerNormLinear.__name__} - COL-PARALLEL - DGRAD+RS ", f"{te.LayerNormLinear.__name__}-col_tensor_parallel-DGRAD+RS",
] ]
+ [ + [
" " + " - ".join(test_name_parts) + " " "-".join(test_name_parts)
for test_name_parts in zip( for test_name_parts in zip(
[layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)], [layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)],
["BULK DGRAD/WGRAD", "DGRAD+RS"] * len(TE_LAYERS[2:]), ["BULK DGRAD/WGRAD", "DGRAD+RS"] * len(TE_LAYERS[2:]),
...@@ -295,12 +274,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d ...@@ -295,12 +274,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
], ],
) )
def test_layers_with_overlap_fp8( def test_layers_with_overlap_fp8(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization layer_type,
linear_parallel_mode,
overlap_rs_dgrad,
quantization,
): ):
""" """
Test Transformer Engine layers with comm+GEMM overlap. Test Transformer Engine layers with comm+GEMM overlap.
""" """
_run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization) _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -347,22 +329,11 @@ def test_multi_layer_with_overlap_bf16( ...@@ -347,22 +329,11 @@ def test_multi_layer_with_overlap_bf16(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"quantization", "quantization",
["fp8_delayed_scaling", "fp8_current_scaling"], ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"],
ids=[" DELAYED SCALING ", " CURRENT SCALING "],
)
@pytest.mark.parametrize(
"fp8",
(True,),
ids=[
" FP8 ",
],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_layers", "num_layers",
(2,), (2,),
ids=[
" 2 layers ",
],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"layer_type,linear_parallel_mode,overlap_rs_dgrad", "layer_type,linear_parallel_mode,overlap_rs_dgrad",
...@@ -374,7 +345,7 @@ def test_multi_layer_with_overlap_bf16( ...@@ -374,7 +345,7 @@ def test_multi_layer_with_overlap_bf16(
) )
), ),
ids=[ ids=[
" " + " - ".join(test_name_parts) + " " "-".join(test_name_parts)
for test_name_parts in zip( for test_name_parts in zip(
[te.TransformerLayer.__name__ for _ in range(2)], [te.TransformerLayer.__name__ for _ in range(2)],
["BULK DGRAD/WGRAD", "DGRAD+RS"], ["BULK DGRAD/WGRAD", "DGRAD+RS"],
...@@ -382,11 +353,11 @@ def test_multi_layer_with_overlap_bf16( ...@@ -382,11 +353,11 @@ def test_multi_layer_with_overlap_bf16(
], ],
) )
def test_multi_layer_with_overlap_fp8( def test_multi_layer_with_overlap_fp8(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers layer_type, linear_parallel_mode, overlap_rs_dgrad, quantization, num_layers
): ):
""" """
Test Transformer Engine layers with comm+GEMM overlap. Test Transformer Engine layers with comm+GEMM overlap.
""" """
_run_layer_with_overlap( _run_layer_with_overlap(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization, num_layers
) )
...@@ -19,7 +19,6 @@ import torch ...@@ -19,7 +19,6 @@ import torch
import transformer_engine import transformer_engine
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor from transformer_engine.pytorch.ops._common import is_float8_tensor
...@@ -27,6 +26,8 @@ from transformer_engine.pytorch.ops.fused import ( ...@@ -27,6 +26,8 @@ from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear, UserbuffersBackwardLinear,
UserbuffersForwardLinear, UserbuffersForwardLinear,
) )
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
# Import utility functions # Import utility functions
...@@ -36,6 +37,13 @@ from utils import dtype_tols, str_to_dtype ...@@ -36,6 +37,13 @@ from utils import dtype_tols, str_to_dtype
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
quantization_list: list[Optional[str]] = [None]
if fp8_available:
quantization_list.append("fp8")
if mxfp8_available:
quantization_list.append("mxfp8")
# Check if there are multiple GPUs # Check if there are multiple GPUs
if torch.cuda.device_count() < 2: if torch.cuda.device_count() < 2:
...@@ -51,7 +59,7 @@ class ModelConfig: ...@@ -51,7 +59,7 @@ class ModelConfig:
num_heads: int num_heads: int
head_dim: int head_dim: int
dtype: torch.dtype dtype: torch.dtype
fp8: bool quantization: Optional[str]
@property @property
def hidden_size(self): def hidden_size(self):
...@@ -129,11 +137,15 @@ def make_reference_and_test_tensors( ...@@ -129,11 +137,15 @@ def make_reference_and_test_tensors(
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Make copy of tensor # Make copy of tensor
if test_is_fp8:
test = Float8Tensor.to_float8(ref)
else:
test = ref.to(device=test_device, dtype=test_dtype) test = ref.to(device=test_device, dtype=test_dtype)
if test.data_ptr() == ref.data_ptr(): if test_is_fp8:
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device),
amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
test = quantizer(test)
elif test.data_ptr() == ref.data_ptr():
test = test.clone() test = test.clone()
# Make sure reference and test tensors represent exact same values # Make sure reference and test tensors represent exact same values
...@@ -145,6 +157,21 @@ def make_reference_and_test_tensors( ...@@ -145,6 +157,21 @@ def make_reference_and_test_tensors(
return ref, test return ref, test
def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name == "fp8":
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
def _test_linear( def _test_linear(
*, *,
model_config: ModelConfig, model_config: ModelConfig,
...@@ -155,7 +182,8 @@ def _test_linear( ...@@ -155,7 +182,8 @@ def _test_linear(
weight_requires_grad: bool = True, weight_requires_grad: bool = True,
) -> None: ) -> None:
dtype = model_config.dtype dtype = model_config.dtype
fp8_compute = model_config.fp8 quantization = model_config.quantization
quantized_compute = quantization is not None
# Distributed process group # Distributed process group
process_group = world_group() process_group = world_group()
...@@ -175,14 +203,19 @@ def _test_linear( ...@@ -175,14 +203,19 @@ def _test_linear(
in_shape, in_shape,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8_compute, test_is_fp8=quantized_compute,
) )
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors( w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features), (out_features, in_features),
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8_compute, test_is_fp8=quantized_compute,
) )
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
b_ref, b_test = None, None b_ref, b_test = None, None
if bias: if bias:
if tensor_parallel_mode == "row": if tensor_parallel_mode == "row":
...@@ -198,9 +231,11 @@ def _test_linear( ...@@ -198,9 +231,11 @@ def _test_linear(
out_shape, out_shape,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8_compute, test_is_fp8=quantized_compute,
requires_grad=False, requires_grad=False,
) )
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation # Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref) y_ref = torch.nn.functional.linear(x_ref, w_ref)
...@@ -265,21 +300,15 @@ def _test_linear( ...@@ -265,21 +300,15 @@ def _test_linear(
x_test.requires_grad_() x_test.requires_grad_()
# Implementation with fusible operation # Implementation with fusible operation
with te.fp8_model_init(enabled=fp8_compute): recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_compute, recipe=recipe):
ops = [] ops = []
linear_op = None linear_op = None
bias_op = None bias_op = None
if tensor_parallel_mode == "column": if tensor_parallel_mode == "column":
userbuffers_options = {} userbuffers_options = {}
if not weight_requires_grad: if not weight_requires_grad:
if fp8_compute:
userbuffers_options["comm_name"] = "fc1" userbuffers_options["comm_name"] = "fc1"
else:
# There is a correctness bug with overlapping
# dgrad reduce-scatter with dgrad GEMM. Fall back
# to overlapping dgrad reduce-scatter with wgrad
# GEMM, even though wgrad isn't needed.
userbuffers_options["comm_name"] = "qkv"
else: else:
userbuffers_options["comm_name"] = "qkv" userbuffers_options["comm_name"] = "qkv"
linear_op = te_ops.BasicLinear( linear_op = te_ops.BasicLinear(
...@@ -322,7 +351,7 @@ def _test_linear( ...@@ -322,7 +351,7 @@ def _test_linear(
bias_op.bias.copy_(b_test) bias_op.bias.copy_(b_test)
del w_test del w_test
del b_test del b_test
with te.fp8_autocast(enabled=fp8_compute): with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = model(x_test) y_test = model(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -338,7 +367,7 @@ def _test_linear( ...@@ -338,7 +367,7 @@ def _test_linear(
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
if dtype == torch.float32: if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM tols = dtype_tols(torch.float16) # TF32 GEMM
if fp8_compute: if quantized_compute:
tols = dtype_tols( tols = dtype_tols(
model[0].weight._fp8_dtype model[0].weight._fp8_dtype
if is_float8_tensor(model[0].weight) if is_float8_tensor(model[0].weight)
...@@ -370,7 +399,7 @@ def run_parallel_tests(model_config: ModelConfig) -> None: ...@@ -370,7 +399,7 @@ def run_parallel_tests(model_config: ModelConfig) -> None:
for test_config in itertools.product( for test_config in itertools.product(
(False, True), # bias (False, True), # bias
("column", "row"), # tensor_parallel_mode ("column", "row"), # tensor_parallel_mode
(False, True), # weight_requires_grad (True, False), # weight_requires_grad
): ):
if rank == 0: if rank == 0:
print(f"Running _test_linear with {test_config=}") print(f"Running _test_linear with {test_config=}")
...@@ -390,19 +419,15 @@ if torch.cuda.device_count() > 1: ...@@ -390,19 +419,15 @@ if torch.cuda.device_count() > 1:
@pytest.mark.parametrize("world_size", _world_sizes) @pytest.mark.parametrize("world_size", _world_sizes)
@pytest.mark.parametrize("fp8", (False, True)) @pytest.mark.parametrize("quantization", quantization_list)
def test_fuser_ops_with_userbuffers( def test_fuser_ops_with_userbuffers(
*, *,
world_size: int, world_size: int,
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
fp8: bool, quantization: Optional[str],
) -> None: ) -> None:
"""Launch parallel job and run tests""" """Launch parallel job and run tests"""
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
# Parallel job launcher # Parallel job launcher
command = [] command = []
if tex.ubuf_built_with_mpi(): if tex.ubuf_built_with_mpi():
...@@ -424,8 +449,8 @@ def test_fuser_ops_with_userbuffers( ...@@ -424,8 +449,8 @@ def test_fuser_ops_with_userbuffers(
str(dtype), str(dtype),
) )
) )
if fp8: if quantization is not None:
command.append("--fp8") command.extend(("--quantization", quantization))
# Environment # Environment
env = dict(os.environ) env = dict(os.environ)
...@@ -445,12 +470,12 @@ def main() -> None: ...@@ -445,12 +470,12 @@ def main() -> None:
# Parse command-line arguments # Parse command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--parallel", action="store_true", help="Run parallel tests") parser.add_argument("--parallel", action="store_true", help="Run parallel tests")
parser.add_argument("--sequence-length", type=int, default=32) parser.add_argument("--sequence-length", type=int, default=256)
parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--num-heads", type=int, default=16) parser.add_argument("--num-heads", type=int, default=16)
parser.add_argument("--head-dim", type=int, default=32) parser.add_argument("--head-dim", type=int, default=256)
parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--fp8", action="store_true") parser.add_argument("--quantization", type=str, default=None)
args = parser.parse_args() args = parser.parse_args()
# Run parallel tests if needed # Run parallel tests if needed
...@@ -463,14 +488,17 @@ def main() -> None: ...@@ -463,14 +488,17 @@ def main() -> None:
num_heads=args.num_heads, num_heads=args.num_heads,
head_dim=args.head_dim, head_dim=args.head_dim,
dtype=str_to_dtype(args.dtype), dtype=str_to_dtype(args.dtype),
fp8=args.fp8, quantization=args.quantization,
) )
# Initialize Userbuffers # Initialize Userbuffers
group = world_group() # Initialize NCCL group = world_group() # Initialize NCCL
bootstrap_backend = "mpi" if launcher() == "ompi" else "nccl" bootstrap_backend = "mpi" if launcher() == "ompi" else "nccl"
userbuffer_configs = { userbuffer_configs = {
"fc1_dgrad": {"method": "pipeline"}, # Overlap dgrad RS with dgrad GEMM "fc1_dgrad": {
"method": "ring_exchange",
"fp8_buf": False,
}, # Overlap dgrad RS with dgrad GEMM
} }
te.module.base.initialize_ub( te.module.base.initialize_ub(
[ [
...@@ -478,7 +506,7 @@ def main() -> None: ...@@ -478,7 +506,7 @@ def main() -> None:
model_config.num_heads * model_config.head_dim, model_config.num_heads * model_config.head_dim,
], ],
torch.distributed.get_world_size(group), torch.distributed.get_world_size(group),
use_fp8=model_config.fp8, use_fp8=model_config.quantization is not None,
dtype=model_config.dtype, dtype=model_config.dtype,
bootstrap_backend=bootstrap_backend, bootstrap_backend=bootstrap_backend,
ub_cfgs=userbuffer_configs, ub_cfgs=userbuffer_configs,
......
...@@ -147,7 +147,44 @@ CommOverlapCore::~CommOverlapCore() { ...@@ -147,7 +147,44 @@ CommOverlapCore::~CommOverlapCore() {
TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, size_t chunk_offset, TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, size_t chunk_offset,
const std::vector<size_t> &chunk_shape) { const std::vector<size_t> &chunk_shape) {
TensorWrapper chunk; const auto scaling_mode = source.scaling_mode();
// Tensor dimensions
std::vector<size_t> shape = shape_to_vector(source.shape());
auto flatten_shape_to_2d = [](const std::vector<size_t> &shape) -> std::pair<size_t, size_t> {
if (shape.empty()) {
return {1, 1};
}
size_t height = 1;
for (size_t i = 0; i < shape.size() - 1; ++i) {
height *= shape[i];
}
return {height, shape.back()};
};
size_t height, width, chunk_height, chunk_width;
std::tie(height, width) = flatten_shape_to_2d(shape);
std::tie(chunk_height, chunk_width) = flatten_shape_to_2d(chunk_shape);
// Check tensor dimensions
#define NVTE_DIM_CHECK(cond, message) \
NVTE_CHECK(cond, message, " (tensor shape=", shape, ", chunk shape=", chunk_shape, \
", chunk offset=", chunk_offset, ")")
NVTE_DIM_CHECK(height > 0 && width > 0, "Attempted to get chunk from empty tensor");
NVTE_DIM_CHECK(chunk_height > 0 && chunk_width > 0, "Attempted to get empty tensor chunk");
NVTE_DIM_CHECK(chunk_height <= height && chunk_width <= width,
"Attempted to get out-of-bounds tensor chunk");
if (scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING) {
// MXFP8 scale-inverses are padded to a 2D matrix with dims that
// are divisible by 128. UB doesn't handle this padding yet.
NVTE_DIM_CHECK(height % 128 == 0 && width % 128 == 0,
"Userbuffers requires MXFP8 tensor dims that are divisible by 128");
NVTE_DIM_CHECK(chunk_height % 128 == 0 && chunk_width % 128 == 0,
"Userbuffers requires MXFP8 tensor chunk dims that are divisible by 128");
}
#undef NVTE_DIM_CHECK
// Construct tensor chunk
TensorWrapper chunk(scaling_mode);
for (int param_id = 0; param_id < NVTETensorParam::kNVTENumTensorParams; param_id++) { for (int param_id = 0; param_id < NVTETensorParam::kNVTENumTensorParams; param_id++) {
auto param_type = static_cast<NVTETensorParam>(param_id); auto param_type = static_cast<NVTETensorParam>(param_id);
auto param = source.get_parameter(param_type); auto param = source.get_parameter(param_type);
...@@ -163,8 +200,8 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz ...@@ -163,8 +200,8 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
param_shape = chunk_shape; param_shape = chunk_shape;
if (param_type == NVTETensorParam::kNVTEColumnwiseData && if (param_type == NVTETensorParam::kNVTEColumnwiseData &&
source.scaling_mode() != NVTEScalingMode::NVTE_MXFP8_1D_SCALING) { source.scaling_mode() == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) {
// Columnwise shape for non-block scaled tensors shifts the last dimension to the front // Columnwise shape for FP8 tensor-scaled tensors shifts the last dimension to the front
auto last_dim = param_shape.back(); auto last_dim = param_shape.back();
param_shape.pop_back(); param_shape.pop_back();
param_shape.insert(param_shape.begin(), last_dim); param_shape.insert(param_shape.begin(), last_dim);
...@@ -172,18 +209,16 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz ...@@ -172,18 +209,16 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
} else if (source.scaling_mode() == NVTEScalingMode::NVTE_MXFP8_1D_SCALING && } else if (source.scaling_mode() == NVTEScalingMode::NVTE_MXFP8_1D_SCALING &&
(param_type == NVTETensorParam::kNVTERowwiseScaleInv || (param_type == NVTETensorParam::kNVTERowwiseScaleInv ||
param_type == NVTETensorParam::kNVTEColumnwiseScaleInv)) { param_type == NVTETensorParam::kNVTEColumnwiseScaleInv)) {
// Calculate block scaling offset and size // Calculate offset and size for MXFP8 scale-invs
auto scaled_tensor_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv) size_t chunk_scale_height = chunk_height;
? source.shape().data[0] size_t chunk_scale_width = chunk_width;
: source.columnwise_shape().data[0]; if (param_type == NVTETensorParam::kNVTERowwiseScaleInv) {
auto scaled_chunk_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv) chunk_scale_width /= 32;
? chunk_shape.front() } else {
: chunk_shape.back(); chunk_scale_height /= 32;
auto chunk_scale_start = chunk_offset / 32; }
auto chunk_scale_end = (chunk_offset + scaled_chunk_dim_size) / 32; param_dptr += (chunk_offset / 32) * typeToSize(param_dtype);
auto chunk_scale_size = chunk_scale_end - chunk_scale_start; param_shape = {chunk_scale_height, chunk_scale_width};
param_dptr += chunk_scale_start * typeToSize(param_dtype);
param_shape = std::vector<size_t>{chunk_scale_size};
} }
// Set chunked source parameters into the chunked tensor output // Set chunked source parameters into the chunked tensor output
...@@ -422,10 +457,21 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -422,10 +457,21 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
size_t k = transa ? A.size(1) : A.size(0); size_t k = transa ? A.size(1) : A.size(0);
size_t n = _ubuf.size(0); size_t n = _ubuf.size(0);
size_t m_chunk = m / _num_splits; size_t m_chunk = m / _num_splits;
const std::vector<size_t> input_a_chunk_shape =
(transa ? std::vector<size_t>{m_chunk, k} : std::vector<size_t>{k, m_chunk});
const std::vector<size_t> output_chunk_shape = {n, m_chunk};
size_t input_a_chunk_size = m_chunk * k; size_t input_a_chunk_size = m_chunk * k;
size_t output_chunk_size = n * m_chunk; size_t output_chunk_size = n * m_chunk;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Helper function to get bias chunk if needed
auto maybe_get_bias_chunk = [this, &bias, m_chunk](size_t chunk_id) -> TensorWrapper {
if (bias.dptr() == nullptr) {
return TensorWrapper();
}
return get_tensor_chunk(bias, chunk_id * m_chunk, {m_chunk});
};
// Catch up the default torch stream // Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
for (size_t i = 0; i < _stream_compute.size(); i++) { for (size_t i = 0; i < _stream_compute.size(); i++) {
...@@ -437,21 +483,23 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -437,21 +483,23 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr()); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
if (_rs_overlap_first_gemm) { if (_rs_overlap_first_gemm) {
auto input_a_chunk = get_tensor_chunk(A, 0, {m_chunk, k}); auto input_a_chunk = get_tensor_chunk(A, 0, input_a_chunk_shape);
auto output_chunk = get_buffer_chunk_like(D, 0, {m, m_chunk}); auto output_chunk = get_buffer_chunk_like(D, 0, output_chunk_shape);
auto bias_chunk = maybe_get_bias_chunk(0);
auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0]); use_split_accumulator, _math_sms, _stream_compute[0]);
for (int i = 1; i < _num_splits; i++) { for (int i = 1; i < _num_splits; i++) {
input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k}); input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, input_a_chunk_shape);
output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk}); output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, output_chunk_shape);
bias_chunk = maybe_get_bias_chunk(i);
workspace_chunk = get_tensor_chunk( workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms, accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]); _stream_compute[i % _stream_compute.size()]);
...@@ -494,12 +542,13 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -494,12 +542,13 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
} }
} else { } else {
for (int i = 0; i < _num_splits; i++) { for (int i = 0; i < _num_splits; i++) {
auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k}); auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, input_a_chunk_shape);
auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk}); auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, output_chunk_shape);
auto bias_chunk = maybe_get_bias_chunk(i);
auto workspace_chunk = get_tensor_chunk( auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms, accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]); _stream_compute[i % _stream_compute.size()]);
...@@ -759,8 +808,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -759,8 +808,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// Get communication and GEMM output chunk sizes // Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const bool do_gelu = pre_gelu_out.numel() > 0; const bool do_gelu = pre_gelu_out.numel() > 0;
size_t input_chunk_size = n_chunk * k;
size_t output_chunk_size = n_chunk * m;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
...@@ -771,8 +818,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -771,8 +818,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
} }
if (_aggregate) { if (_aggregate) {
const int num_steps = _tp_size / 2; const int num_steps = _tp_size / 2;
input_chunk_size *= 2;
output_chunk_size *= 2; // Chunk dims
std::vector<size_t> input_b_chunk_shape =
(transb ? std::vector<size_t>{k, 2 * n_chunk} : std::vector<size_t>{2 * n_chunk, k});
std::vector<size_t> output_chunk_shape = {2 * n_chunk, k};
size_t input_b_chunk_size = 2 * n_chunk * k;
size_t output_chunk_size = 2 * n_chunk * m;
// Initial 1X input chunk exchange between neighboring peers // Initial 1X input chunk exchange between neighboring peers
int send_chunk_id = _tp_id; int send_chunk_id = _tp_id;
...@@ -801,8 +853,9 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -801,8 +853,9 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// GEMM // GEMM
auto input_b_chunk = auto input_b_chunk =
get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk * 2, k}); get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);
auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk * 2, m}); auto output_chunk =
get_tensor_chunk(D, output_chunk_size * send_chunk_id, output_chunk_shape);
auto aux_chunk = auto aux_chunk =
(do_gelu) (do_gelu)
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk * 2, k}) ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk * 2, k})
...@@ -834,6 +887,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -834,6 +887,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
} }
} }
} else { } else {
// Chunk dims
std::vector<size_t> input_b_chunk_shape =
(transb ? std::vector<size_t>{k, n_chunk} : std::vector<size_t>{n_chunk, k});
std::vector<size_t> output_chunk_shape = {n_chunk, m};
size_t input_b_chunk_size = n_chunk * k;
size_t output_chunk_size = n_chunk * m;
for (int i = 0; i < _tp_size; i++) { for (int i = 0; i < _tp_size; i++) {
// Set the userbuffer id. Buffer under send is the input for the current // Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
...@@ -845,8 +905,10 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -845,8 +905,10 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
int recv_offset = comm_bytes * recv_chunk_id; int recv_offset = comm_bytes * recv_chunk_id;
// GEMM // GEMM
auto input_b_chunk = get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk, k}); auto input_b_chunk =
auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk, m}); get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);
auto output_chunk =
get_tensor_chunk(D, output_chunk_size * send_chunk_id, output_chunk_shape);
auto aux_chunk = auto aux_chunk =
(do_gelu) (do_gelu)
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk, k}) ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk, k})
...@@ -996,6 +1058,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, ...@@ -996,6 +1058,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
auto input_b_chunk = get_tensor_chunk(B, input_b_chunk_id * input_chunk_size, {n_chunk, k}); auto input_b_chunk = get_tensor_chunk(B, input_b_chunk_id * input_chunk_size, {n_chunk, k});
auto output_chunk = get_buffer_chunk_by_id(D, i); auto output_chunk = get_buffer_chunk_by_id(D, i);
auto workspace_chunk = auto workspace_chunk =
get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk}); get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk});
......
...@@ -12,10 +12,18 @@ ...@@ -12,10 +12,18 @@
#include <cudnn.h> #include <cudnn.h>
#include <nvrtc.h> #include <nvrtc.h>
#include <iostream>
#include <stdexcept> #include <stdexcept>
#include "../util/string.h" #include "../util/string.h"
#define NVTE_WARN(...) \
do { \
std::cerr << ::transformer_engine::concat_strings( \
__FILE__ ":", __LINE__, " in function ", __func__, ": ", \
::transformer_engine::concat_strings(__VA_ARGS__), "\n"); \
} while (false)
#define NVTE_ERROR(...) \ #define NVTE_ERROR(...) \
do { \ do { \
throw ::std::runtime_error(::transformer_engine::concat_strings( \ throw ::std::runtime_error(::transformer_engine::concat_strings( \
......
...@@ -452,12 +452,10 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve ...@@ -452,12 +452,10 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve
~CommOverlap() {} ~CommOverlap() {}
void set_buffer_params(py::handle quantizer); void copy_into_buffer(const at::Tensor &input, bool local_chunk = false);
void copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk = false); at::Tensor get_buffer(bool local_chunk = false,
std::optional<std::vector<int64_t>> shape = std::nullopt);
py::object get_buffer(py::handle quantizer, bool local_chunk = false,
std::optional<const std::vector<int64_t>> shape = std::nullopt);
}; // CommOverlap }; // CommOverlap
...@@ -473,12 +471,10 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm ...@@ -473,12 +471,10 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
~CommOverlapP2P() {} ~CommOverlapP2P() {}
void set_buffer_params(py::handle quantizer); void copy_into_buffer(const at::Tensor &input, bool local_chunk = false);
void copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk = false);
py::object get_buffer(py::handle quantizer, bool local_chunk = false, at::Tensor get_buffer(bool local_chunk = false,
std::optional<const std::vector<int64_t>> shape = std::nullopt); std::optional<std::vector<int64_t>> shape = std::nullopt);
}; // CommOverlapP2P }; // CommOverlapP2P
......
...@@ -141,81 +141,79 @@ CommOverlap::CommOverlap(const std::vector<size_t> &buffer_shape, at::ScalarType ...@@ -141,81 +141,79 @@ CommOverlap::CommOverlap(const std::vector<size_t> &buffer_shape, at::ScalarType
num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm,
set_sm_margin, atomic_gemm, rs_overlap_first_gemm) {} set_sm_margin, atomic_gemm, rs_overlap_first_gemm) {}
void CommOverlap::set_buffer_params(py::handle quantizer) {
std::unique_ptr<te::pytorch::Quantizer> my_quantizer = te::pytorch::convert_quantizer(quantizer);
my_quantizer->set_quantization_params(&_ubuf);
_ubuf_scale_inv_initialized = true;
}
/* /*
** Helper function to copy input to _ubuf ** Helper function to copy input to _ubuf
*/ */
void CommOverlap::copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk) { void CommOverlap::copy_into_buffer(const at::Tensor &input, bool local_chunk) {
auto input_tensor = te::pytorch::makeTransformerEngineTensor(input, quantizer); const auto &input_ = input.contiguous();
auto input_ptr = input_tensor.dptr() ? input_tensor.dptr() : input_tensor.columnwise_dptr();
NVTE_CHECK(input_ptr, "Input tensor does not have rowwise or columnwise data!"); // Check element size
const size_t element_size = input.element_size();
char *ubuf_ptr = reinterpret_cast<char *>(_ubuf.dptr()); NVTE_CHECK(_ubuf.element_size() == element_size,
"Tried to copy data into a Userbuffers buffer but dtypes are not compatible ",
"(input dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(),
" bytes)");
// Input data
const size_t input_size = input_.numel();
const void *src_ptr = input_.data_ptr();
// Userbuffers data
const size_t ubuf_size = _ubuf.numel();
void *dst_ptr = _ubuf.dptr();
if (local_chunk) { if (local_chunk) {
if (input_tensor.numel() * _tp_size > _ubuf.numel()) NVTE_CHECK(input_size * _tp_size == ubuf_size,
NVTE_ERROR("input is larger than the local communication buffer!"); "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ",
if (input_tensor.element_size() != _ubuf.element_size()) "(input_size=", input_size, ", tensor_parallel_size=", _tp_size,
NVTE_ERROR("input data type does not match communication buffer!"); ", ubuf_size=", ubuf_size, ")");
ubuf_ptr += (_ubuf.numel() / _tp_size) * _tp_id * _ubuf.element_size(); dst_ptr = (reinterpret_cast<char *>(dst_ptr) + (ubuf_size / _tp_size) * _tp_id * element_size);
} else { } else {
if (input_tensor.numel() > _ubuf.numel()) NVTE_CHECK(input_size == ubuf_size,
NVTE_ERROR("input is larger than the global communication buffer!"); "Tried to copy an invalid tensor into a Userbuffers buffer ",
if (input_tensor.element_size() != _ubuf.element_size()) "(input_size=", input_size, ", ubuf_size=", ubuf_size, ")");
NVTE_ERROR("input data type does not match communication buffer!");
} }
// Copy either row or columnwise data into the communication buffer's columnwise data // Copy data
// NOTE: _ubuf.columnwise_dptr() is not a valid copy target because it is not registered with auto stream_main = at::cuda::getCurrentCUDAStream();
// the Userbuffers communicator.
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0));
NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input_tensor.dptr(), NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, input_size * element_size,
input_tensor.numel() * input_tensor.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm)); cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm));
} }
py::object CommOverlap::get_buffer(py::handle quantizer, bool local_chunk, at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional<std::vector<int64_t>> shape) {
std::optional<const std::vector<int64_t>> shape) { // Check buffer shape
using namespace te::pytorch; const size_t ubuf_size = _ubuf.numel();
char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.dptr()); if (shape) {
if (local_chunk) ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); const size_t requested_size = transformer_engine::pytorch::product(*shape);
if (local_chunk) {
std::vector<int64_t> torch_shape; NVTE_CHECK(requested_size * _tp_size == ubuf_size,
if (shape.has_value()) { "Invalid shape for local chunk of a Userbuffers buffer (requested shape=", *shape,
torch_shape = shape.value(); ", tensor_parallel_size=", _tp_size, ", ubuf_size=", ubuf_size, ")");
size_t requested = product(torch_shape); } else {
auto expected = local_chunk ? _ubuf.numel() / _tp_size : _ubuf.numel(); NVTE_CHECK(requested_size == ubuf_size,
NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested, "Invalid shape for a Userbuffers buffer (requested shape=", *shape,
") does not match allocated buffer size (", expected, ")!"); ", ubuf_size=", ubuf_size, ")");
}
} else { } else {
int64_t output_c_dim0 = (local_chunk) ? _ubuf.size(0) / _tp_size : _ubuf.size(0); int64_t dim0 = _ubuf.size(0);
int64_t output_c_dim1 = _ubuf.size(1); int64_t dim1 = _ubuf.size(1);
torch_shape = {output_c_dim0, output_c_dim1}; if (local_chunk) {
dim0 /= _tp_size;
}
shape = {dim0, dim1};
}
// Data pointer
void *ubuf_ptr = _ubuf.dptr();
if (local_chunk) {
ubuf_ptr = (reinterpret_cast<char *>(ubuf_ptr) +
(ubuf_size / _tp_size) * _tp_id * _ubuf.element_size());
} }
auto ubuf_tensor = torch::from_blob(reinterpret_cast<void *>(ubuf_wt_ptr), torch_shape, // Construct PyTorch tensor
at::dtype(GetATenDType(_ubuf.dtype())).device(torch::kCUDA)); const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype());
return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
std::vector<size_t> te_shape;
for (auto s : torch_shape) te_shape.emplace_back(static_cast<size_t>(s));
// Always output a rowwise-only QuantizedTensor
// TODO (Alp): This needs to produce an un-interleaved transpose when required.
auto is_internal = my_quantizer->internal;
auto uses_columnwise = my_quantizer->columnwise_usage;
my_quantizer->internal = false;
my_quantizer->columnwise_usage = false;
auto [te_tensor, py_tensor] = my_quantizer->create_tensor(te_shape, _ubuf.dtype(), ubuf_tensor);
my_quantizer->internal = is_internal;
my_quantizer->columnwise_usage = uses_columnwise;
return py_tensor;
} }
/*************************************************************************************************** /***************************************************************************************************
...@@ -236,74 +234,69 @@ CommOverlapP2P::CommOverlapP2P(const std::vector<size_t> &buffer_shape, at::Scal ...@@ -236,74 +234,69 @@ CommOverlapP2P::CommOverlapP2P(const std::vector<size_t> &buffer_shape, at::Scal
comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce,
atomic_gemm, aggregate) {} atomic_gemm, aggregate) {}
void CommOverlapP2P::set_buffer_params(py::handle quantizer) {
std::unique_ptr<te::pytorch::Quantizer> my_quantizer = te::pytorch::convert_quantizer(quantizer);
my_quantizer->set_quantization_params(&_ubuf);
for (size_t i = 0; i < _ubufs.size(); i++) my_quantizer->set_quantization_params(&_ubufs[i]);
}
/* /*
** Copy input to _ubufs[0] ** Copy input to _ubufs[0]
*/ */
void CommOverlapP2P::copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk) { void CommOverlapP2P::copy_into_buffer(const at::Tensor &input, bool local_chunk) {
auto input_tensor = te::pytorch::makeTransformerEngineTensor(input, quantizer); const auto &input_ = input.contiguous();
auto input_ptr = input_tensor.dptr() ? input_tensor.dptr() : input_tensor.columnwise_dptr();
NVTE_CHECK(input_ptr, "Input tensor does not have rowwise or columnwise data!"); // Check element size
const size_t element_size = input.element_size();
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); NVTE_CHECK(_ubuf.element_size() == element_size,
"Tried to copy data into a Userbuffers buffer but dtypes are not compatible ",
"(input dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(),
" bytes)");
// Input data
const size_t input_size = input_.numel();
const void *src_ptr = input_.data_ptr();
// Userbuffers data
void *dst_ptr;
if (local_chunk) { if (local_chunk) {
// Copy input to the target ubuf chunk by rank offset NVTE_CHECK(_ubufs[_tp_id].numel() == input_size,
if (input_tensor.numel() * _tp_size > _ubuf.numel()) "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ",
NVTE_ERROR("input is larger than the local communication buffer!"); "(input_size=", input_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")");
if (input_tensor.element_size() != _ubuf.element_size()) dst_ptr = _ubufs[_tp_id].dptr();
NVTE_ERROR("input data type does not match communication buffer!");
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input_ptr,
input_tensor.numel() * input_tensor.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main));
} else { } else {
if (input_tensor.numel() > _ubuf.numel()) NVTE_CHECK(_ubuf.numel() == input_size,
NVTE_ERROR("input is larger than the global communication buffer!"); "Tried to copy an invalid tensor into a Userbuffers buffer ",
if (input_tensor.element_size() != _ubuf.element_size()) "(input_size=", input_size, ", ubuf_size=", _ubuf.numel(), ")");
NVTE_ERROR("input data type does not match communication buffer!"); dst_ptr = _ubuf.dptr();
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input_ptr,
input_tensor.numel() * input_tensor.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main));
} }
// Copy data
NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, input_size * element_size,
cudaMemcpyDeviceToDevice,
(cudaStream_t)at::cuda::getCurrentCUDAStream()));
} }
py::object CommOverlapP2P::get_buffer(py::handle quantizer, bool local_chunk, at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional<std::vector<int64_t>> shape) {
std::optional<const std::vector<int64_t>> shape) { // Check buffer shape
using namespace te::pytorch; if (shape) {
char *ubuf_wt_ptr = reinterpret_cast<char *>(local_chunk ? _ubufs[_tp_id].dptr() : _ubuf.dptr()); const size_t requested_size = transformer_engine::pytorch::product(*shape);
if (local_chunk) {
std::vector<int64_t> torch_shape; NVTE_CHECK(requested_size == _ubufs[_tp_id].numel(),
if (shape.has_value()) { "Invalid shape for local chunk of a Userbuffers buffer (requested shape=", *shape,
torch_shape = shape.value(); ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")");
size_t requested = product(torch_shape);
auto expected = local_chunk ? _ubufs[_tp_id].numel() : _ubuf.numel();
NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested,
") does not match allocated buffer size (", expected, ")!");
} else { } else {
int64_t output_c_dim0 = (local_chunk) ? _ubuf.size(0) / _tp_size : _ubuf.size(0); NVTE_CHECK(requested_size == _ubuf.numel(),
int64_t output_c_dim1 = _ubuf.size(1); "Invalid shape for a Userbuffers buffer (requested shape=", *shape,
torch_shape = {output_c_dim0, output_c_dim1}; ", ubuf_size=", _ubuf.numel(), ")");
} }
auto ubuf_tensor = torch::from_blob(reinterpret_cast<void *>(ubuf_wt_ptr), torch_shape, } else {
at::dtype(GetATenDType(_ubuf.dtype())).device(torch::kCUDA)); int64_t dim0 = _ubuf.size(0);
int64_t dim1 = _ubuf.size(1);
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer); if (local_chunk) {
std::vector<size_t> te_shape; dim0 /= _tp_size;
for (auto s : torch_shape) te_shape.emplace_back(static_cast<size_t>(s)); }
shape = {dim0, dim1};
// Always output a rowwise-only QuantizedTensor }
// TODO (Alp): This needs to produce an un-interleaved transpose when required.
auto is_internal = my_quantizer->internal; // Data pointer
auto uses_columnwise = my_quantizer->columnwise_usage; void *ubuf_ptr = local_chunk ? _ubufs[_tp_id].dptr() : _ubuf.dptr();
my_quantizer->internal = false;
my_quantizer->columnwise_usage = false; // Construct PyTorch tensor
auto [te_tensor, py_tensor] = my_quantizer->create_tensor(te_shape, _ubuf.dtype(), ubuf_tensor); const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype());
my_quantizer->internal = is_internal; return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
my_quantizer->columnwise_usage = uses_columnwise;
return py_tensor;
} }
...@@ -360,10 +360,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -360,10 +360,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true,
py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false) py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false)
.def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"), .def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"),
py::arg("quantizer"), py::arg("local_chunk") = false) py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlap::get_buffer, py::arg("quantizer"), .def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false,
py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) py::arg("shape") = std::nullopt);
.def("set_buffer_params", &CommOverlap::set_buffer_params);
py::class_<CommOverlapP2P, std::shared_ptr<CommOverlapP2P>, py::class_<CommOverlapP2P, std::shared_ptr<CommOverlapP2P>,
transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>( transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>(
...@@ -378,8 +377,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -378,8 +377,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false,
py::arg("use_ce") = true, py::arg("aggregate") = false) py::arg("use_ce") = true, py::arg("aggregate") = false)
.def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"), .def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"),
py::arg("quantizer"), py::arg("local_chunk") = false) py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("quantizer"), .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false,
py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) py::arg("shape") = std::nullopt);
.def("set_buffer_params", &CommOverlapP2P::set_buffer_params);
} }
...@@ -919,8 +919,10 @@ def _all_gather_fp8( ...@@ -919,8 +919,10 @@ def _all_gather_fp8(
# Note: We cannot directly all-gather the transposed FP8 tensor, # Note: We cannot directly all-gather the transposed FP8 tensor,
# so temporarily modify quantizer to avoid creating FP8 transpose. # so temporarily modify quantizer to avoid creating FP8 transpose.
if not isinstance(inp, Float8TensorBase): if not isinstance(inp, Float8TensorBase):
if quantizer is None: assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer))
raise ValueError("Input tensor is not FP8 and no quantizer was provided") # we cannot directly gather the transposed fp8 tensor
# so we need to disable columnwise usage for the quantizer
# and then set it back to the original value after quantizing
init_rowwise_usage = quantizer.rowwise_usage init_rowwise_usage = quantizer.rowwise_usage
init_columnwise_usage = quantizer.columnwise_usage init_columnwise_usage = quantizer.columnwise_usage
quantizer.set_usage(rowwise=True, columnwise=False) quantizer.set_usage(rowwise=True, columnwise=False)
...@@ -1063,11 +1065,11 @@ def _all_gather_mxfp8( ...@@ -1063,11 +1065,11 @@ def _all_gather_mxfp8(
dtype = inp.dtype dtype = inp.dtype
elif isinstance(inp, MXFP8TensorBase): elif isinstance(inp, MXFP8TensorBase):
if inp._rowwise_data is not None: if inp._rowwise_data is not None:
in_shape = inp._rowwise_data.device.size() in_shape = inp._rowwise_data.size()
device = inp._rowwise_data.device device = inp._rowwise_data.device
dtype = inp._rowwise_data.dtype dtype = inp._rowwise_data.dtype
elif inp._columnwise_data is not None: elif inp._columnwise_data is not None:
in_shape = inp._columnwise_data.device.size() in_shape = inp._columnwise_data.size()
device = inp._columnwise_data.device device = inp._columnwise_data.device
dtype = inp._columnwise_data.dtype dtype = inp._columnwise_data.dtype
else: else:
......
...@@ -6,8 +6,6 @@ ...@@ -6,8 +6,6 @@
from typing import Any, List, Optional, Tuple, Union, Callable from typing import Any, List, Optional, Tuple, Union, Callable
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce
from operator import mul as multiply_op
import queue import queue
import torch import torch
...@@ -15,7 +13,6 @@ import torch ...@@ -15,7 +13,6 @@ import torch
from .. import cpp_extensions as tex from .. import cpp_extensions as tex
from ..constants import TE_DType from ..constants import TE_DType
from ..utils import get_default_init_method from ..utils import get_default_init_method
from ..tensor.float8_tensor import Float8Tensor
def _get_normalization_func(normalization: str, forward: bool): def _get_normalization_func(normalization: str, forward: bool):
...@@ -33,39 +30,6 @@ def _get_normalization_func(normalization: str, forward: bool): ...@@ -33,39 +30,6 @@ def _get_normalization_func(normalization: str, forward: bool):
return bwd_normalization_funcs[normalization] return bwd_normalization_funcs[normalization]
def _fix_gathered_fp8_transpose(fp8_tensor: Float8Tensor, tp_size: int) -> Float8Tensor:
"""Reorder FP8 transposes after Userbuffers gather.
The all-gather is performed in-place in the Float8Tensor's
row-wise data, and afterwards we need to do a transpose to get the
correct ordering. This misuses data fields in Float8Tensor and
should be considered an evil hack. It would be best to move
transpose logic into CommOverlap::get_buffer.
Responsibility for fixing: adener, tmoon
"""
assert isinstance(fp8_tensor, Float8Tensor), "Tensor is not a Float8Tensor"
assert tp_size > 1, "The tensor transpose cannot be interleaved when TP size is 1"
assert fp8_tensor._data is not None, "The tensor does not hold any rowwise data"
assert (
fp8_tensor._data.shape[0] % tp_size == 0
), "Leading dimension of data is not divisble by TP size"
data = fp8_tensor._data
batched_size = reduce(multiply_op, data.shape[1:])
interleaved_shape = [tp_size, data.shape[0] // tp_size, batched_size]
transposed_shape = [data.shape[0] // tp_size, batched_size * tp_size]
fp8_tensor._transpose = (
data.view(interleaved_shape).transpose(0, 1).contiguous().view(transposed_shape)
)
fp8_tensor._transpose_invalid = False
fp8_tensor._data = None
return fp8_tensor
def apply_normalization( def apply_normalization(
inputmat: torch.Tensor, inputmat: torch.Tensor,
ln_out: torch.Tensor, ln_out: torch.Tensor,
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""Base modules and utilities for TransformerEngine PyTorch API""" """Base modules and utilities for TransformerEngine PyTorch API"""
import io import io
import math
import os import os
import pickle import pickle
import warnings import warnings
...@@ -35,7 +36,9 @@ from ..distributed import ( ...@@ -35,7 +36,9 @@ from ..distributed import (
_fsdp_gather_tensors, _fsdp_gather_tensors,
) )
from ..constants import dist_group_type from ..constants import dist_group_type
from ..tensor import QuantizedTensor, Quantizer from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorBase, Quantizer
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
...@@ -418,6 +421,142 @@ def destroy_ub(): ...@@ -418,6 +421,142 @@ def destroy_ub():
layers_atomic_ring_exchange = [] layers_atomic_ring_exchange = []
def fill_userbuffers_buffer_for_all_gather(
comm,
local_tensor: torch.Tensor,
quantizer: Optional[Quantizer],
process_group,
) -> tuple[torch.Tensor | QuantizedTensorBase, torch.Tensor | QuantizedTensorBase]:
"""Fill local shard of Userbuffers buffer with data for all-gather
Returns the full tensor and the local shard, both using the
Userbuffers buffer as their underlying data. These tensors should
be used carefully (e.g. only immediately before and after a
Userbuffers operation) since the underlying data may be
overwritten by other Userbuffers operations.
May perform blocking communication if needed for the gathered
tensor's metadata, e.g. scaling factors.
"""
# Tensor dimensions
local_shape = local_tensor.size()
if not local_shape:
raise ValueError(f"Invalid local tensor (shape={tuple(local_shape)})")
process_group_size = torch.distributed.get_world_size(process_group)
global_shape = list(local_shape)
global_shape[0] *= process_group_size
# Unquantized data
if quantizer is None:
if isinstance(local_tensor, QuantizedTensorBase):
local_tensor = local_tensor.dequantize()
if comm.is_fp8_ubuf():
raise RuntimeError(
"Attempting to all-gather unquantized tensor, "
"but Userbuffers is initialized with FP8 buffers"
)
comm.copy_into_buffer(local_tensor, local_chunk=True)
global_tensor = comm.get_buffer(shape=global_shape)
return global_tensor, local_tensor
# FP8 data
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
if not isinstance(local_tensor, Float8TensorBase):
if isinstance(local_tensor, QuantizedTensorBase):
local_tensor.dequantize()
quantizer.set_usage(rowwise=True, columnwise=False)
local_tensor = quantizer(local_tensor)
if not comm.is_fp8_ubuf():
raise RuntimeError(
"Attempting to all-gather FP8 tensor, "
"but Userbuffers is not initialized with FP8 buffers"
)
comm.copy_into_buffer(local_tensor._data, local_chunk=True)
global_tensor_data = comm.get_buffer(shape=global_shape)
global_tensor = Float8TensorBase(
data=global_tensor_data,
fp8_scale_inv=local_tensor._scale_inv,
fp8_dtype=local_tensor._fp8_dtype,
quantizer=quantizer,
)
return global_tensor, local_tensor
# MXFP8 data
if isinstance(quantizer, MXFP8Quantizer):
# Cast to MXFP8 if needed
if not isinstance(local_tensor, MXFP8TensorBase):
if isinstance(local_tensor, QuantizedTensorBase):
local_tensor.dequantize()
local_tensor = quantizer(local_tensor)
if not comm.is_fp8_ubuf():
raise RuntimeError(
"Attempting to all-gather MXFP8 tensor, "
"but Userbuffers is not initialized with FP8 buffers"
)
# Check which MXFP8 buffer to communicate
if quantizer.rowwise_usage == quantizer.columnwise_usage:
raise ValueError(
"Userbuffers can only communicate one MXFP8 buffer at a time, "
f"but quantizer has rowwise_usage={quantizer.rowwise_usage}, "
f"columnwise_usage={quantizer.columnwise_usage}"
)
with_rowwise_data = quantizer.rowwise_usage
# Copy MXFP8 data to local chunk of Userbuffers buffer
local_data = (
local_tensor._rowwise_data if with_rowwise_data else local_tensor._columnwise_data
)
comm.copy_into_buffer(local_data, local_chunk=True)
# Gather scaling-inverses
if math.prod(local_shape[:-1]) % 128 != 0:
raise ValueError(
"Userbuffers requires MXFP8 tensor dims that are divisible by 128, "
f"but got MXFP8 tensor with shape={tuple(local_shape)}"
)
local_scale_inv = (
local_tensor._rowwise_scale_inv
if with_rowwise_data
else local_tensor._columnwise_scale_inv
)
local_scale_inv_size = list(local_scale_inv.size())
global_scale_inv = torch.empty(
[process_group_size * local_scale_inv_size[0]] + local_scale_inv_size[1:],
dtype=local_scale_inv.dtype,
device=local_scale_inv.device,
)
torch.distributed.all_gather_into_tensor(
global_scale_inv,
local_scale_inv,
group=process_group,
)
# Construct MXFP8 tensor with Userbuffers buffer
rowwise_data, rowwise_scale_inv = None, None
columnwise_data, columnwise_scale_inv = None, None
global_data = comm.get_buffer(shape=global_shape)
if with_rowwise_data:
rowwise_data, rowwise_scale_inv = global_data, global_scale_inv
else:
columnwise_data, columnwise_scale_inv = global_data, global_scale_inv
global_tensor = MXFP8TensorBase(
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
fp8_dtype=local_tensor._fp8_dtype,
quantizer=quantizer,
)
return global_tensor, local_tensor
# Unsupported data format
raise ValueError(f"Unsupported quantizer for Userbuffers ({quantizer})")
class TransformerEngineBaseModule(torch.nn.Module, ABC): class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module.""" """Base TE module."""
...@@ -866,11 +1005,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -866,11 +1005,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Non-FP8 case: bgrad is fused with wgrad for this case. # Non-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8 and not ctx.debug: if not ctx.fp8 and not ctx.debug:
if gather_grad_output: if gather_grad_output:
if not ctx.ub_overlap_ag: if not ctx.ub_overlap_ag: # Perform NCCL all-gather
grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group) grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group)
else: else: # Initialize Userbuffers all-gather
ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True) grad_output, _ = fill_userbuffers_buffer_for_all_gather(
grad_output = ctx.ub_obj_gradout.get_buffer(quantizer) ctx.ub_obj_gradout,
grad_output,
None,
ctx.tp_group,
)
return grad_output, None return grad_output, None
# FP8 with all-gather: unfused bgrad, fused cast + transpose # FP8 with all-gather: unfused bgrad, fused cast + transpose
...@@ -893,8 +1036,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -893,8 +1036,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_output = quantizer(grad_output) grad_output = quantizer(grad_output)
# Copy into communication buffer, and replace original gradient with it # Copy into communication buffer, and replace original gradient with it
ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True) grad_output, _ = fill_userbuffers_buffer_for_all_gather(
grad_output = ctx.ub_obj_gradout.get_buffer(quantizer) ctx.ub_obj_gradout,
grad_output,
quantizer,
ctx.tp_group,
)
else: else:
grad_output, _ = gather_along_first_dim( grad_output, _ = gather_along_first_dim(
grad_output, grad_output,
...@@ -1165,7 +1312,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1165,7 +1312,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute(): if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
return return
with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"): with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"):
(wgrad, grad_bias_, _, _), _ = self.wgrad_store.pop() (wgrad, bgrad), _ = self.wgrad_store.pop()
if not self.fuse_wgrad_accumulation: if not self.fuse_wgrad_accumulation:
unfused_weights = [getattr(self, name) for name in self.weight_names] unfused_weights = [getattr(self, name) for name in self.weight_names]
weight_tensor = noop_cat(unfused_weights) weight_tensor = noop_cat(unfused_weights)
...@@ -1174,9 +1321,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1174,9 +1321,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.use_bias: if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
if bias_tensor.grad is None: if bias_tensor.grad is None:
bias_tensor.grad = grad_bias_.to(bias_tensor.dtype) bias_tensor.grad = bgrad.to(bias_tensor.dtype)
del grad_bias_
del wgrad
def _validate_name(self): def _validate_name(self):
""" """
......
...@@ -9,7 +9,6 @@ from typing import Callable, Dict, Optional, Tuple, Union ...@@ -9,7 +9,6 @@ from typing import Callable, Dict, Optional, Tuple, Union
from functools import reduce from functools import reduce
from operator import mul as multiply_op from operator import mul as multiply_op
import functools
import torch import torch
from torch.nn import init from torch.nn import init
...@@ -18,6 +17,7 @@ import transformer_engine_torch as tex ...@@ -18,6 +17,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version from transformer_engine.pytorch import torch_version
from .base import ( from .base import (
fill_userbuffers_buffer_for_all_gather,
get_workspace, get_workspace,
get_ub, get_ub,
TransformerEngineBaseModule, TransformerEngineBaseModule,
...@@ -53,9 +53,10 @@ from ..distributed import ( ...@@ -53,9 +53,10 @@ from ..distributed import (
from ..constants import GemmParallelModes, dist_group_type from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ._common import apply_normalization, noop_cat, _fix_gathered_fp8_transpose, WeightGradStore from ._common import apply_normalization, noop_cat, WeightGradStore
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
QuantizedTensor, QuantizedTensor,
QuantizedTensorBase,
Quantizer, Quantizer,
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
...@@ -153,42 +154,43 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -153,42 +154,43 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.norm_input_cast") nvtx_range_pop(f"{nvtx_label}.norm_input_cast")
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag_fprop = (
ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output
)
weight_requires_grad = weight.requires_grad weight_requires_grad = weight.requires_grad
backward_needs_input = is_grad_enabled and weight_requires_grad backward_needs_input = is_grad_enabled and weight_requires_grad
with_input_all_gather = parallel_mode == "column" and sequence_parallel with_input_all_gather = parallel_mode == "column" and sequence_parallel
# Check if Userbuffers is supported # Configure Userbuffers communication (comm+GEMM overlap)
if fp8: ub_obj = None
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not ( ub_type = None
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() ub_overlap_ag_fprop = (
): ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
) )
if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.RS
elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.AG
# Configure quantizer for norm output # Configure quantizer for norm output
if fp8: if fp8:
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
columnwise_usage = backward_needs_input input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if ( if with_input_all_gather and isinstance(
columnwise_usage input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
and with_input_all_gather
and not isinstance(input_quantizer, MXFP8Quantizer)
): ):
columnwise_usage = False # All-gather is not supported with FP8 column-wise data
input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) input_quantizer.set_usage(columnwise=False)
# Avoid quantized norm kernel if norm output will be returned # Do TP communication in high precision if quantized format
# or if a gather of ln_out must be in high precision. # does not support communication
force_hp_blockwise_ln_out_gather = ( force_hp_blockwise_ln_out_gather = (
fp8 and with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer) fp8 and with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer)
) # Perform TP communication in high precision. )
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
with_quantized_norm = ( with_quantized_norm = (
fp8 fp8
and not return_layernorm_output and not return_layernorm_output
...@@ -210,16 +212,19 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -210,16 +212,19 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_ln_sm_margin, fwd_ln_sm_margin,
zero_centered_gamma, zero_centered_gamma,
) )
nvtx_range_pop(f"{nvtx_label}.norm")
# Store unquantized layer norm output if we need to return it
ln_out_return = None ln_out_return = None
if return_layernorm_output or return_layernorm_output_gathered: if return_layernorm_output or return_layernorm_output_gathered:
ln_out_return = ln_out ln_out_return = ln_out
nvtx_range_pop(f"{nvtx_label}.norm")
# Prepare GEMM input # ------------------------------------------------------
# Prepare GEMM input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication # Note: Cast to expected dtype and perform tensor-parallel communication
# ------------------------------------------------------
nvtx_range_push(f"{nvtx_label}.gemm_input_cast_comm") nvtx_range_push(f"{nvtx_label}.gemm_input_cast_comm")
ln_out_total = None ln_out_total = None
ub_obj_fprop = None
if with_input_all_gather: if with_input_all_gather:
if return_layernorm_output_gathered: if return_layernorm_output_gathered:
# Perform all-gather in high precision if gathered # Perform all-gather in high precision if gathered
...@@ -227,47 +232,53 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -227,47 +232,53 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total ln_out_return = ln_out_total
if fp8 or debug: if fp8 or debug:
if not force_hp_blockwise_ln_out_gather:
ln_out = input_quantizer(ln_out) ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False) input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out_total = input_quantizer(ln_out_total) ln_out_total = input_quantizer(ln_out_total)
else: else:
quantizer = None
if fp8 or debug: if fp8 or debug:
quantizer = input_quantizer
if not with_quantized_norm and not force_hp_blockwise_ln_out_gather: if not with_quantized_norm and not force_hp_blockwise_ln_out_gather:
ln_out = input_quantizer(ln_out) ln_out = quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False) quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag_fprop: if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather
# Copy into Userbuffers buffer ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_fprop = get_ub(ub_name + "_fprop") ub_obj,
ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True).copy_(ln_out) ln_out,
ln_out_total = ub_obj_fprop.get_buffer(input_quantizer) quantizer,
else: tp_group,
# All-gather with NCCL )
else: # Perform NCCL all-gather
ln_out_total, _ = gather_along_first_dim( ln_out_total, _ = gather_along_first_dim(
ln_out, ln_out,
tp_group, tp_group,
quantizer=(input_quantizer if fp8 or debug else None), quantizer=quantizer,
) )
else: else:
if (fp8 or debug) and not with_quantized_norm: if (fp8 or debug) and not with_quantized_norm:
ln_out = input_quantizer(ln_out) ln_out = input_quantizer(ln_out)
ln_out_total = ln_out ln_out_total = ln_out
nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm") nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm")
# ------------------------------------------------------
# GEMM input tensor is ready...
# ------------------------------------------------------
# Cast weight to expected dtype # ------------------------------------------------------
# Prepare weight tensor
# ------------------------------------------------------
weightmat = weight weightmat = weight
quantized_weight = False quantized_weight = False
if not fp8 and not debug: if fp8 or debug:
weightmat = cast_if_needed(weightmat, activation_dtype) quantized_weight = not isinstance(weight, QuantizedTensorBase)
else:
quantized_weight = not isinstance(weight, QuantizedTensor)
# Configure quantizer # Configure quantizer
if weight_quantizer is not None: if weight_quantizer is not None:
weight_quantizer.set_usage(rowwise=True, columnwise=True) weight_quantizer.set_usage(rowwise=True, columnwise=True)
# FP8 cast to workspace buffer # Get quantized weight
update_workspace = is_first_microbatch is None or is_first_microbatch update_workspace = is_first_microbatch is None or is_first_microbatch
weightmat = module.get_weight_workspace( weightmat = module.get_weight_workspace(
tensor=weight, tensor=weight,
quantizer=weight_quantizer, quantizer=weight_quantizer,
...@@ -277,17 +288,21 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -277,17 +288,21 @@ class _LayerNormLinear(torch.autograd.Function):
fsdp_group=fsdp_group, fsdp_group=fsdp_group,
workspace_dtype=activation_dtype, workspace_dtype=activation_dtype,
) )
weightmat.update_usage(rowwise_usage=True)
else:
weightmat = cast_if_needed(weightmat, activation_dtype) # Cast for AMP
# ------------------------------------------------------
# Weight tensor is ready for GEMM...
# ------------------------------------------------------
# Cast bias to expected dtype # Cast bias to expected dtype
bias_dtype = activation_dtype bias_dtype = activation_dtype
if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32: if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32:
# cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16
bias_dtype = torch.bfloat16 bias_dtype = torch.bfloat16
bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias
# Configure output quantizer
if output_quantizer is not None:
output_quantizer.set_usage(rowwise=True, columnwise=False)
# Calibrate quantizers if needed # Calibrate quantizers if needed
if not fp8 and fp8_calibration: if not fp8 and fp8_calibration:
if input_quantizer is not None: if input_quantizer is not None:
...@@ -295,48 +310,81 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -295,48 +310,81 @@ class _LayerNormLinear(torch.autograd.Function):
if weight_quantizer is not None: if weight_quantizer is not None:
weight_quantizer.calibrate(weight) weight_quantizer.calibrate(weight)
ub_obj = None # Choose whether to use GEMM kernel with split accumulator
ub_type = None use_split_accumulator = _2X_ACC_FPROP
rs_out = None
if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.RS
out_shape = [reduce(multiply_op, inp_shape[:-1]) // tp_world_size, out_features]
rs_out = torch.empty(out_shape, dtype=activation_dtype, device=ln_out_total.device)
elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.AG
if fp8:
assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer."
ln_out_total = ub_obj.get_buffer(input_quantizer)
nvtx_range_push(f"{nvtx_label}.gemm")
fprop_gemm_use_split_accumulator = _2X_ACC_FPROP
if fp8: if fp8:
recipe = FP8GlobalStateManager.get_fp8_recipe() recipe = FP8GlobalStateManager.get_fp8_recipe()
if hasattr(recipe, "fp8_gemm_fprop"): if hasattr(recipe, "fp8_gemm_fprop"):
fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
out, *_, rs_out = general_gemm( # Configure output quantizer
if output_quantizer is not None:
output_quantizer.set_usage(rowwise=True, columnwise=False)
# Output buffer for Userbuffers reduce-scatter
reduce_scatter_out = None
if ub_overlap_rs_fprop:
out_shape = list(inp_shape)
out_shape[0] //= tp_world_size
out_shape[-1] = out_features
reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device)
# ------------------------------------------------------
# Forward GEMM
# Note: y = x * w^T
# ------------------------------------------------------
nvtx_range_push(f"{nvtx_label}.gemm")
gemm_out, *_, reduce_scatter_out = general_gemm(
weightmat, weightmat,
ln_out_total, ln_out_total,
get_workspace(), get_workspace(),
quantization_params=output_quantizer, quantization_params=output_quantizer,
out_dtype=activation_dtype, out_dtype=activation_dtype,
bias=bias, bias=bias,
use_split_accumulator=fprop_gemm_use_split_accumulator, use_split_accumulator=use_split_accumulator,
ub=ub_obj, ub=ub_obj,
ub_type=ub_type, ub_type=ub_type,
extra_output=rs_out, extra_output=reduce_scatter_out,
) )
nvtx_range_pop(f"{nvtx_label}.gemm") nvtx_range_pop(f"{nvtx_label}.gemm")
# ------------------------------------------------------
# Finished forward GEMM...
# ------------------------------------------------------
if not weight.requires_grad: # Deallocate GEMM input tensor if no longer needed
if not return_layernorm_output: if not weight.requires_grad and not return_layernorm_output:
ln_out = ln_out_total = None ln_out = ln_out_total = None
clear_tensor_data(ln_out, ln_out_total) clear_tensor_data(ln_out, ln_out_total)
# ------------------------------------------------------
# Prepare output tensor
# Note: Perform tensor-parallel communication
# ------------------------------------------------------
out = None
if ub_overlap_rs_fprop:
out = reduce_scatter_out
elif parallel_mode == "row" and tp_size > 1:
nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
out = gemm_out
if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel:
if symmetric_ar_type is not None:
out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
else:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
else:
out = gemm_out
out = out.view(-1, *inp_shape[1:-1], out_features)
# ------------------------------------------------------
# Output tensor is ready to return...
# ------------------------------------------------------
# ------------------------------------------------------
# Cache state for backward pass
# ------------------------------------------------------
if is_grad_enabled: if is_grad_enabled:
ctx.weight_quantizer = weight_quantizer ctx.weight_quantizer = weight_quantizer
ctx.ln_out_needs_gather = ( ctx.ln_out_needs_gather = (
...@@ -346,19 +394,15 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -346,19 +394,15 @@ class _LayerNormLinear(torch.autograd.Function):
# Input with column-wise usage is needed for wgrad GEMM. # Input with column-wise usage is needed for wgrad GEMM.
if backward_needs_input: if backward_needs_input:
if isinstance(ln_out, QuantizedTensor): if isinstance(ln_out, QuantizedTensorBase):
# For sequence parallel in vanilla FP8, rowwise data is # For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data # to gather the input. For MXFP8, columnwise only data
# can be allgathered. # can be allgathered.
if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather: if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather:
ln_out.update_usage(rowwise_usage=False) ln_out.update_usage(rowwise_usage=False)
# For force_hp_blockwise_ln_out_gather, we should
# be saving the unquantized ln_out to ctx.
assert not force_hp_blockwise_ln_out_gather
# Weight with column-wise usage is needed for dgrad GEMM. # Weight with column-wise usage is needed for dgrad GEMM.
if isinstance(weightmat, QuantizedTensor): if isinstance(weightmat, QuantizedTensorBase):
weightmat.update_usage(columnwise_usage=True) weightmat.update_usage(columnwise_usage=True)
if cpu_offloading: if cpu_offloading:
...@@ -445,22 +489,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -445,22 +489,9 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.wgrad_store = wgrad_store ctx.wgrad_store = wgrad_store
ctx.debug = debug ctx.debug = debug
# Row Parallel Linear # ------------------------------------------------------
if ub_overlap_rs_fprop: # Cached state for backward pass is ready...
out = rs_out # ------------------------------------------------------
elif parallel_mode == "row":
nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel:
if symmetric_ar_type is not None:
out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
else:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
# [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.view(-1, *inp_shape[1:-1], out_features)
if return_layernorm_output: if return_layernorm_output:
if return_layernorm_output_gathered: if return_layernorm_output_gathered:
...@@ -482,24 +513,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -482,24 +513,6 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_label = f"{nvtx_label}.{ctx.ub_name}" nvtx_label = f"{nvtx_label}.{ctx.ub_name}"
with torch.cuda.nvtx.range("_LayerNormLinear_backward"): with torch.cuda.nvtx.range("_LayerNormLinear_backward"):
if (
ctx.fp8
and any(
[
ctx.ub_overlap_ag,
ctx.ub_overlap_rs_dgrad,
ctx.ub_bulk_dgrad,
ctx.ub_bulk_wgrad,
]
)
and (ctx.fp8_recipe is not None)
):
if not ctx.fp8_recipe.float8_per_tensor_scaling():
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
saved_tensors = ctx.saved_tensors saved_tensors = ctx.saved_tensors
( # pylint: disable=unbalanced-tuple-unpacking ( # pylint: disable=unbalanced-tuple-unpacking
inputmat, inputmat,
...@@ -544,66 +557,50 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -544,66 +557,50 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
origin_weight.main_grad = main_grad origin_weight.main_grad = main_grad
# Configure Userbuffers communication (comm+GEMM overlap)
ctx.ub_obj_gradout = None ctx.ub_obj_gradout = None
ub_obj_dgrad = None ub_obj_dgrad = None
ub_obj_wgrad = None ub_obj_wgrad = None
ub_type_dgrad = None ub_type_dgrad = None
ub_type_wgrad = None ub_type_wgrad = None
dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
rs_out = None
dgrad_bulk = None
if ctx.ub_overlap_ag: if ctx.ub_overlap_ag:
# Overlap grad_output all-gather with dgrad compute # Overlap grad_output all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG ub_type_dgrad = tex.CommOverlapType.AG
elif ctx.ub_overlap_rs_dgrad: elif ctx.ub_overlap_rs_dgrad:
# Overlap dgrad reduce-scatter with dgrad compute # Overlap dgrad reduce-scatter with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.RS ub_type_dgrad = tex.CommOverlapType.RS
rs_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=inputmat.device
)
else: else:
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
# Overlap inputmat all-gather with dgrad compute # Overlap inputmat all-gather with dgrad compute
# NOTE: Copying into communication buffer will always prefer rowwise data,
# and will copy columnwise data if rowwise does not exist. In that case,
# the all-gather will apply to the leading dimension of the transpose,
# which then needs to be interleaved correctly before WGRAD.
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG ub_type_dgrad = tex.CommOverlapType.AG
ub_obj_dgrad.copy_into_buffer(ln_out, ctx.input_quantizer, local_chunk=True)
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
# Overlap dgrad reduce-scatter with wgrad compute # Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad")
ub_type_wgrad = tex.CommOverlapType.RS ub_type_wgrad = tex.CommOverlapType.RS
ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer)
dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer) # --------------------------------------------------
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# --------------------------------------------------
# Configure quantizer for grad output tensor # Configure quantizer for grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM # Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage # requires column-wise usage
if ctx.grad_output_quantizer is not None: if ctx.grad_output_quantizer is not None:
rowwise_usage = True quantizer = ctx.grad_output_quantizer
columnwise_usage = True quantizer.set_usage(rowwise=True, columnwise=True)
if ctx.ub_overlap_ag and isinstance( if ctx.ub_overlap_ag:
ctx.grad_output_quantizer, # Userbuffers only supports communication for one
(Float8Quantizer, Float8CurrentScalingQuantizer), # tensor usage at a time. Configure quantizer with
): # usage for only dgrad GEMM.
# If data is in FP8 and communication is handled quantizer.set_usage(columnwise=False)
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage = False
ctx.grad_output_quantizer.set_usage(
rowwise=rowwise_usage,
columnwise=columnwise_usage,
)
# Prepare grad output tensor # Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication # Note: Cast to expected dtype and perform tensor-parallel communication
...@@ -619,12 +616,21 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -619,12 +616,21 @@ class _LayerNormLinear(torch.autograd.Function):
) )
nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess") nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
# Launch tensor-parallel communication for LayerNorm out tensor # --------------------------------------------------
# Grad output tensor is ready for computing grad input...
# --------------------------------------------------
# --------------------------------------------------
# Prepare GEMM input tensor
# Note: Input tensor is needed for wgrad GEMM.
# Tensor-parallel communication is overlapped with dgrad
# GEMM.
# --------------------------------------------------
ln_out_total = None ln_out_total = None
ln_out_total_work = None ln_out_total_work = None
if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad: if ctx.ln_out_needs_gather:
quantizer = None quantizer = None
if ctx.input_quantizer is not None: if ctx.input_quantizer is not None and not ctx.force_hp_blockwise_ln_out_gather:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually # If data is in FP8, we compute FP8 transposes manually
...@@ -632,70 +638,92 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -632,70 +638,92 @@ class _LayerNormLinear(torch.autograd.Function):
else: else:
# wgrad GEMM requires input with column-wise usage # wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True) quantizer.set_usage(rowwise=False, columnwise=True)
if ctx.ub_bulk_dgrad:
ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_dgrad,
ln_out,
quantizer,
ctx.tp_group,
)
else:
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
# async_op is not compatible with high precision gather since
# gather_along_first_dim does not offer callback chaining.
gather_quantizer = None if ctx.force_hp_blockwise_ln_out_gather else quantizer
ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out, ln_out,
ctx.tp_group, ctx.tp_group,
async_op=True, async_op=True,
quantizer=gather_quantizer, quantizer=quantizer,
) )
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
else: else:
ln_out_total = ln_out ln_out_total = ln_out
# --------------------------------------------------
# Check whether to output wgrad GEMM directly into main grad # Input tensor is ready for computing grad weight...
if ctx.is_first_microbatch is not None: # --------------------------------------------------
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch # --------------------------------------------------
) # Compute grad input tensor
else: # Note: Gradient w.r.t. GEMM input (i.e. norm output).
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation # --------------------------------------------------
# dgrad GEMM # Make sure required data is available
if ctx.grad_input_quantizer is not None: if isinstance(grad_output, QuantizedTensorBase):
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) grad_output.update_usage(rowwise_usage=True)
if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorBase):
nvtx_range_push(f"{nvtx_label}.dgrad_gemm") weight.update_usage(columnwise_usage=True)
dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator = _2X_ACC_DGRAD
if ctx.fp8: if ctx.fp8:
recipe = ctx.fp8_recipe recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_dgrad"): if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_gemm_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
# Update grad input quantizer
if ctx.grad_input_quantizer is not None:
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensor): # Output buffers for Userbuffers reduce-scatter
weight.update_usage( gemm_out = None
rowwise_usage=ctx.weight_quantizer.rowwise_usage, reduce_scatter_out = None
columnwise_usage=ctx.weight_quantizer.columnwise_usage, if ctx.ub_overlap_rs_dgrad:
reduce_scatter_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_outputs[0].device
) )
dgrad, *_ = general_gemm( elif ctx.ub_bulk_wgrad:
gemm_out = ub_obj_wgrad.get_buffer(local_chunk=False)
# dgrad GEMM
# Note: dx = dy * w
nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
gemm_out, *_, reduce_scatter_out = general_gemm(
weight, weight,
grad_output, grad_output,
get_workspace(), get_workspace(),
layout="NN", layout="NN",
grad=True, grad=True,
quantization_params=ctx.grad_input_quantizer, quantization_params=ctx.grad_input_quantizer,
out=dgrad_bulk, out=gemm_out,
out_dtype=ctx.activation_dtype, out_dtype=ctx.activation_dtype,
use_split_accumulator=dgrad_gemm_use_split_accumulator, use_split_accumulator=use_split_accumulator,
ub=ub_obj_dgrad, ub=ub_obj_dgrad,
ub_type=ub_type_dgrad, ub_type=ub_type_dgrad,
extra_output=rs_out, extra_output=reduce_scatter_out,
bulk_overlap=ctx.ub_bulk_dgrad, bulk_overlap=ctx.ub_bulk_dgrad,
) )
nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") nvtx_range_pop(f"{nvtx_label}.dgrad_gemm")
# Launch tensor-parallel communication # Prepare grad input tensor
# Note: Perform tensor-parallel communication
dgrad = None
dgrad_work = None dgrad_work = None
if ctx.ub_overlap_rs_dgrad: if ctx.ub_overlap_rs_dgrad:
dgrad = rs_out dgrad = reduce_scatter_out
elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad: elif ctx.ub_bulk_wgrad:
dgrad = ub_obj_wgrad.get_buffer(local_chunk=True)
elif ctx.parallel_mode == "column" and ctx.tp_size > 1:
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad")
dgrad = gemm_out
if ctx.sequence_parallel: if ctx.sequence_parallel:
if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered:
dgrad = dgrad + grad_outputs[1].view_as(dgrad)
dgrad, dgrad_work = reduce_scatter_along_first_dim( dgrad, dgrad_work = reduce_scatter_along_first_dim(
dgrad, dgrad,
ctx.tp_group, ctx.tp_group,
...@@ -704,41 +732,55 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -704,41 +732,55 @@ class _LayerNormLinear(torch.autograd.Function):
else: else:
dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad")
else:
dgrad = gemm_out
# --------------------------------------------------
# Grad input tensor has been computed...
# --------------------------------------------------
# --------------------------------------------------
# Compute grad weight
# --------------------------------------------------
# Compute grad weight tensor
wgrad = None wgrad = None
if ctx.requires_wgrad: if ctx.requires_wgrad:
# Synchronize tensor-parallel communication for input tensor # Prepare input tensor
if ctx.ub_bulk_dgrad: # Note: Synchronize tensor-parallel communication and
ln_out_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer) # make sure required data is available
if ctx.fp8:
# FP8 GEMM on Hopper only supports TN layout so the gathered input must have
# a valid transpose.
if ln_out._data is None:
# All-gather executed on columnwise data and result is in rowwise data,
# so we need to fix the interleaving before WGRAD.
ln_out_total = _fix_gathered_fp8_transpose(ln_out_total, ctx.tp_size)
else:
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
ln_out_total._create_transpose()
if ln_out_total_work is not None: if ln_out_total_work is not None:
ln_out_total_work.wait() ln_out_total_work.wait()
ln_out_total_work = None ln_out_total_work = None
if ctx.input_quantizer is not None and not isinstance( if ctx.fp8 or ctx.debug:
ln_out_total, QuantizedTensor if isinstance(ln_out_total, QuantizedTensorBase):
): ln_out_total.update_usage(columnwise_usage=True)
# Async gather may have been done in BF16 else:
# call quantizer after gather.
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total = ctx.input_quantizer(ln_out_total) ln_out_total = ctx.input_quantizer(ln_out_total)
# Make sure GEMM inputs have required data # Prepare grad output tensor
if isinstance(ln_out_total, QuantizedTensor): # Note: Synchronize tensor-parallel communication and
ln_out_total.update_usage(columnwise_usage=True) # make sure required data is available
if isinstance(grad_output, QuantizedTensor): if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output, _ = gather_along_first_dim(
grad_outputs[0],
ctx.tp_group,
quantizer=ctx.grad_output_quantizer,
)
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(columnwise_usage=True) grad_output.update_usage(columnwise_usage=True)
else:
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output = ctx.grad_output_quantizer(grad_output)
# Figure out whether to use split accumulator # Figure out whether to use split accumulator
use_split_accumulator = _2X_ACC_WGRAD use_split_accumulator = _2X_ACC_WGRAD
...@@ -747,55 +789,95 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -747,55 +789,95 @@ class _LayerNormLinear(torch.autograd.Function):
if hasattr(recipe, "fp8_gemm_wgrad"): if hasattr(recipe, "fp8_gemm_wgrad"):
use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator
# Output buffer for overlapping grad input # Figure out whether to output wgrad GEMM directly into main grad
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# Output buffer for overlapping FP8 grad input
# reduce-scatter with wgrad GEMM # reduce-scatter with wgrad GEMM
reduce_scatter_out = None
if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf(): if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf():
rs_out = torch.empty( reduce_scatter_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=inputmat.device dgrad_shape, dtype=ctx.activation_dtype, device=grad_outputs[0].device
) )
# wgrad GEMM # Arguments to include in wgrad GEMM closure
# Note: Fuse with bgrad computation if needed wgrad_gemm_kwargs = {
nvtx_range_push(f"{nvtx_label}.wgrad_gemm") "workspace": get_workspace(),
general_gemm_wgrad = functools.partial( "out_dtype": (
general_gemm,
out_dtype=(
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
), ),
workspace=get_workspace(), "quantization_params": ctx.grad_weight_quantizer,
layout="NT", "accumulate": accumulate_wgrad_into_param_main_grad,
grad=True, "layout": "NT",
bias=(bias if (grad_bias is None and not ctx.fp8) else None), "out": main_grad if ctx.fuse_wgrad_accumulation else None,
out=main_grad if ctx.fuse_wgrad_accumulation else None, "bias": (bias if (grad_bias is None and not ctx.fp8) else None),
use_split_accumulator=use_split_accumulator, "use_split_accumulator": use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad, "grad": True,
quantization_params=ctx.grad_weight_quantizer, "ub": ub_obj_wgrad,
ub=ub_obj_wgrad, "ub_type": ub_type_wgrad,
ub_type=ub_type_wgrad, "extra_output": reduce_scatter_out,
extra_output=rs_out, "bulk_overlap": ctx.ub_bulk_wgrad,
bulk_overlap=ctx.ub_bulk_wgrad, }
)
def wgrad_gemm(
x: torch.Tensor,
dy: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform wgrad GEMM: dw = dy^T * x
May be fused with bgrad computation.
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
return dw, db
# Choose whether to call wgrad GEMM now or delay
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([ln_out_total, grad_output], general_gemm_wgrad) if (
wgrad_gemm_kwargs["ub"] is not None
or wgrad_gemm_kwargs["ub_type"] is not None
or wgrad_gemm_kwargs["extra_output"] is not None
or wgrad_gemm_kwargs["bulk_overlap"]
):
raise NotImplementedError(
"Delayed weight grad computation is not supported "
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx.wgrad_store.put([ln_out_total, grad_output], wgrad_gemm)
else: else:
wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(ln_out_total, grad_output)
# Call wgrad GEMM now
wgrad, grad_bias_ = wgrad_gemm(ln_out_total, grad_output)
# Update grad bias if needed
if grad_bias is None: if grad_bias is None:
grad_bias = grad_bias_ grad_bias = grad_bias_
del grad_bias_ del grad_bias_
# Deallocate input tensor # Deallocate input tensor if permitted
if not ctx.return_layernorm_output: if not ctx.return_layernorm_output:
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
clear_tensor_data(ln_out_total) clear_tensor_data(ln_out_total)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
if ub_obj_wgrad.is_fp8_ubuf(): if ub_obj_wgrad.is_fp8_ubuf():
dgrad = rs_out dgrad = reduce_scatter_out
else: else:
dgrad = ub_obj_wgrad.get_buffer(None, local_chunk=True) dgrad = ub_obj_wgrad.get_buffer(local_chunk=True).clone()
# --------------------------------------------------
# Grad weight has been computed...
# --------------------------------------------------
# Don't return grad bias if not needed # Don't return grad bias if not needed
if not ctx.use_bias: if not ctx.use_bias:
...@@ -870,7 +952,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -870,7 +952,7 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
# Scatter fp8 weight buffers # Scatter fp8 weight buffers
# if ctx.fp8 and not isinstance(weight, QuantizedTensor): # if ctx.fp8 and not isinstance(weight, QuantizedTensorBase):
# _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) # _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
return ( return (
...@@ -1506,7 +1588,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1506,7 +1588,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
grad_output_quantizer = None grad_output_quantizer = None
output_quantizer = None output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
input_quantizer.internal = False input_quantizer.internal = True
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True weight_quantizer.internal = True
if fp8_output: if fp8_output:
...@@ -1574,3 +1656,12 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1574,3 +1656,12 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
# parallel related
if self.sequence_parallel and self.parallel_mode == "row":
# customize grad_output_quantizer with amax reduction TP group
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
...@@ -8,7 +8,6 @@ import warnings ...@@ -8,7 +8,6 @@ import warnings
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
from functools import reduce from functools import reduce
from operator import mul as multiply_op from operator import mul as multiply_op
import functools
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -19,6 +18,7 @@ import transformer_engine_torch as tex ...@@ -19,6 +18,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version from transformer_engine.pytorch import torch_version
from .base import ( from .base import (
fill_userbuffers_buffer_for_all_gather,
get_workspace, get_workspace,
_ub_communicators, _ub_communicators,
get_ub, get_ub,
...@@ -42,7 +42,6 @@ from ..utils import ( ...@@ -42,7 +42,6 @@ from ..utils import (
assert_dim_for_fp8_exec, assert_dim_for_fp8_exec,
clear_tensor_data, clear_tensor_data,
requires_grad, requires_grad,
is_non_tn_fp8_gemm_supported,
needs_quantized_gemm, needs_quantized_gemm,
) )
from ..distributed import ( from ..distributed import (
...@@ -66,10 +65,11 @@ from ..tensor.float8_tensor import ( ...@@ -66,10 +65,11 @@ from ..tensor.float8_tensor import (
) )
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, _fix_gathered_fp8_transpose, WeightGradStore from ._common import apply_normalization, WeightGradStore
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
QuantizedTensor, QuantizedTensor,
QuantizedTensorBase,
Quantizer, Quantizer,
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
...@@ -200,24 +200,16 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -200,24 +200,16 @@ class _LayerNormMLP(torch.autograd.Function):
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
in_features, inp_shape = ln_weight.numel(), inp.shape
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features, inp_shape = ln_weight.numel(), inp.shape
assert inp_shape[-1] == in_features, "GEMM not possible" assert inp_shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features)) inputmat = inp.view((-1, in_features))
if fp8: if fp8:
assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight) assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight)
if any([ub_overlap_ag, ub_overlap_rs]) and not (
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
):
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
activation_func = _act_func( activation_func = _act_func(
activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
)[0] )[0]
device = inp.device
# Cast for native AMP # Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype) inputmat = cast_if_needed(inputmat, activation_dtype)
...@@ -225,6 +217,38 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -225,6 +217,38 @@ class _LayerNormMLP(torch.autograd.Function):
if ln_bias is not None: if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype)
tp_world_size = get_distributed_world_size(tp_group)
backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad
device = inp.device
# Configure Userbuffers communication (comm+GEMM overlap)
ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered
ub_overlap_rs = ub_overlap_rs and is_grad_enabled
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator = _2X_ACC_FPROP
if fp8:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if hasattr(recipe, "fp8_gemm_fprop"):
use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
# Configure quantizer for norm output
if fp8:
if fc1_input_quantizer is None:
raise ValueError("Missing quantizer for FC1 input tensor")
fc1_input_quantizer.set_usage(rowwise=True, columnwise=backwards_needs_fc1_input)
if sequence_parallel and isinstance(
fc1_input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
# All-gather is not supported with FP8 column-wise data
fc1_input_quantizer.set_usage(columnwise=False)
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_fc1_input_gather = (
fp8 and sequence_parallel and isinstance(fc1_input_quantizer, Float8BlockQuantizer)
)
# for fp8 DelayedScaling: layernorm output = FP8 # for fp8 DelayedScaling: layernorm output = FP8
# only output of the linear is returned # only output of the linear is returned
# for return_layernorm_output: layernorm output = High precision, then cast to FP8 # for return_layernorm_output: layernorm output = High precision, then cast to FP8
...@@ -240,29 +264,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -240,29 +264,6 @@ class _LayerNormMLP(torch.autograd.Function):
# Kernels not available for norm fusion. # Kernels not available for norm fusion.
with_quantized_norm = False with_quantized_norm = False
tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered
ub_overlap_rs = ub_overlap_rs and is_grad_enabled
backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad
# TODO(kwyss): Support FP8 allgather of Float8BlockQuantizer recipe
force_hp_fc1_input_gather = (
fp8 and sequence_parallel and isinstance(fc1_input_quantizer, Float8BlockQuantizer)
) # Perform TP communication in high precision.
# Configure quantizer for norm output
if fp8:
if fc1_input_quantizer is None:
raise ValueError("Missing quantizer for FC1 input tensor")
columnwise_usage = backwards_needs_fc1_input
if (
columnwise_usage
and sequence_parallel
and not isinstance(fc1_input_quantizer, MXFP8Quantizer)
):
columnwise_usage = False
fc1_input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
# Apply normalization # Apply normalization
ln_out, mu, rsigma = apply_normalization( ln_out, mu, rsigma = apply_normalization(
inputmat, inputmat,
...@@ -296,39 +297,43 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -296,39 +297,43 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out_total = fc1_input_quantizer(ln_out_total) ln_out_total = fc1_input_quantizer(ln_out_total)
else: else:
quantizer = None
if fp8 or debug: if fp8 or debug:
quantizer = fc1_input_quantizer
if not with_quantized_norm and not force_hp_fc1_input_gather: if not with_quantized_norm and not force_hp_fc1_input_gather:
ln_out = fc1_input_quantizer(ln_out) ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag: if ub_overlap_ag:
# Copy into Userbuffers buffer # Copy into Userbuffers buffer
ub_obj_lnout = get_ub("fc1_fprop") ub_obj_lnout = get_ub("fc1_fprop")
ub_obj_lnout.get_buffer(fc1_input_quantizer, local_chunk=True).copy_(ln_out) ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
ln_out_total = ub_obj_lnout.get_buffer(fc1_input_quantizer) ub_obj_lnout,
ln_out,
quantizer,
tp_group,
)
else: else:
# All-gather with NCCL # All-gather with NCCL
ln_out_total, _ = gather_along_first_dim( ln_out_total, _ = gather_along_first_dim(
ln_out, ln_out,
tp_group, tp_group,
quantizer=(fc1_input_quantizer if fp8 or debug else None), quantizer=quantizer,
) )
else: else:
# NOTE: force_hp_fc1_input_gather is redundant with else, but if (fp8 or debug) and not with_quantized_norm:
# here for clarity. We should not quantize ln_out if bwd needs
# to gather in hp.
if (fp8 or debug) and not with_quantized_norm and not force_hp_fc1_input_gather:
ln_out = fc1_input_quantizer(ln_out) ln_out = fc1_input_quantizer(ln_out)
ln_out_total = ln_out ln_out_total = ln_out
# Cast weights to expected dtype # Cast weights to expected dtype
fc1_weight_final = fc1_weight fc1_weight_final = fc1_weight
fc2_weight_final = fc2_weight fc2_weight_final = fc2_weight
if fp8 or debug: if fp8 or debug:
# If weights are not quantized, we call get_weight_workspace, # If weights are not quantized, we call get_weight_workspace,
# which handles weight caching etc. # which handles weight caching etc.
# FP8 cast to workspace buffer # FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch update_workspace = is_first_microbatch is None or is_first_microbatch
fc1_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc1_weight_final = module.get_weight_workspace( fc1_weight_final = module.get_weight_workspace(
tensor=fc1_weight, tensor=fc1_weight,
quantizer=fc1_weight_quantizer, quantizer=fc1_weight_quantizer,
...@@ -338,7 +343,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -338,7 +343,6 @@ class _LayerNormMLP(torch.autograd.Function):
fsdp_group=fsdp_group, fsdp_group=fsdp_group,
workspace_dtype=activation_dtype, workspace_dtype=activation_dtype,
) )
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc2_weight_final = module.get_weight_workspace( fc2_weight_final = module.get_weight_workspace(
tensor=fc2_weight, tensor=fc2_weight,
quantizer=fc2_weight_quantizer, quantizer=fc2_weight_quantizer,
...@@ -348,6 +352,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -348,6 +352,8 @@ class _LayerNormMLP(torch.autograd.Function):
fsdp_group=fsdp_group, fsdp_group=fsdp_group,
workspace_dtype=activation_dtype, workspace_dtype=activation_dtype,
) )
fc1_weight_final.update_usage(rowwise_usage=True)
fc2_weight_final.update_usage(rowwise_usage=True)
else: else:
fc1_weight_final = cast_if_needed(fc1_weight_final, activation_dtype) fc1_weight_final = cast_if_needed(fc1_weight_final, activation_dtype)
fc2_weight_final = cast_if_needed(fc2_weight_final, activation_dtype) fc2_weight_final = cast_if_needed(fc2_weight_final, activation_dtype)
...@@ -355,6 +361,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -355,6 +361,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Cast biases to expected dtype # Cast biases to expected dtype
bias_dtype = activation_dtype bias_dtype = activation_dtype
if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32: if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32:
# cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16
bias_dtype = torch.bfloat16 bias_dtype = torch.bfloat16
if fc1_bias is not None: if fc1_bias is not None:
fc1_bias = cast_if_needed(fc1_bias, bias_dtype) fc1_bias = cast_if_needed(fc1_bias, bias_dtype)
...@@ -368,7 +375,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -368,7 +375,9 @@ class _LayerNormMLP(torch.autograd.Function):
if fc1_weight_quantizer is not None: if fc1_weight_quantizer is not None:
fc1_weight_quantizer.calibrate(fc1_weight) fc1_weight_quantizer.calibrate(fc1_weight)
# ------------------------------------------------------
# FC1 GEMM # FC1 GEMM
# ------------------------------------------------------
# There are 2 fusions possible: # There are 2 fusions possible:
# - gemm_gelu_fusion - default for full precision, optional for fp8 - need to turn on gemm_gelu_fusion, # - gemm_gelu_fusion - default for full precision, optional for fp8 - need to turn on gemm_gelu_fusion,
...@@ -400,11 +409,16 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -400,11 +409,16 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_bias if not bias_gelu_fusion else None fc1_bias if not bias_gelu_fusion else None
), # otherwise bias is added later (fused with gelu) ), # otherwise bias is added later (fused with gelu)
gelu=gemm_gelu_fusion, gelu=gemm_gelu_fusion,
accumulate=_2X_ACC_FPROP, use_split_accumulator=use_split_accumulator,
ub=ub_obj_lnout, ub=ub_obj_lnout,
ub_type=tex.CommOverlapType.AG if ub_overlap_ag else None, ub_type=tex.CommOverlapType.AG if ub_overlap_ag else None,
) )
# ------------------------------------------------------
# Finished FC1 GEMM...
# ------------------------------------------------------
# Deallocate FC1 GEMM input tensor if no longer needed
if not is_grad_enabled and (ln_out_total is not ln_out_return): if not is_grad_enabled and (ln_out_total is not ln_out_return):
clear_tensor_data(ln_out_total) clear_tensor_data(ln_out_total)
...@@ -438,45 +452,66 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -438,45 +452,66 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_input_quantizer.calibrate(act_out) fc2_input_quantizer.calibrate(act_out)
fc2_weight_quantizer.calibrate(fc2_weight) fc2_weight_quantizer.calibrate(fc2_weight)
# Configure Userbuffers reduce-scatter if needed
ub_obj_fc2out = None ub_obj_fc2out = None
rs_out = None reduce_scatter_out = None
fc2_out = None
if ub_overlap_rs: if ub_overlap_rs:
ub_obj_fc2out = get_ub("fc2_fprop") ub_obj_fc2out = get_ub("fc2_fprop")
dim_size = list(act_out.size()) dim_size = list(act_out.size())
dim_size[0] = dim_size[0] // tp_world_size dim_size[0] //= tp_world_size
dim_size[1] = fc2_weight.size(0) dim_size[-1] = fc2_weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=device) reduce_scatter_out = torch.empty(dim_size, dtype=activation_dtype, device=device)
else:
dim_size = list(act_out.size())
dim_size[1] = fc2_weight.size(0)
fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=device)
# ------------------------------------------------------
# FC2 GEMM # FC2 GEMM
_ = general_gemm( # ------------------------------------------------------
gemm_out, *_, reduce_scatter_out = general_gemm(
fc2_weight_final, fc2_weight_final,
act_out, act_out,
get_workspace(), get_workspace(),
out_dtype=activation_dtype, out_dtype=activation_dtype,
bias=fc2_bias, bias=fc2_bias,
quantization_params=fc2_output_quantizer, quantization_params=fc2_output_quantizer,
out=fc2_out, use_split_accumulator=use_split_accumulator,
use_split_accumulator=_2X_ACC_FPROP,
ub=ub_obj_fc2out, ub=ub_obj_fc2out,
ub_type=tex.CommOverlapType.RS if ub_overlap_rs else None, ub_type=tex.CommOverlapType.RS if ub_overlap_rs else None,
extra_output=rs_out, extra_output=reduce_scatter_out,
) )
# ------------------------------------------------------
# Finished FC2 GEMM...
# ------------------------------------------------------
# Weight with column-wise usage is needed for dgrad GEMM. # Deallocate tensors if no longer needed
if not is_grad_enabled:
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)
# Prepare output tensor
# Note: Perform tensor-parallel communication if needed
fc2_out = None
if ub_overlap_rs:
fc2_out = reduce_scatter_out
elif set_parallel_mode and sequence_parallel:
fc2_out, _ = reduce_scatter_along_first_dim(gemm_out, tp_group)
elif set_parallel_mode and tensor_parallel:
if symmetric_ar_type is not None:
fc2_out, _ = symmetric_all_reduce(
gemm_out, tp_group, all_reduce_type=symmetric_ar_type
)
else:
fc2_out, _ = allreduce(gemm_out, tp_group)
else:
fc2_out = gemm_out
fc2_out = fc2_out.view(-1, *inp_shape[1:-1], fc2_out.shape[-1])
# Cache state for backward pass
if is_grad_enabled: if is_grad_enabled:
if isinstance(fc1_weight_final, QuantizedTensor):
# Weight with column-wise usage is needed for dgrad GEMM.
if isinstance(fc1_weight_final, QuantizedTensorBase):
fc1_weight_final.update_usage(columnwise_usage=True) fc1_weight_final.update_usage(columnwise_usage=True)
if isinstance(fc2_weight_final, QuantizedTensor): if isinstance(fc2_weight_final, QuantizedTensorBase):
fc2_weight_final.update_usage(columnwise_usage=True) fc2_weight_final.update_usage(columnwise_usage=True)
if not is_grad_enabled:
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)
else:
if cpu_offloading: if cpu_offloading:
mark_activation_offload( mark_activation_offload(
inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out
...@@ -503,8 +538,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -503,8 +538,6 @@ class _LayerNormMLP(torch.autograd.Function):
if not return_layernorm_output: if not return_layernorm_output:
clear_tensor_data(ln_out) clear_tensor_data(ln_out)
ln_out = None ln_out = None
elif force_hp_fc1_input_gather:
assert not isinstance(ln_out, QuantizedTensor)
if not fc2_weight.requires_grad: if not fc2_weight.requires_grad:
clear_tensor_data(act_out) clear_tensor_data(act_out)
act_out = None act_out = None
...@@ -591,22 +624,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -591,22 +624,6 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.wgrad_store = wgrad_store ctx.wgrad_store = wgrad_store
# Row Parallel Linear
if ub_overlap_rs:
fc2_out = rs_out
elif set_parallel_mode and sequence_parallel:
fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group)
elif set_parallel_mode and tensor_parallel:
if symmetric_ar_type is not None:
fc2_out, _ = symmetric_all_reduce(
fc2_out, tp_group, all_reduce_type=symmetric_ar_type
)
else:
fc2_out, _ = allreduce(fc2_out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
fc2_out = fc2_out.view(-1, *inp_shape[1:-1], fc2_out.shape[-1])
if return_layernorm_output: if return_layernorm_output:
if return_layernorm_output_gathered: if return_layernorm_output_gathered:
shape = list(inp_shape) shape = list(inp_shape)
...@@ -621,24 +638,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -621,24 +638,6 @@ class _LayerNormMLP(torch.autograd.Function):
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
with torch.cuda.nvtx.range("_LayerNormMLP_backward"): with torch.cuda.nvtx.range("_LayerNormMLP_backward"):
if (
ctx.fp8
and any(
[
ctx.ub_overlap_ag,
ctx.ub_overlap_rs_dgrad,
ctx.ub_bulk_dgrad,
ctx.ub_bulk_wgrad,
]
)
and (ctx.fp8_recipe is not None)
):
if not ctx.fp8_recipe.float8_per_tensor_scaling():
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
saved_tensors = ctx.saved_tensors saved_tensors = ctx.saved_tensors
( # pylint: disable=unbalanced-tuple-unpacking ( # pylint: disable=unbalanced-tuple-unpacking
inputmat, inputmat,
...@@ -698,6 +697,16 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -698,6 +697,16 @@ class _LayerNormMLP(torch.autograd.Function):
# fc2_weight_fp8 if ctx.fp8 and not isinstance(fc2_weight, Float8Tensor) else None, # fc2_weight_fp8 if ctx.fp8 and not isinstance(fc2_weight, Float8Tensor) else None,
# ) # )
# Choose whether to use GEMM kernel with split accumulator
dgrad_use_split_accumulator = _2X_ACC_DGRAD
wgrad_use_split_accumulator = _2X_ACC_WGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
if hasattr(recipe, "fp8_gemm_wgrad"):
wgrad_use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator
# No need to do bulk DGRAD/WGRAD overlap if WGRAD is not required # No need to do bulk DGRAD/WGRAD overlap if WGRAD is not required
ctx.ub_bulk_dgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_dgrad ctx.ub_bulk_dgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_dgrad
ctx.ub_bulk_wgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_wgrad ctx.ub_bulk_wgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_wgrad
...@@ -706,20 +715,13 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -706,20 +715,13 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM # Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage # requires column-wise usage
if ctx.fc2_grad_output_quantizer is not None: if ctx.fc2_grad_output_quantizer is not None:
rowwise_usage = True quantizer = ctx.fc2_grad_output_quantizer
columnwise_usage = True quantizer.set_usage(rowwise=True, columnwise=True)
if ctx.ub_overlap_ag and isinstance( if ctx.ub_overlap_ag:
ctx.fc2_grad_output_quantizer, # Userbuffers only supports communication for one
(Float8Quantizer, Float8CurrentScalingQuantizer), # tensor usage at a time. Configure quantizer with
): # usage for only dgrad GEMM.
# If data is in FP8 and communication is handled quantizer.set_usage(columnwise=False)
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage = False
ctx.fc2_grad_output_quantizer.set_usage(
rowwise=rowwise_usage,
columnwise=columnwise_usage,
)
# Prepare FC2 grad output tensor # Prepare FC2 grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication # Note: Cast to expected dtype and perform tensor-parallel communication
...@@ -737,14 +739,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -737,14 +739,10 @@ class _LayerNormMLP(torch.autograd.Function):
# Launch tensor-parallel communication for FC1 GEMM input # Launch tensor-parallel communication for FC1 GEMM input
ln_out_total = None ln_out_total = None
ln_out_total_work = None ln_out_total_work = None
if ( ub_obj_fc1_dgrad = None
ctx.fc1_weight_requires_grad if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel:
and ctx.tensor_parallel
and ctx.sequence_parallel
and not ctx.ub_bulk_dgrad
):
quantizer = None quantizer = None
if ctx.fp8 or ctx.debug: if ctx.fp8 or ctx.debug and not ctx.force_hp_fc1_input_gather:
quantizer = ctx.fc1_input_quantizer quantizer = ctx.fc1_input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually # If data is in FP8, we compute FP8 transposes manually
...@@ -752,12 +750,20 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -752,12 +750,20 @@ class _LayerNormMLP(torch.autograd.Function):
else: else:
# wgrad GEMM requires input with column-wise usage # wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True) quantizer.set_usage(rowwise=False, columnwise=True)
gather_quantizer = None if ctx.force_hp_fc1_input_gather else quantizer if ctx.ub_bulk_dgrad:
ub_obj_fc1_dgrad = get_ub("fc1_dgrad")
ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_fc1_dgrad,
ln_out,
quantizer,
ctx.tp_group,
)
else:
ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out, ln_out,
ctx.tp_group, ctx.tp_group,
async_op=True, async_op=True,
quantizer=gather_quantizer, quantizer=quantizer,
) )
else: else:
ln_out_total = ln_out ln_out_total = ln_out
...@@ -769,6 +775,11 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -769,6 +775,11 @@ class _LayerNormMLP(torch.autograd.Function):
) )
else: else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# --------------------------------------------------
# FC2 DGRAD
# --------------------------------------------------
# There are 6 possible fusion paths # There are 6 possible fusion paths
# 1 high-precision bias_gelu_fusion: gemm, FC1_bias + gelu, # 1 high-precision bias_gelu_fusion: gemm, FC1_bias + gelu,
# 2 high-precision fc2_dgrad_gemm_gelu_fusion: gemm + gelu, FC1_bias + quantize # 2 high-precision fc2_dgrad_gemm_gelu_fusion: gemm + gelu, FC1_bias + quantize
...@@ -783,12 +794,15 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -783,12 +794,15 @@ class _LayerNormMLP(torch.autograd.Function):
and (not ctx.debug) and (not ctx.debug)
) )
# FC2 DGRAD; Unconditional # Make sure required data is available
if ctx.fc2_weight_quantizer is not None and isinstance(ctx.fc2_weight, QuantizedTensor): if isinstance(grad_output, QuantizedTensorBase):
ctx.fc2_weight.update_usage( grad_output.update_usage(rowwise_usage=True)
rowwise_usage=ctx.fc2_weight_quantizer.rowwise_usage, if ctx.fc2_weight_quantizer is not None and isinstance(
columnwise_usage=ctx.fc2_weight_quantizer.columnwise_usage, ctx.fc2_weight, QuantizedTensorBase
) ):
ctx.fc2_weight.update_usage(columnwise_usage=True)
# Perform GEMM
gemm_output, *_ = general_gemm( gemm_output, *_ = general_gemm(
fc2_weight, fc2_weight,
grad_output, grad_output,
...@@ -803,52 +817,107 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -803,52 +817,107 @@ class _LayerNormMLP(torch.autograd.Function):
out_dtype=ctx.activation_dtype, out_dtype=ctx.activation_dtype,
gelu=fc2_dgrad_gemm_gelu_fusion, gelu=fc2_dgrad_gemm_gelu_fusion,
gelu_in=fc1_out if fc2_dgrad_gemm_gelu_fusion else None, gelu_in=fc1_out if fc2_dgrad_gemm_gelu_fusion else None,
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=dgrad_use_split_accumulator,
ub=ub_obj_fc2_dgrad, ub=ub_obj_fc2_dgrad,
ub_type=tex.CommOverlapType.AG if ctx.ub_overlap_ag else None, ub_type=tex.CommOverlapType.AG if ctx.ub_overlap_ag else None,
) )
# Prepare input grad tensor
dact = None
fc2_dgrad = None
if fc2_dgrad_gemm_gelu_fusion: if fc2_dgrad_gemm_gelu_fusion:
dact = gemm_output dact = gemm_output
fc2_dgrad = None
else: else:
fc2_dgrad = gemm_output fc2_dgrad = gemm_output
# --------------------------------------------------
# Finished FC2 DGRAD...
# --------------------------------------------------
# --------------------------------------------------
# FC2 WGRAD # FC2 WGRAD
# --------------------------------------------------
fc2_wgrad = None
if ctx.fc2_weight_requires_grad: if ctx.fc2_weight_requires_grad:
if isinstance(act_out, QuantizedTensor):
act_out.update_usage(rowwise_usage=True, columnwise_usage=True)
if isinstance(grad_output, QuantizedTensor): # Prepare input tensor
grad_output.update_usage(rowwise_usage=True, columnwise_usage=True) # Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.fp8 or ctx.debug:
if isinstance(act_out, QuantizedTensorBase):
act_out.update_usage(columnwise_usage=True)
else:
ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True)
act_out = ctx.fc2_input_quantizer(act_out)
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output, _ = gather_along_first_dim(
grad_outputs[0],
ctx.tp_group,
quantizer=ctx.fc2_grad_output_quantizer,
)
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(columnwise_usage=True)
else:
ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output = ctx.fc2_grad_output_quantizer(grad_output)
# Whether to set grad arg in general_gemm
grad_arg = True grad_arg = True
if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling(): if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling():
grad_arg = False grad_arg = False
general_gemm_fc2_wgrad = functools.partial(
general_gemm, # Arguments to include in wgrad GEMM closure
out_dtype=( fc2_wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": (
origin_fc2_weight.main_grad.dtype origin_fc2_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype else ctx.activation_dtype
), ),
workspace=get_workspace(), "quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision
quantization_params=ctx.fc2_grad_weight_quantizer, # wgrad in high precision "accumulate": accumulate_wgrad_into_param_main_grad,
layout="NT", "layout": "NT",
grad=grad_arg, "out": origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None, "bias": fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None,
accumulate=accumulate_wgrad_into_param_main_grad, "use_split_accumulator": wgrad_use_split_accumulator,
use_split_accumulator=_2X_ACC_WGRAD, "grad": grad_arg,
out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, }
)
def fc2_wgrad_gemm(
x: torch.Tensor,
dy: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform FC2 WGRAD GEMM
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
dw, db, *_ = general_gemm(x, dy, **fc2_wgrad_gemm_kwargs)
return dw, db
# Choose whether to call wgrad GEMM now or delay
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([act_out, grad_output], general_gemm_fc2_wgrad) ctx.wgrad_store.put([act_out, grad_output], fc2_wgrad_gemm)
fc2_wgrad = None
else: else:
fc2_wgrad, fc2_bias_grad_, *_ = general_gemm_fc2_wgrad(
act_out,
grad_output,
)
# Call wgrad GEMM now
fc2_wgrad, fc2_bias_grad_ = fc2_wgrad_gemm(act_out, grad_output)
# Update grad bias if needed
if fc2_bias_grad is None: if fc2_bias_grad is None:
if ( if (
ctx.fp8 ctx.fp8
...@@ -857,12 +926,17 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -857,12 +926,17 @@ class _LayerNormMLP(torch.autograd.Function):
): ):
# BGRAD not fused with GEMM for float8 blockwise gemm. # BGRAD not fused with GEMM for float8 blockwise gemm.
fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0) fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0)
fc2_bias_grad = fc2_bias_grad_ fc2_bias_grad = fc2_bias_grad_
del fc2_bias_grad_ del fc2_bias_grad_
# Deallocate input tensor if permitted
if ctx.wgrad_store is not None and not ctx.wgrad_store.delay_wgrad_compute(): if ctx.wgrad_store is not None and not ctx.wgrad_store.delay_wgrad_compute():
clear_tensor_data(act_out) clear_tensor_data(act_out)
# --------------------------------------------------
# Finished FC2 WGRAD...
# --------------------------------------------------
# bias computation # bias computation
fc1_bias_grad = None fc1_bias_grad = None
fuse_gemm_and_bias_fc1_wgrad = False fuse_gemm_and_bias_fc1_wgrad = False
...@@ -926,63 +1000,69 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -926,63 +1000,69 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_fc1_dgrad = None ub_obj_fc1_dgrad = None
ub_obj_fc1_wgrad = None ub_obj_fc1_wgrad = None
ub_type_fc1_dgrad = None ub_type_fc1_dgrad = None
ub_type_fc1_wgrad = None
fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]] fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]]
fc1_dgrad_rs_out = None
fc1_dgrad_bulk = None
if ctx.ub_overlap_rs_dgrad: if ctx.ub_overlap_rs_dgrad:
# Overlap DGRAD+RS # Overlap DGRAD+RS
ub_obj_fc1_dgrad = get_ub("fc1_dgrad") ub_obj_fc1_dgrad = get_ub("fc1_dgrad")
ub_type_fc1_dgrad = tex.CommOverlapType.RS ub_type_fc1_dgrad = tex.CommOverlapType.RS
fc1_dgrad_rs_out = torch.empty(
fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda"
)
else: else:
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
# Overlap ln_out all-gather with DGRAD compute # Overlap ln_out all-gather with DGRAD compute
# NOTE: Copying into communication buffer will always prefer rowwise data,
# and will copy columnwise data if rowwise does not exist. In that case,
# the all-gather will apply to the leading dimension of the transpose,
# which then needs to be interleaved correctly before WGRAD.
ub_obj_fc1_dgrad = get_ub("fc1_dgrad") ub_obj_fc1_dgrad = get_ub("fc1_dgrad")
ub_type_fc1_dgrad = tex.CommOverlapType.AG ub_type_fc1_dgrad = tex.CommOverlapType.AG
ub_obj_fc1_dgrad.copy_into_buffer(
ln_out, ctx.fc1_input_quantizer, local_chunk=True
)
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
# Overlap FC1 DGRAD reduce-scatter with WGRAD compute # Overlap FC1 DGRAD reduce-scatter with WGRAD compute
ub_obj_fc1_wgrad = get_ub("fc1_wgrad") ub_obj_fc1_wgrad = get_ub("fc1_wgrad")
fc1_dgrad_bulk = ub_obj_fc1_wgrad.get_buffer(None) ub_type_fc1_wgrad = tex.CommOverlapType.RS
# --------------------------------------------------
# FC1 DGRAD
# --------------------------------------------------
# FC1 DGRAD: Unconditional # Make sure required data is available
if ctx.fc1_weight_quantizer is not None and isinstance( if ctx.fc1_weight_quantizer is not None and isinstance(
ctx.fc1_weight_quantizer, QuantizedTensor ctx.fc1_weight_quantizer, QuantizedTensorBase
): ):
ctx.fc1_weight.update_usage( ctx.fc1_weight.update_usage(columnwise_usage=True)
rowwise_usage=ctx.fc1_weight_quantizer.rowwise_usage,
columnwise_usage=ctx.fc1_weight_quantizer.columnwise_usage, # Output buffers for Userbuffers reduce-scatter
gemm_out = None
reduce_scatter_out = None
if ctx.ub_overlap_rs_dgrad:
reduce_scatter_out = torch.empty(
fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda"
) )
fc1_dgrad, *_, fc1_dgrad_rs_out = general_gemm( if ctx.ub_bulk_wgrad:
gemm_out = ub_obj_fc1_wgrad.get_buffer(local_chunk=False)
# dgrad GEMM
gemm_out, *_, reduce_scatter_out = general_gemm(
fc1_weight, fc1_weight,
dact, dact,
get_workspace(), get_workspace(),
out=fc1_dgrad_bulk, out=gemm_out,
out_dtype=ctx.activation_dtype, out_dtype=ctx.activation_dtype,
quantization_params=ctx.fc1_grad_input_quantizer, quantization_params=ctx.fc1_grad_input_quantizer,
layout="NN", layout="NN",
grad=True, grad=True,
use_split_accumulator=dgrad_use_split_accumulator,
ub=ub_obj_fc1_dgrad, ub=ub_obj_fc1_dgrad,
ub_type=ub_type_fc1_dgrad, ub_type=ub_type_fc1_dgrad,
extra_output=fc1_dgrad_rs_out, extra_output=reduce_scatter_out,
bulk_overlap=ctx.ub_bulk_dgrad, bulk_overlap=ctx.ub_bulk_dgrad,
) )
# Overlap dgrad-RS/AR with wgrad # Prepare grad input tensor
# Note: Perform tensor-parallel communication
fc1_dgrad = None
fc1_dgrad_work = None fc1_dgrad_work = None
if ctx.ub_overlap_rs_dgrad: if ctx.ub_overlap_rs_dgrad:
fc1_dgrad = fc1_dgrad_rs_out fc1_dgrad = reduce_scatter_out
elif ctx.ub_bulk_wgrad:
fc1_dgrad = ub_obj_fc1_wgrad.get_buffer(local_chunk=True)
elif ctx.set_parallel_mode and not ctx.ub_bulk_wgrad: elif ctx.set_parallel_mode and not ctx.ub_bulk_wgrad:
fc1_dgrad = gemm_out
if ctx.sequence_parallel: if ctx.sequence_parallel:
if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered: if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered:
fc1_dgrad = fc1_dgrad + grad_outputs[1].view_as(fc1_dgrad) fc1_dgrad = fc1_dgrad + grad_outputs[1].view_as(fc1_dgrad)
...@@ -993,90 +1073,125 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -993,90 +1073,125 @@ class _LayerNormMLP(torch.autograd.Function):
) )
elif ctx.tensor_parallel: elif ctx.tensor_parallel:
fc1_dgrad, fc1_dgrad_work = allreduce(fc1_dgrad, ctx.tp_group, async_op=True) fc1_dgrad, fc1_dgrad_work = allreduce(fc1_dgrad, ctx.tp_group, async_op=True)
else:
fc1_dgrad = gemm_out
# --------------------------------------------------
# Finished FC1 DGRAD...
# --------------------------------------------------
# --------------------------------------------------
# FC1 WGRAD # FC1 WGRAD
# --------------------------------------------------
fc1_wgrad = None fc1_wgrad = None
if ctx.fc1_weight_requires_grad: if ctx.fc1_weight_requires_grad:
# Synchronize tensor-parallel communication for FC1 GEMM input tensor # Prepare input tensor
if ctx.ub_bulk_dgrad: # Note: Synchronize tensor-parallel communication and
ln_out_total = ub_obj_fc1_dgrad.get_buffer(ctx.fc1_input_quantizer) # make sure required data is available
if ctx.fp8:
if ln_out._data is None:
# All-gather executed on columnwise data and result is in rowwise data,
# so we need to fix the interleaving before WGRAD.
ln_out_total = _fix_gathered_fp8_transpose(ln_out_total, ctx.tp_size)
elif not is_non_tn_fp8_gemm_supported():
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
ln_out_total._create_transpose()
if ln_out_total_work is not None: if ln_out_total_work is not None:
ln_out_total_work.wait() ln_out_total_work.wait()
ln_out_total_work = None ln_out_total_work = None
if ctx.fc1_input_quantizer is not None and not isinstance( if ctx.fp8 or ctx.debug:
ln_out_total, QuantizedTensor if isinstance(ln_out_total, QuantizedTensorBase):
): ln_out_total.update_usage(columnwise_usage=True)
# Async gather in BF16 does not asynchronously else:
# call quantizer after gather.
ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True) ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total = ctx.fc1_input_quantizer(ln_out_total) ln_out_total = ctx.fc1_input_quantizer(ln_out_total)
# Make sure GEMM inputs have required data # Prepare grad output tensor
if isinstance(ln_out_total, QuantizedTensor): # Note: Synchronize tensor-parallel communication and
ln_out_total.update_usage(columnwise_usage=True) # make sure required data is available
if isinstance(dact, QuantizedTensor): if ctx.fp8 or ctx.debug:
if isinstance(dact, QuantizedTensorBase):
dact.update_usage(columnwise_usage=True) dact.update_usage(columnwise_usage=True)
else:
ctx.fc1_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
dact = ctx.fc1_grad_output_quantizer(dact)
# Output buffer for overlapping grad input # Output buffer for overlapping grad input
# reduce-scatter with wgrad GEMM # reduce-scatter with wgrad GEMM
reduce_scatter_out = None
if ctx.ub_bulk_wgrad and ub_obj_fc1_wgrad.is_fp8_ubuf(): if ctx.ub_bulk_wgrad and ub_obj_fc1_wgrad.is_fp8_ubuf():
fc1_dgrad_rs_out = torch.empty( reduce_scatter_out = torch.empty(
fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda" fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda"
) )
# wgrad GEMM # Arguments to include in wgrad GEMM closure
general_gemm_fc1_wgrad = functools.partial( fc1_wgrad_gemm_kwargs = {
general_gemm, "workspace": get_workspace(),
out_dtype=( "out_dtype": (
origin_fc1_weight.main_grad.dtype origin_fc1_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype else ctx.activation_dtype
), ),
workspace=get_workspace(), "quantization_params": ctx.fc1_grad_weight_quantizer,
layout="NT", "accumulate": accumulate_wgrad_into_param_main_grad,
quantization_params=ctx.fc1_grad_weight_quantizer, "layout": "NT",
grad=fuse_gemm_and_bias_fc1_wgrad, "out": origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, "bias": fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None,
accumulate=accumulate_wgrad_into_param_main_grad, "use_split_accumulator": wgrad_use_split_accumulator,
out=origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, "grad": fuse_gemm_and_bias_fc1_wgrad,
ub=ub_obj_fc1_wgrad, "ub": ub_obj_fc1_wgrad,
ub_type=tex.CommOverlapType.RS if ctx.ub_bulk_wgrad else None, "ub_type": ub_type_fc1_wgrad,
extra_output=fc1_dgrad_rs_out, "extra_output": reduce_scatter_out,
bulk_overlap=ctx.ub_bulk_wgrad, "bulk_overlap": ctx.ub_bulk_wgrad,
) }
def fc1_wgrad_gemm(
x: torch.Tensor,
dy: torch.Tensor,
_is_delayed: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform FC1 WGRAD GEMM
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
dw, db, *_ = general_gemm(x, dy, **fc1_wgrad_gemm_kwargs)
return dw, db
# Choose whether to call wgrad GEMM now or delay
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([ln_out_total, dact], general_gemm_fc1_wgrad) if (
fc1_wgrad_gemm_kwargs["ub"] is not None
or fc1_wgrad_gemm_kwargs["ub_type"] is not None
or fc1_wgrad_gemm_kwargs["extra_output"] is not None
or fc1_wgrad_gemm_kwargs["bulk_overlap"]
):
raise NotImplementedError(
"Delayed weight grad computation is not supported "
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx.wgrad_store.put([ln_out_total, dact], fc1_wgrad_gemm)
fc1_wgrad = None fc1_wgrad = None
if fuse_gemm_and_bias_fc1_wgrad: if fuse_gemm_and_bias_fc1_wgrad:
fc1_bias_grad = None fc1_bias_grad = None
else: else:
fc1_wgrad_outputs = general_gemm_fc1_wgrad(
ln_out_total,
dact,
)
clear_tensor_data(ln_out_total, dact)
# Call wgrad GEMM now
fc1_wgrad_outputs = fc1_wgrad_gemm(ln_out_total, dact)
if fuse_gemm_and_bias_fc1_wgrad: if fuse_gemm_and_bias_fc1_wgrad:
fc1_wgrad, fc1_bias_grad, *_ = fc1_wgrad_outputs fc1_wgrad, fc1_bias_grad = fc1_wgrad_outputs
else: else:
fc1_wgrad, *_ = fc1_wgrad_outputs fc1_wgrad, _ = fc1_wgrad_outputs
# Deallocate tensors if permitted
clear_tensor_data(dact)
if not ctx.return_layernorm_output_gathered:
clear_tensor_data(ln_out_total)
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
if ub_obj_fc1_wgrad.is_fp8_ubuf(): if ub_obj_fc1_wgrad.is_fp8_ubuf():
fc1_dgrad = fc1_dgrad_rs_out fc1_dgrad = reduce_scatter_out
else: else:
fc1_dgrad = ub_obj_fc1_wgrad.get_buffer(None, local_chunk=True) fc1_dgrad = ub_obj_fc1_wgrad.get_buffer(local_chunk=True).clone()
# --------------------------------------------------
# Finished FC1 WGRAD...
# --------------------------------------------------
# Make sure all tensor-parallel communication is finished # Make sure all tensor-parallel communication is finished
if ln_out_total_work is not None: if ln_out_total_work is not None:
...@@ -1746,7 +1861,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1746,7 +1861,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
) = [None] * 12 ) = [None] * 12
if self.fp8: if self.fp8:
fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
fc1_input_quantizer.internal = False # temporary fc1_input_quantizer.internal = True
fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
fc1_weight_quantizer.internal = True fc1_weight_quantizer.internal = True
fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
...@@ -1754,6 +1869,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1754,6 +1869,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
rowwise=True, rowwise=True,
columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)), columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)),
) )
fc1_input_quantizer.internal = True
fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT]
fc2_weight_quantizer.internal = True fc2_weight_quantizer.internal = True
if fp8_output: if fp8_output:
...@@ -1762,11 +1878,11 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1762,11 +1878,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
] ]
if torch.is_grad_enabled(): if torch.is_grad_enabled():
fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][ fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT2
] ]
fc2_grad_output_quantizer.internal = True fc2_grad_output_quantizer.internal = True
fc1_grad_output_quantizer = self.quantizers["scaling_bwd"][ fc1_grad_output_quantizer = self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_INPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
] ]
fc1_grad_output_quantizer.internal = True fc1_grad_output_quantizer.internal = True
...@@ -1851,25 +1967,25 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1851,25 +1967,25 @@ class LayerNormMLP(TransformerEngineBaseModule):
else: else:
# fc2_grad_output_quantizer: set configs about amax epsilon and power_2_scale for fc2_grad_output_quantizer # fc2_grad_output_quantizer: set configs about amax epsilon and power_2_scale for fc2_grad_output_quantizer
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT2
].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT2
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
# fc1_grad_output_quantizer: also set numerical configs for fc1_grad_output_quantizer # fc1_grad_output_quantizer: also set numerical configs for fc1_grad_output_quantizer
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_INPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_INPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
if self.sequence_parallel and self.set_parallel_mode: if self.sequence_parallel and self.set_parallel_mode:
# fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here # fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT2
].with_amax_reduction = True ].with_amax_reduction = True
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT2
].amax_reduction_group = self.tp_group ].amax_reduction_group = self.tp_group
def backward_dw(self): def backward_dw(self):
......
...@@ -8,7 +8,6 @@ from functools import reduce ...@@ -8,7 +8,6 @@ from functools import reduce
from operator import mul as multiply_op from operator import mul as multiply_op
import warnings import warnings
import functools
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
...@@ -17,15 +16,16 @@ from transformer_engine.common.recipe import Recipe ...@@ -17,15 +16,16 @@ from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version from transformer_engine.pytorch import torch_version
from .base import ( from .base import (
get_workspace, fill_userbuffers_buffer_for_all_gather,
get_dummy_wgrad,
get_ub, get_ub,
get_workspace,
TransformerEngineBaseModule, TransformerEngineBaseModule,
get_dummy_wgrad,
_2X_ACC_FPROP, _2X_ACC_FPROP,
_2X_ACC_DGRAD, _2X_ACC_DGRAD,
_2X_ACC_WGRAD, _2X_ACC_WGRAD,
) )
from ._common import noop_cat, _fix_gathered_fp8_transpose, WeightGradStore from ._common import noop_cat, WeightGradStore
from ..fp8 import FP8GlobalStateManager from ..fp8 import FP8GlobalStateManager
from ..utils import ( from ..utils import (
cast_if_needed, cast_if_needed,
...@@ -34,7 +34,6 @@ from ..utils import ( ...@@ -34,7 +34,6 @@ from ..utils import (
init_method_constant, init_method_constant,
requires_grad, requires_grad,
needs_quantized_gemm, needs_quantized_gemm,
is_non_tn_fp8_gemm_supported,
assert_dim_for_fp8_exec, assert_dim_for_fp8_exec,
nvtx_range_pop, nvtx_range_pop,
nvtx_range_push, nvtx_range_push,
...@@ -130,86 +129,98 @@ class _Linear(torch.autograd.Function): ...@@ -130,86 +129,98 @@ class _Linear(torch.autograd.Function):
out_features, in_features = weight.shape out_features, in_features = weight.shape
assert inp.shape[-1] == in_features, "GEMM not possible" assert inp.shape[-1] == in_features, "GEMM not possible"
# Configure tensor-parallel communication
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
backward_needs_input = is_grad_enabled and weight.requires_grad backward_needs_input = is_grad_enabled and weight.requires_grad
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.input_cast_comm")
inputmat = inp
inputmat_total = None
with_input_all_gather_nccl = ( with_input_all_gather_nccl = (
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
) )
own_quantized_input = False
# TODO(kwyss): Support FP8 allgather for FP8 block quantization. # Do TP communication in high precision if quantized format
# does not support communication
force_hp_input_gather = ( force_hp_input_gather = (
fp8 and with_input_all_gather_nccl and isinstance(input_quantizer, Float8BlockQuantizer) fp8 and with_input_all_gather_nccl and isinstance(input_quantizer, Float8BlockQuantizer)
) # Perform TP communication in high precision. )
# Configure Userbuffers communication (comm+GEMM overlap)
ub_obj = None
ub_type = None
if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.RS
elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.AG
# ------------------------------------------------------
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# ------------------------------------------------------
nvtx_range_push(f"{nvtx_label}.input_cast_comm")
inputmat = inp # Input tensor to save for backward (maybe sharded)
inputmat_total = None # Input tensor to pass to GEMM (gathered)
own_quantized_input = False
if fp8: if fp8:
assert_dim_for_fp8_exec(inputmat, weight) assert_dim_for_fp8_exec(inputmat, weight)
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not ( if with_input_all_gather_nccl or ub_overlap_ag_fprop: # All-gather input tensor
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
): # Cast local input tensor if needed
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
if fp8 or debug: if fp8 or debug:
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
if with_input_all_gather_nccl: if not force_hp_input_gather and not isinstance(inputmat, QuantizedTensorBase):
if force_hp_input_gather: input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
input_quantizer.set_usage(rowwise=True, columnwise=False) if isinstance(
inputmat_total, _ = gather_along_first_dim( input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
inputmat, tp_group, quantizer=input_quantizer ):
) # All-gather is not supported with FP8 column-wise data
else: input_quantizer.set_usage(columnwise=False)
if not isinstance(inputmat, QuantizedTensorBase):
columnwise_usage = backward_needs_input and isinstance(
input_quantizer, MXFP8Quantizer
)
# force_hp_input_gather should enforce this
assert not isinstance(input_quantizer, Float8BlockQuantizer)
input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
inputmat = input_quantizer(inputmat) inputmat = input_quantizer(inputmat)
own_quantized_input = True own_quantized_input = True
input_quantizer.set_usage(rowwise=True, columnwise=False) else:
inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP
# Initialize gathered input tensor
quantizer = None
if fp8 or debug:
quantizer = input_quantizer
quantizer.set_usage(rowwise=True, columnwise=False)
if with_input_all_gather_nccl: # Perform NCCL all-gather
inputmat_total, _ = gather_along_first_dim( inputmat_total, _ = gather_along_first_dim(
inputmat, inputmat,
tp_group, tp_group,
quantizer=input_quantizer, quantizer=quantizer,
) )
else: elif ub_overlap_ag_fprop: # Initialize Userbuffers all-gather
if ( inputmat_total, _ = fill_userbuffers_buffer_for_all_gather(
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() ub_obj,
and ub_bulk_dgrad inputmat,
): quantizer,
# reduce duplicated transpose in `_fix_gathered_fp8_transpose` tp_group,
input_quantizer.set_usage(rowwise=True, columnwise=False)
else:
input_quantizer.set_usage(
rowwise=True,
columnwise=backward_needs_input,
) )
if not isinstance(inputmat, QuantizedTensorBase):
else: # Do not all-gather input tensor
if fp8 or debug:
if isinstance(inputmat, QuantizedTensorBase):
inputmat.update_usage(rowwise_usage=True)
else:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
inputmat = input_quantizer(inputmat) inputmat = input_quantizer(inputmat)
own_quantized_input = True own_quantized_input = True
elif backward_needs_input:
inputmat.update_usage(rowwise_usage=True, columnwise_usage=True)
inputmat_total = inputmat
else:
inputmat = cast_if_needed(inp, activation_dtype)
if with_input_all_gather_nccl:
inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
else: else:
inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP
inputmat_total = inputmat inputmat_total = inputmat
nvtx_range_pop(f"{nvtx_label}.input_cast_comm") nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
# ------------------------------------------------------
# Input tensor is ready for GEMM...
# ------------------------------------------------------
# Cast weight to expected dtype # ------------------------------------------------------
# Prepare weight tensor
# ------------------------------------------------------
weightmat = weight weightmat = weight
if fp8 or debug: if fp8 or debug:
# Configure quantizer # Configure quantizer
if weight_quantizer is not None: if weight_quantizer is not None:
...@@ -220,7 +231,8 @@ class _Linear(torch.autograd.Function): ...@@ -220,7 +231,8 @@ class _Linear(torch.autograd.Function):
and not in_fp8_activation_recompute_phase() and not in_fp8_activation_recompute_phase()
) )
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
# FP8 cast to workspace buffer
# Get quantized weight
update_workspace = is_first_microbatch is None or is_first_microbatch update_workspace = is_first_microbatch is None or is_first_microbatch
weightmat = module.get_weight_workspace( weightmat = module.get_weight_workspace(
tensor=weight, tensor=weight,
...@@ -231,19 +243,21 @@ class _Linear(torch.autograd.Function): ...@@ -231,19 +243,21 @@ class _Linear(torch.autograd.Function):
fsdp_group=fsdp_group, fsdp_group=fsdp_group,
workspace_dtype=activation_dtype, workspace_dtype=activation_dtype,
) )
weightmat.update_usage(rowwise_usage=True)
else: else:
weightmat = cast_if_needed(weightmat, activation_dtype) weightmat = cast_if_needed(weightmat, activation_dtype) # Cast for AMP
# ------------------------------------------------------
# Weight tensor is ready for GEMM...
# ------------------------------------------------------
# Cast bias to expected dtype # Cast bias to expected dtype
bias_dtype = activation_dtype bias_dtype = activation_dtype
if needs_quantized_gemm(inputmat_total) and activation_dtype == torch.float32: if needs_quantized_gemm(inputmat_total) and activation_dtype == torch.float32:
# cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16
bias_dtype = torch.bfloat16 bias_dtype = torch.bfloat16
bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias
# Configure output quantizer
if output_quantizer is not None:
output_quantizer.set_usage(rowwise=True, columnwise=False)
# Calibrate quantizers if needed # Calibrate quantizers if needed
if not fp8 and fp8_calibration: if not fp8 and fp8_calibration:
if input_quantizer is not None: if input_quantizer is not None:
...@@ -251,44 +265,74 @@ class _Linear(torch.autograd.Function): ...@@ -251,44 +265,74 @@ class _Linear(torch.autograd.Function):
if weight_quantizer is not None: if weight_quantizer is not None:
weight_quantizer.calibrate(weight) weight_quantizer.calibrate(weight)
ub_obj = None # Choose whether to use GEMM kernel with split accumulator
ub_type = None use_split_accumulator = _2X_ACC_FPROP
rs_out = None
out_dtype = activation_dtype
if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.RS
out_shape = [reduce(multiply_op, inp.shape[:-1]) // tp_world_size, out_features]
rs_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device)
elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.AG
if fp8:
assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer."
ub_obj.copy_into_buffer(inputmat_total, input_quantizer, local_chunk=True)
inputmat_total = ub_obj.get_buffer(input_quantizer)
nvtx_range_push(f"{nvtx_label}.gemm")
fprop_gemm_use_split_accumulator = _2X_ACC_FPROP
if fp8: if fp8:
recipe = FP8GlobalStateManager.get_fp8_recipe() recipe = FP8GlobalStateManager.get_fp8_recipe()
if hasattr(recipe, "fp8_gemm_fprop"): if hasattr(recipe, "fp8_gemm_fprop"):
fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
out, *_, rs_out = general_gemm( # Configure output quantizer
if output_quantizer is not None:
output_quantizer.set_usage(rowwise=True, columnwise=False)
# Output buffer for Userbuffers reduce-scatter
reduce_scatter_out = None
if ub_overlap_rs_fprop:
out_shape = list(inp.shape)
out_shape[0] //= tp_world_size
out_shape[-1] = out_features
reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device)
# ------------------------------------------------------
# Forward GEMM
# Note: y = x * w^T
# ------------------------------------------------------
nvtx_range_push(f"{nvtx_label}.gemm")
gemm_out, *_, reduce_scatter_out = general_gemm(
weightmat, weightmat,
inputmat_total, inputmat_total,
get_workspace(), get_workspace(),
quantization_params=output_quantizer, quantization_params=output_quantizer,
out_dtype=out_dtype, out_dtype=activation_dtype,
bias=bias, bias=bias,
use_split_accumulator=fprop_gemm_use_split_accumulator, use_split_accumulator=use_split_accumulator,
ub=ub_obj, ub=ub_obj,
ub_type=ub_type, ub_type=ub_type,
extra_output=rs_out, extra_output=reduce_scatter_out,
) )
nvtx_range_pop(f"{nvtx_label}.gemm") nvtx_range_pop(f"{nvtx_label}.gemm")
# ------------------------------------------------------
# Finished forward GEMM...
# ------------------------------------------------------
# ------------------------------------------------------
# Prepare output tensor
# Note: Perform tensor-parallel communication
# ------------------------------------------------------
out = None
if ub_overlap_rs_fprop:
out = reduce_scatter_out
elif parallel_mode == "row" and tp_size > 1:
nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
out = gemm_out
if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel:
if symmetric_ar_type is not None:
out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
else:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
else:
out = gemm_out
# ------------------------------------------------------
# Output tensor is ready to return...
# ------------------------------------------------------
# ------------------------------------------------------
# Cache state for backward pass
# ------------------------------------------------------
if is_grad_enabled: if is_grad_enabled:
ctx.weight_quantizer = weight_quantizer ctx.weight_quantizer = weight_quantizer
...@@ -388,19 +432,9 @@ class _Linear(torch.autograd.Function): ...@@ -388,19 +432,9 @@ class _Linear(torch.autograd.Function):
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
ctx.wgrad_store = wgrad_store ctx.wgrad_store = wgrad_store
# Row Parallel Linear # ------------------------------------------------------
if ub_overlap_rs_fprop: # Cached state for backward pass is ready...
out = rs_out # ------------------------------------------------------
elif parallel_mode == "row":
nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel:
if symmetric_ar_type is not None:
out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
else:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
return out return out
...@@ -414,28 +448,11 @@ class _Linear(torch.autograd.Function): ...@@ -414,28 +448,11 @@ class _Linear(torch.autograd.Function):
nvtx_label = f"{nvtx_label}.{ctx.ub_name}" nvtx_label = f"{nvtx_label}.{ctx.ub_name}"
with torch.cuda.nvtx.range("_Linear_backward"): with torch.cuda.nvtx.range("_Linear_backward"):
if (
ctx.fp8
and any(
[
ctx.ub_overlap_ag,
ctx.ub_overlap_rs_dgrad,
ctx.ub_bulk_dgrad,
ctx.ub_bulk_wgrad,
]
)
and (ctx.fp8_recipe is not None)
):
if not ctx.fp8_recipe.float8_per_tensor_scaling():
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
saved_tensors = ctx.saved_tensors saved_tensors = ctx.saved_tensors
inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking
restore_from_saved(ctx.tensor_objects, saved_tensors) restore_from_saved(ctx.tensor_objects, saved_tensors)
) )
# Delete the references to tensor objects once they've been consumed # Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors. # by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None ctx.tensor_objects = None
...@@ -465,69 +482,55 @@ class _Linear(torch.autograd.Function): ...@@ -465,69 +482,55 @@ class _Linear(torch.autograd.Function):
) )
nvtx_range_pop(f"{nvtx_label}.fsdp_gather") nvtx_range_pop(f"{nvtx_label}.fsdp_gather")
# Configure Userbuffers communication (comm+GEMM overlap)
ctx.ub_obj_gradout = None ctx.ub_obj_gradout = None
ub_obj_dgrad = None ub_obj_dgrad = None
ub_obj_wgrad = None ub_obj_wgrad = None
ub_type_dgrad = None ub_type_dgrad = None
ub_type_wgrad = None ub_type_wgrad = None
dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
rs_out = None
dgrad_bulk = None
if ctx.ub_overlap_ag: if ctx.ub_overlap_ag:
# Overlap grad_output all-gather with dgrad compute # Overlap grad_output all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG ub_type_dgrad = tex.CommOverlapType.AG
elif ctx.ub_overlap_rs_dgrad: elif ctx.ub_overlap_rs_dgrad:
# Overlap dgrad reduce-scatter with dgrad compute # Overlap dgrad reduce-scatter with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.RS ub_type_dgrad = tex.CommOverlapType.RS
rs_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
)
else: else:
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
# Overlap inputmat all-gather with dgrad compute # Overlap inputmat all-gather with dgrad compute
# NOTE: Copying into communication buffer will always prefer rowwise data,
# and will copy columnwise data if rowwise does not exist. In that case,
# the all-gather will apply to the leading dimension of the transpose,
# which then needs to be interleaved correctly before WGRAD.
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG ub_type_dgrad = tex.CommOverlapType.AG
ub_obj_dgrad.copy_into_buffer(inputmat, ctx.input_quantizer, local_chunk=True)
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
# Overlap dgrad reduce-scatter with wgrad compute # Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad")
ub_type_wgrad = tex.CommOverlapType.RS ub_type_wgrad = tex.CommOverlapType.RS
ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer)
dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer) # --------------------------------------------------
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# --------------------------------------------------
# Unmodified grad output tensor
grad_output_arg = grad_output
# Configure quantizer for grad output tensor # Configure quantizer for grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM # Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage # requires column-wise usage
if ctx.grad_output_quantizer is not None: if ctx.grad_output_quantizer is not None:
rowwise_usage = True quantizer = ctx.grad_output_quantizer
columnwise_usage = True quantizer.set_usage(rowwise=True, columnwise=True)
if ctx.ub_overlap_ag and isinstance( if ctx.ub_overlap_ag:
ctx.grad_output_quantizer, # Userbuffers only supports communication for one
(Float8Quantizer, Float8CurrentScalingQuantizer), # tensor usage at a time. Configure quantizer with
): # usage for only dgrad GEMM.
# If data is in FP8 and communication is handled quantizer.set_usage(columnwise=False)
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage = False
ctx.grad_output_quantizer.set_usage(
rowwise=rowwise_usage,
columnwise=columnwise_usage,
)
# Prepare grad output tensor # Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.grad_output_preprocess") nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
( (
grad_output, grad_output,
...@@ -540,12 +543,21 @@ class _Linear(torch.autograd.Function): ...@@ -540,12 +543,21 @@ class _Linear(torch.autograd.Function):
) )
nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess") nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
# Launch tensor-parallel communication for input tensor # --------------------------------------------------
# Grad output tensor is ready for computing grad input...
# --------------------------------------------------
# --------------------------------------------------
# Prepare input tensor
# Note: Input tensor is needed for wgrad GEMM.
# Tensor-parallel communication is overlapped with dgrad
# GEMM.
# --------------------------------------------------
inputmat_total = None inputmat_total = None
inputmat_total_work = None inputmat_total_work = None
if ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad: if ctx.backward_input_needs_gather:
quantizer = None quantizer = None
if ctx.fp8 or ctx.debug: if (ctx.fp8 or ctx.debug) and not ctx.force_hp_input_gather:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually # If data is in FP8, we compute FP8 transposes manually
...@@ -553,72 +565,92 @@ class _Linear(torch.autograd.Function): ...@@ -553,72 +565,92 @@ class _Linear(torch.autograd.Function):
else: else:
# wgrad GEMM requires input with column-wise usage # wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True) quantizer.set_usage(rowwise=False, columnwise=True)
if ctx.ub_bulk_dgrad:
inputmat_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_dgrad,
inputmat,
quantizer,
ctx.tp_group,
)
else:
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
gather_quantizer = None if ctx.force_hp_input_gather else quantizer
inputmat_total, inputmat_total_work = gather_along_first_dim( inputmat_total, inputmat_total_work = gather_along_first_dim(
inputmat, inputmat,
ctx.tp_group, ctx.tp_group,
async_op=True, async_op=True,
quantizer=gather_quantizer, quantizer=quantizer,
) )
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
else: else:
inputmat_total = inputmat inputmat_total = inputmat
# --------------------------------------------------
# Input tensor is ready for computing grad weight...
# --------------------------------------------------
# Check whether to output wgrad GEMM directly into main grad # --------------------------------------------------
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# Compute grad input tensor # Compute grad input tensor
# --------------------------------------------------
dgrad = None dgrad = None
dgrad_work = None dgrad_work = None
if ctx.requires_dgrad: if ctx.requires_dgrad:
# Update quantizer # Make sure required data is available
if ctx.grad_input_quantizer is not None: if isinstance(grad_output, QuantizedTensorBase):
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) grad_output.update_usage(rowwise_usage=True)
# dgrad GEMM if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorBase):
nvtx_range_push(f"{nvtx_label}.dgrad_gemm") weight_fp8.update_usage(columnwise_usage=True)
dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator = _2X_ACC_DGRAD
if ctx.fp8: if ctx.fp8:
recipe = ctx.fp8_recipe recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_dgrad"): if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_gemm_use_split_accumulator = ( use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
recipe.fp8_gemm_dgrad.use_split_accumulator
)
if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorBase): # Update grad input quantizer
weight_fp8.update_usage( if ctx.grad_input_quantizer is not None:
rowwise_usage=ctx.weight_quantizer.rowwise_usage, ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
columnwise_usage=ctx.weight_quantizer.columnwise_usage,
# Output buffers for Userbuffers reduce-scatter
gemm_out = None
reduce_scatter_out = None
if ctx.ub_overlap_rs_dgrad:
reduce_scatter_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device
) )
elif ctx.ub_bulk_wgrad:
gemm_out = ub_obj_wgrad.get_buffer(local_chunk=False)
dgrad, *_, rs_out = general_gemm( # dgrad GEMM
# Note: dx = dy * w
nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
gemm_out, *_, reduce_scatter_out = general_gemm(
weight_fp8, weight_fp8,
grad_output, grad_output,
get_workspace(), get_workspace(),
layout="NN", layout="NN",
grad=True, grad=True,
quantization_params=ctx.grad_input_quantizer, quantization_params=ctx.grad_input_quantizer,
out=dgrad_bulk, out=gemm_out,
out_dtype=ctx.activation_dtype, out_dtype=ctx.activation_dtype,
use_split_accumulator=dgrad_gemm_use_split_accumulator, use_split_accumulator=use_split_accumulator,
ub=ub_obj_dgrad, ub=ub_obj_dgrad,
ub_type=ub_type_dgrad, ub_type=ub_type_dgrad,
extra_output=rs_out, extra_output=reduce_scatter_out,
bulk_overlap=ctx.ub_bulk_dgrad, bulk_overlap=ctx.ub_bulk_dgrad,
) )
nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") nvtx_range_pop(f"{nvtx_label}.dgrad_gemm")
# Launch tensor-parallel communication # Prepare grad input tensor
# Note: Perform tensor-parallel communication
if ctx.ub_overlap_rs_dgrad: if ctx.ub_overlap_rs_dgrad:
dgrad = rs_out dgrad = reduce_scatter_out
elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad: elif ctx.ub_bulk_wgrad:
dgrad = ub_obj_wgrad.get_buffer(local_chunk=True)
elif ctx.parallel_mode == "column" and ctx.tp_size > 1:
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad")
dgrad = gemm_out
if ctx.sequence_parallel: if ctx.sequence_parallel:
dgrad, dgrad_work = reduce_scatter_along_first_dim( dgrad, dgrad_work = reduce_scatter_along_first_dim(
dgrad, dgrad,
...@@ -628,41 +660,55 @@ class _Linear(torch.autograd.Function): ...@@ -628,41 +660,55 @@ class _Linear(torch.autograd.Function):
else: else:
dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad")
else:
dgrad = gemm_out
# --------------------------------------------------
# Grad input tensor has been computed...
# --------------------------------------------------
# --------------------------------------------------
# Compute grad weight
# --------------------------------------------------
# Compute grad weight tensor
wgrad = None wgrad = None
if ctx.requires_wgrad: if ctx.requires_wgrad:
# Synchronize tensor-parallel communication for input tensor # Prepare input tensor
if ctx.ub_bulk_dgrad: # Note: Synchronize tensor-parallel communication and
inputmat_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer) # make sure required data is available
if ctx.fp8:
if inputmat._data is None:
# All-gather executed on columnwise data and result is in rowwise data,
# so we need to fix the interleaving before WGRAD.
inputmat_total = _fix_gathered_fp8_transpose(
inputmat_total, ctx.tp_size
)
elif not is_non_tn_fp8_gemm_supported():
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
inputmat_total._create_transpose()
if inputmat_total_work is not None: if inputmat_total_work is not None:
inputmat_total_work.wait() inputmat_total_work.wait()
inputmat_total_work = None inputmat_total_work = None
if ctx.input_quantizer is not None and not isinstance( if ctx.fp8 or ctx.debug:
inputmat_total, QuantizedTensorBase if isinstance(inputmat_total, QuantizedTensorBase):
): inputmat_total.update_usage(columnwise_usage=True)
# Async gather in BF16 does not asynchronously else:
# call quantizer after gather.
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
inputmat_total = ctx.input_quantizer(inputmat_total) inputmat_total = ctx.input_quantizer(inputmat_total)
# Make sure GEMM inputs have required data # Prepare grad output tensor
if isinstance(inputmat_total, QuantizedTensorBase): # Note: Synchronize tensor-parallel communication and
inputmat_total.update_usage(columnwise_usage=True) # make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output, _ = gather_along_first_dim(
grad_output_arg,
ctx.tp_group,
quantizer=ctx.grad_output_quantizer,
)
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase): if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(columnwise_usage=True) grad_output.update_usage(columnwise_usage=True)
else:
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output = ctx.grad_output_quantizer(grad_output)
# Figure out whether to use split accumulator # Figure out whether to use split accumulator
use_split_accumulator = _2X_ACC_WGRAD use_split_accumulator = _2X_ACC_WGRAD
...@@ -671,54 +717,95 @@ class _Linear(torch.autograd.Function): ...@@ -671,54 +717,95 @@ class _Linear(torch.autograd.Function):
if hasattr(recipe, "fp8_gemm_wgrad"): if hasattr(recipe, "fp8_gemm_wgrad"):
use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator
# Output buffer for overlapping grad input # Figure out whether to output wgrad GEMM directly into main grad
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# Output buffer for overlapping FP8 grad input
# reduce-scatter with wgrad GEMM # reduce-scatter with wgrad GEMM
reduce_scatter_out = None
if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf(): if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf():
rs_out = torch.empty( reduce_scatter_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device
) )
# wgrad GEMM # Arguments to include in wgrad GEMM closure
# Note: Fuse with bgrad computation if needed wgrad_gemm_kwargs = {
nvtx_range_push(f"{nvtx_label}.wgrad_gemm") "workspace": get_workspace(),
general_gemm_wgrad = functools.partial( "out_dtype": (
general_gemm,
out_dtype=(
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
), ),
workspace=get_workspace(), "quantization_params": ctx.grad_weight_quantizer,
layout="NT", "accumulate": accumulate_wgrad_into_param_main_grad,
grad=True, "layout": "NT",
bias=(bias if (grad_bias is None and not ctx.fp8) else None), "out": main_grad if ctx.fuse_wgrad_accumulation else None,
out=main_grad if ctx.fuse_wgrad_accumulation else None, "bias": (bias if (grad_bias is None and not ctx.fp8) else None),
use_split_accumulator=use_split_accumulator, "use_split_accumulator": use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad, "grad": True,
quantization_params=ctx.grad_weight_quantizer, "ub": ub_obj_wgrad,
ub=ub_obj_wgrad, "ub_type": ub_type_wgrad,
ub_type=ub_type_wgrad, "extra_output": reduce_scatter_out,
extra_output=rs_out, "bulk_overlap": ctx.ub_bulk_wgrad,
bulk_overlap=ctx.ub_bulk_wgrad, }
)
def wgrad_gemm(
x: torch.Tensor,
dy: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform wgrad GEMM: dw = dy^T * x
May be fused with bgrad computation.
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
return dw, db
# Choose whether to call wgrad GEMM now or delay
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([inputmat_total, grad_output], general_gemm_wgrad) if (
wgrad_gemm_kwargs["ub"] is not None
or wgrad_gemm_kwargs["ub_type"] is not None
or wgrad_gemm_kwargs["extra_output"] is not None
or wgrad_gemm_kwargs["bulk_overlap"]
):
raise NotImplementedError(
"Delayed weight grad computation is not supported "
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx.wgrad_store.put([inputmat_total, grad_output], wgrad_gemm)
else: else:
wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(inputmat_total, grad_output)
# Call wgrad GEMM now
wgrad, grad_bias_ = wgrad_gemm(inputmat_total, grad_output)
# Update grad bias if needed
if grad_bias is None: if grad_bias is None:
grad_bias = grad_bias_ grad_bias = grad_bias_
del grad_bias_ del grad_bias_
# Deallocate input tensor # Deallocate input tensor if permitted
if ctx.owns_input: if ctx.owns_input:
clear_tensor_data(inputmat_total) clear_tensor_data(inputmat_total)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
if ub_obj_wgrad.is_fp8_ubuf(): if ub_obj_wgrad.is_fp8_ubuf():
dgrad = rs_out dgrad = reduce_scatter_out
else: else:
dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, local_chunk=True) dgrad = ub_obj_wgrad.get_buffer(local_chunk=True).clone()
# --------------------------------------------------
# Grad weight has been computed...
# --------------------------------------------------
# Don't return grad bias if not needed # Don't return grad bias if not needed
if not ctx.use_bias: if not ctx.use_bias:
...@@ -756,6 +843,7 @@ class _Linear(torch.autograd.Function): ...@@ -756,6 +843,7 @@ class _Linear(torch.autograd.Function):
else: else:
wgrad = None wgrad = None
# Update FP8 scaling factors if needed
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors") nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors")
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
......
...@@ -534,6 +534,8 @@ class BasicLinear(BasicOperation): ...@@ -534,6 +534,8 @@ class BasicLinear(BasicOperation):
# Configure input tensor for backward pass # Configure input tensor for backward pass
if with_quantized_compute and isinstance(x_local, QuantizedTensor): if with_quantized_compute and isinstance(x_local, QuantizedTensor):
if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True) x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
return y, x_local, w return y, x_local, w
...@@ -622,7 +624,10 @@ class BasicLinear(BasicOperation): ...@@ -622,7 +624,10 @@ class BasicLinear(BasicOperation):
# Check datatype # Check datatype
if dtype is None: if dtype is None:
if weight is not None:
dtype = weight.dtype dtype = weight.dtype
else:
dtype = grad_output.dtype
dtype = canonicalize_dtype(dtype) dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16): if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
...@@ -814,7 +819,7 @@ class BasicLinear(BasicOperation): ...@@ -814,7 +819,7 @@ class BasicLinear(BasicOperation):
x_async = None x_async = None
dy_async = None dy_async = None
# Check grad input tensor # Check grad weight tensor
dw = grad_weight dw = grad_weight
dw_dtype = dtype dw_dtype = dtype
if dw is None: if dw is None:
......
...@@ -4,30 +4,27 @@ ...@@ -4,30 +4,27 @@
"""Linear layer backward with Userbuffers communication.""" """Linear layer backward with Userbuffers communication."""
# pylint: skip-file ### TODO Debug Userbuffers support
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable from typing import Optional
from typing import Any, Optional
import warnings import warnings
import torch import torch
from transformer_engine_torch import CommOverlapAlgo from transformer_engine_torch import CommOverlapType
from ...cpp_extensions import general_gemm from ...cpp_extensions import general_gemm
from ...distributed import get_distributed_world_size from ...distributed import gather_along_first_dim, get_distributed_world_size
from ...float8_tensor import Float8Tensor from ...module.base import (
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype fill_userbuffers_buffer_for_all_gather,
from ...module.base import get_ub, get_workspace get_ub,
get_workspace,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data
from ..basic import BasicLinear, Bias, ReduceScatter from ..basic import BasicLinear, Bias, ReduceScatter
from ..op import FusedOperation, FusibleOperation, OperationContext from ..op import FusedOperation, FusibleOperation, OperationContext
from .._common import (
convert_tensor,
get_fp8_meta_from_fp8_tensor,
is_float8_tensor,
reshape,
)
class UserbuffersBackwardLinear(FusedOperation): class UserbuffersBackwardLinear(FusedOperation):
...@@ -47,9 +44,6 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -47,9 +44,6 @@ class UserbuffersBackwardLinear(FusedOperation):
reduce_scatter: Optional[ReduceScatter], reduce_scatter: Optional[ReduceScatter],
) -> None: ) -> None:
### TODO Debug Userbuffers support
raise NotImplementedError("Userbuffers support has been broken by recent refactors")
# Basic operations that comprise this fused operation # Basic operations that comprise this fused operation
op_idxs = {"linear": None, "bias": None, "reduce_scatter": None} op_idxs = {"linear": None, "bias": None, "reduce_scatter": None}
ops = [] ops = []
...@@ -89,9 +83,8 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -89,9 +83,8 @@ class UserbuffersBackwardLinear(FusedOperation):
grad_output: torch.Tensor, grad_output: torch.Tensor,
input: Optional[torch.Tensor], # pylint: disable=redefined-builtin input: Optional[torch.Tensor], # pylint: disable=redefined-builtin
weight: Optional[torch.Tensor], weight: Optional[torch.Tensor],
input_dims: Iterable[int],
weight_dims: Iterable[int],
*, *,
input_requires_grad: bool = True,
weight_requires_grad: bool = True, weight_requires_grad: bool = True,
bias_requires_grad: bool = False, bias_requires_grad: bool = False,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
...@@ -102,11 +95,11 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -102,11 +95,11 @@ class UserbuffersBackwardLinear(FusedOperation):
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
tensor_parallel_size: Optional[int] = None, tensor_parallel_size: Optional[int] = None,
sequence_parallel: bool = False, sequence_parallel: bool = False,
with_fp8_compute: bool = False, with_quantized_compute: bool = False,
input_fp8_meta: Optional[dict[str, Any]] = None, input_quantizer: Optional[Quantizer] = None,
weight_fp8_meta: Optional[dict[str, Any]] = None, weight_quantizer: Optional[Quantizer] = None,
grad_output_fp8_meta: Optional[dict[str, Any]] = None, grad_output_quantizer: Optional[Quantizer] = None,
grad_input_fp8_meta: Optional[dict[str, Any]] = None, grad_input_quantizer: Optional[Quantizer] = None,
ub_comm_name: str, ub_comm_name: str,
) -> tuple[torch.Tensor, Optional[torch.Tensor], dict]: ) -> tuple[torch.Tensor, Optional[torch.Tensor], dict]:
"""Functional API for backward pass """Functional API for backward pass
...@@ -121,10 +114,6 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -121,10 +114,6 @@ class UserbuffersBackwardLinear(FusedOperation):
weight: torch.Tensor, optional weight: torch.Tensor, optional
Weight tensor. Required to compute loss gradient w.r.t. Weight tensor. Required to compute loss gradient w.r.t.
input. input.
input_dims: iterable of int
Input tensor dimensions
weight_dims: iterable of int
Weight tensor dimensions
weight_requires_grad: bool weight_requires_grad: bool
Whether to compute loss gradient w.r.t. weight tensor Whether to compute loss gradient w.r.t. weight tensor
bias_requires_grad: bool bias_requires_grad: bool
...@@ -146,21 +135,18 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -146,21 +135,18 @@ class UserbuffersBackwardLinear(FusedOperation):
parallelism, i.e. distributing input or output tensors parallelism, i.e. distributing input or output tensors
along outer dimension (sequence or batch dim) when not along outer dimension (sequence or batch dim) when not
distributing along inner dimension (embedding dim) distributing along inner dimension (embedding dim)
with_fp8_compute: bool, default = `False` with_quantized_compute: bool, default = `False`
Whether to perform compute in FP8 Whether to perform compute with quantized data.
input_fp8_meta: dict, optional input_quantizer: Quantizer, optional
FP8 metadata for casting input tensor to FP8. Required for Builder class for quantized input tensor.
FP8 compute if input is not already in FP8. weight_quantizer: Quantizer, optional
weight_fp8_meta: dict, optional Builder class for quantized weight tensor.
FP8 metadata for casting weight tensor to FP8. Required for grad_output_quantizer: Quantizer, optional
FP8 compute if weight is not already in FP8. Builder class for quantized loss gradient w.r.t. output
grad_output_fp8_meta: dict, optional tensor.
FP8 metadata for casting loss gradient w.r.t. output grad_input_quantizer: Quantizer, optional
tensor to FP8. Required if output grad is not already in Builder class for quantized loss gradient w.r.t. input
FP8. tensor.
grad_input_fp8_meta: dict, optional
FP8 metadata for casting loss gradient w.r.t. input
tensor to FP8
ub_comm_name: str ub_comm_name: str
Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is
used to access the corresponding Userbuffers communicators used to access the corresponding Userbuffers communicators
...@@ -183,37 +169,24 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -183,37 +169,24 @@ class UserbuffersBackwardLinear(FusedOperation):
# Check device # Check device
if device is None: if device is None:
if weight is not None:
device = weight.device device = weight.device
else:
device = grad_output.device
device = canonicalize_device(device) device = canonicalize_device(device)
if device.type != "cuda": if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})") raise ValueError(f"Only CUDA devices are supported (got {device})")
# Check datatype # Check datatype
if dtype is None: if dtype is None:
if weight is not None:
dtype = weight.dtype dtype = weight.dtype
else:
dtype = grad_output.dtype
dtype = canonicalize_dtype(dtype) dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16): if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
# Input tensor dims
output_dims = tuple(grad_output.size())
input_dims = tuple(input_dims)
weight_dims = tuple(weight_dims)
if len(weight_dims) != 2:
raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})")
if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]:
raise ValueError(
f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={weight_dims}) "
"are not compatible"
)
if weight_dims[0] != output_dims[-1]:
raise ValueError(
f"Grad output tensor (shape={output_dims}) "
f"and weight tensor (shape={weight_dims}) "
"are not compatible"
)
# Check tensor parallel group # Check tensor parallel group
if tensor_parallel_size is None: if tensor_parallel_size is None:
tensor_parallel_size = get_distributed_world_size(tensor_parallel_group) tensor_parallel_size = get_distributed_world_size(tensor_parallel_group)
...@@ -227,373 +200,283 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -227,373 +200,283 @@ class UserbuffersBackwardLinear(FusedOperation):
if not sequence_parallel: if not sequence_parallel:
raise RuntimeError(f"Invalid configuration for Userbuffers ({sequence_parallel=})") raise RuntimeError(f"Invalid configuration for Userbuffers ({sequence_parallel=})")
# Check if FP8 is enabled # dgrad GEMM is required
if with_fp8_compute: if not input_requires_grad:
if grad_output_fp8_meta is None and not is_float8_tensor(grad_output): warnings.warn(
raise ValueError("No FP8 metadata was provided for casting output gradient to FP8") "Linear input doesn't require gradient, "
else: "but Userbuffers implementation requires dgrad GEMM."
input_fp8_meta = None
weight_fp8_meta = None
grad_output_fp8_meta = None
grad_input_fp8_meta = None
with_fp8_grad_input = (
with_fp8_compute
and tensor_parallel_mode != "column"
and grad_input_fp8_meta is not None
) )
input_requires_grad = True
# Check quantizers
if with_quantized_compute:
if weight_requires_grad and input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if input_requires_grad and weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
if grad_output_quantizer is None:
raise ValueError("Missing quantizer for grad output tensor")
if grad_input_quantizer is not None:
raise ValueError("Quantized grad input is not supported")
else:
input_quantizer = None
weight_quantizer = None
grad_output_quantizer = None
grad_input_quantizer = None
# Get Userbuffers communicators and algorithms # Get Userbuffers communicators
# Note: communication patterns are (1) overlap dy all-gather # Note: Communication patterns are (1) overlap dy all-gather
# with dgrad GEMM, (2) overlap x all-gather with dgrad GEMM # with dgrad GEMM, (2) overlap x all-gather with dgrad GEMM
# and dx reduce-scatter with wgrad GEMM, (3) overlap dx # and dx reduce-scatter with wgrad GEMM, (3) overlap dx
# reduce-scatter with dgrad GEMM. # reduce-scatter with dgrad GEMM
with_ub_all_gather_dy = False ub_comm_dgrad = None
with_ub_reduce_scatter_dx = False ub_comm_wgrad = None
with_ub_all_gather_x = False ub_type_dgrad = None
ub_comm_dy = None ub_type_wgrad = None
ub_comm_dx = None with_bulk_overlap = False
ub_comm_x = None with_dgrad_all_gather_dy = False
ub_algo_dy = None with_dgrad_reduce_scatter_dx = False
ub_algo_dx = None with_dgrad_all_gather_x = False
ub_algo_x = None with_wgrad_reduce_scatter_dx = False
if tensor_parallel_mode == "row": if tensor_parallel_mode == "row":
with_ub_all_gather_dy = True ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad")
ub_comm_dy = get_ub(ub_comm_name + "_dgrad") ub_type_dgrad = CommOverlapType.AG
if with_fp8_compute and ub_comm_dy.is_atomic_gemm(): with_dgrad_all_gather_dy = True
ub_algo_dy = CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
ub_algo_dy = CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
elif tensor_parallel_mode == "column": elif tensor_parallel_mode == "column":
with_ub_reduce_scatter_dx = True if input_requires_grad and weight_requires_grad:
if weight_requires_grad: with_bulk_overlap = True
with_ub_all_gather_x = True ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad")
ub_comm_dx = get_ub(ub_comm_name + "_wgrad") ub_type_dgrad = CommOverlapType.AG
ub_comm_x = get_ub(ub_comm_name + "_dgrad") with_dgrad_all_gather_x = True
ub_algo_dx = CommOverlapAlgo.BULK_OVERLAP_RS ub_comm_wgrad = get_ub(ub_comm_name + "_wgrad")
ub_algo_x = CommOverlapAlgo.BULK_OVERLAP_AG ub_type_wgrad = CommOverlapType.RS
with_wgrad_reduce_scatter_dx = True
if ub_comm_wgrad.is_fp8_ubuf():
raise RuntimeError(
"Userbuffers reduce-scatter is not supported with FP8 buffers"
)
else: else:
with_ub_all_gather_x = False ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad")
ub_comm_dx = get_ub(ub_comm_name + "_dgrad") ub_type_dgrad = CommOverlapType.RS
is_atomic_gemm = with_fp8_compute and ub_comm_dx.is_atomic_gemm() with_dgrad_reduce_scatter_dx = True
ub_algo_dx = { if ub_comm_dgrad.is_fp8_ubuf():
(True, True): CommOverlapAlgo.ATOMIC_GEMM_RS_P2P, raise RuntimeError(
(True, False): CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P, "Userbuffers reduce-scatter is not supported with FP8 buffers"
(False, True): CommOverlapAlgo.ATOMIC_GEMM_RS,
(False, False): CommOverlapAlgo.SPLIT_PIPELINED_RS,
}[(ub_comm_dx.is_p2p_overlap(), is_atomic_gemm)]
# Check grad output tensor
# Note: Possibly fuse cast with computing grad bias
dy_local = reshape(
grad_output,
(-1, output_dims[-1]),
device=device,
dtype=dtype,
) )
# Compute grad bias if needed
db = None db = None
db_async = None db_async = None
if bias_requires_grad and with_fp8_compute and with_ub_all_gather_dy: if bias_requires_grad:
# We don't have a grad bias impl that takes FP8 input. For db = grad_output.sum(tuple(range(grad_output.dim() - 1)))
# cases where we cast to FP8 and all-gather, it's better if tensor_parallel_mode == "row":
# to compute the grad bias on ungathered, non-FP8 values.
db = dy_local.sum(dim=0)
db_async = torch.distributed.all_reduce( db_async = torch.distributed.all_reduce(
db, db,
group=tensor_parallel_group, group=tensor_parallel_group,
async_op=True, async_op=True,
) )
if with_fp8_compute and not is_float8_tensor(dy_local):
fp8_dtype = get_fp8_te_dtype( # Cast grad output tensor dtype if needed
grad_output_fp8_meta["recipe"], dy_local = grad_output
fprop_tensor=False, if with_quantized_compute:
) if not isinstance(dy_local, QuantizedTensorBase):
if bias_requires_grad and db is None: with_columnwise = weight_requires_grad
# Fused cast-transpose-bgrad if (
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) with_columnwise
fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=device) and with_dgrad_all_gather_dy
db, data, data_transpose = fp8_cast_transpose_bgrad_fused( and not isinstance(grad_output_quantizer, MXFP8Quantizer)
dy_local, ):
grad_output_fp8_meta[fp8_meta_key], with_columnwise = False
0, grad_output_quantizer.set_usage(
fp8_dtype, rowwise=True,
scale_inv=fp8_scale_inv, columnwise=with_columnwise,
)
if with_ub_all_gather_dy:
data = ub_comm_dy.get_ubuf_output(0).copy_(data)
dy_local = Float8Tensor(
data=data,
fp8_meta=grad_output_fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
dtype=dtype,
data_transpose=data_transpose,
) )
dy_local = grad_output_quantizer(dy_local)
else: else:
dy_local = Float8Tensor.to_float8( if isinstance(dy_local, QuantizedTensorBase):
dy_local, dy_local = dy_local.dequantize(dtype=dtype)
fp8_meta=grad_output_fp8_meta, elif dy_local.dtype != dtype:
fp8_meta_forward=False, dy_local = dy_local.to(dtype=dtype)
fp8_meta_index=0,
fp8_dtype=fp8_dtype, # Cast weight tensor dtype if needed
data=(ub_comm_dy.get_ubuf_output(0) if with_ub_all_gather_dy else None), if weight is None:
with_transpose_cache=(not with_ub_all_gather_dy), raise ValueError("Weight tensor is required to compute input grad")
) w = weight
elif not with_fp8_compute and is_float8_tensor(dy_local): if with_quantized_compute:
if with_ub_all_gather_dy: if not isinstance(w, QuantizedTensorBase):
ub_local_buffer = ub_comm_dy.get_ubuf_output(0) weight_quantizer.set_usage(columnwise=True)
dy_local = ub_local_buffer.copy_(dy_local) w = weight_quantizer(w)
else: else:
dy_local = dy_local.dequantize() if isinstance(w, QuantizedTensorBase):
w = w.dequantize(dtype=dtype)
if bias_requires_grad and db is None and with_fp8_compute and with_ub_all_gather_dy: elif w.dtype != dtype:
# We don't have a fused grad bias impl that takes FP8 w = w.to(dtype=dtype)
# input. For cases where we cast to FP8 and all-gather,
# it's better to compute the grad bias on ungathered,
# non-FP8 values.
db = dy_local.sum(dim=0)
db_async = torch.distributed.all_reduce(
db,
group=tensor_parallel_group,
async_op=True,
)
# Check input tensor # Cast input tensor dtype if needed
x_local = None x_local = None
if weight_requires_grad: if weight_requires_grad:
x_local = reshape( if input is None:
input, raise ValueError("Input tensor is required to compute weight grad")
(-1, input_dims[-1]), x_local = input
device=device, if with_quantized_compute:
dtype=dtype, if not isinstance(x_local, QuantizedTensorBase):
) input_quantizer.set_usage(columnwise=True)
if with_fp8_compute and not is_float8_tensor(x_local): x_local = input_quantizer(x_local)
fp8_dtype = get_fp8_te_dtype( else:
input_fp8_meta["recipe"], if isinstance(x_local, QuantizedTensorBase):
fprop_tensor=True, x_local = x_local.dequantize(dtype=dtype)
) elif x_local.dtype != dtype:
x_local = Float8Tensor.to_float8( x_local = x_local.to(dtype=dtype)
x_local,
fp8_meta=input_fp8_meta, # dgrad GEMM
fp8_meta_forward=True, dx_local = None
fp8_meta_index=0, dx = None
fp8_dtype=fp8_dtype, dy = None
data=(ub_comm_x.get_ubuf_output(0) if with_ub_all_gather_x else None), x = None
with_transpose_cache=(not with_ub_all_gather_x), if input_requires_grad:
# Initialize grad output
if with_dgrad_all_gather_dy:
if grad_output_quantizer is not None:
grad_output_quantizer.set_usage(rowwise=True, columnwise=False)
dy, _ = fill_userbuffers_buffer_for_all_gather(
ub_comm_dgrad,
dy_local,
grad_output_quantizer,
tensor_parallel_group,
) )
elif not with_fp8_compute and is_float8_tensor(x_local):
if with_ub_all_gather_x:
ub_local_buffer = ub_comm_x.get_ubuf_output(0)
x_local = ub_local_buffer.copy_(x_local)
else: else:
x_local = x_local.dequantize() dy = dy_local
# Check weight tensor # Construct grad input tensor if needed
w = convert_tensor( if with_dgrad_reduce_scatter_dx or with_wgrad_reduce_scatter_dx:
weight, dx_size = list(dy.size())
device=device, dx_size[-1] = w.size(-1)
dx_local_size = list(dx_size)
dx_local_size[0] //= tensor_parallel_size
if with_dgrad_reduce_scatter_dx:
dx_local = torch.empty(
dx_local_size,
dtype=dtype, dtype=dtype,
memory_format=torch.contiguous_format, device=device,
) )
if with_fp8_compute and not is_float8_tensor(w): elif with_wgrad_reduce_scatter_dx:
fp8_dtype = get_fp8_te_dtype( dx_local = ub_comm_wgrad.get_buffer(
weight_fp8_meta["recipe"], local_chunk=True,
fprop_tensor=True, shape=dx_local_size,
) )
w = Float8Tensor.to_float8( dx = ub_comm_wgrad.get_buffer(
w, local_chunk=False,
fp8_meta=weight_fp8_meta, shape=dx_size,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
with_transpose_cache=True,
) )
elif not with_fp8_compute and is_float8_tensor(w):
w = w.dequantize()
# Initialize buffers for UB all-gather if needed
dy = dy_local
x = x_local
if with_ub_all_gather_dy:
ub_local_buffer = ub_comm_dy.get_ubuf_output(0)
ub_global_buffer = ub_comm_dy.get_ubuf_output(1)
if with_fp8_compute:
dy = Float8Tensor.make_like(dy_local, data=ub_global_buffer)
if dy_local._data.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(dy_local._data)
else:
dy = ub_global_buffer
if dy_local.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(dy_local)
if with_ub_all_gather_x:
ub_local_buffer = ub_comm_x.get_ubuf_output(0)
ub_global_buffer = ub_comm_x.get_ubuf_output(1)
if with_fp8_compute:
x = Float8Tensor.make_like(x_local, data=ub_global_buffer)
if x_local._data.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(x_local._data)
else:
x = ub_global_buffer
if x_local.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(x_local)
# Construct grad input tensor # Initialize input tensor if needed
dx = None if with_dgrad_all_gather_x:
dx_local = None if input_quantizer is not None:
if with_ub_reduce_scatter_dx: input_quantizer.set_usage(rowwise=False, columnwise=True)
# Initialize buffers for UB reduce-scatter x, _ = fill_userbuffers_buffer_for_all_gather(
dx = ub_comm_dx.get_ubuf_output(1) ub_comm_dgrad,
ub_local_buffer = ub_comm_dx.get_ubuf_output(0) x_local,
if with_ub_all_gather_x: input_quantizer,
dx_local = ub_local_buffer tensor_parallel_group,
else:
dx_local = torch.empty_like(ub_local_buffer)
else:
# Allocate grad input tensor
if with_fp8_grad_input:
fp8_dtype = get_fp8_te_dtype(
grad_input_fp8_meta["recipe"],
fprop_tensor=False,
)
data = torch.empty(
(dy.size(0), w.size(-1)),
dtype=torch.uint8,
device=device,
)
dx = Float8Tensor(
data=data,
fp8_meta=grad_input_fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
dtype=dtype,
) )
else:
dx = torch.empty( # Perform dgrad GEMM
(dy.size(0), w.size(-1)), dx, *_ = general_gemm(
dtype=dtype, w,
device=device, dy,
get_workspace(),
out_dtype=dtype,
quantization_params=grad_input_quantizer,
layout="NN",
out=dx,
use_split_accumulator=_2X_ACC_DGRAD,
grad=True,
ub=ub_comm_dgrad,
ub_type=ub_type_dgrad,
extra_output=dx_local if with_dgrad_reduce_scatter_dx else None,
bulk_overlap=with_bulk_overlap,
) )
if not (with_dgrad_reduce_scatter_dx or with_wgrad_reduce_scatter_dx):
dx_local = dx dx_local = dx
# Allocate grad input tensor # wgrad GEMM
if grad_weight is None: dw = None
if accumulate_into_grad_weight: if weight_requires_grad:
raise ValueError(
"Attempted to accumulate into grad weight bufferwithout providing grad weight" # Initialize grad output
if tensor_parallel_mode == "row" and isinstance(grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, MXFP8 does not
# allow reusing the grad output that was gathered for
# the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
dy, _ = gather_along_first_dim(
grad_output,
tensor_parallel_group,
quantizer=grad_output_quantizer,
) )
grad_weight = torch.empty( if tensor_parallel_mode == "column":
weight_dims, dy = dy_local
dtype=dtype, if dy is None:
device=device, raise RuntimeError(
memory_format=torch.contiguous_format, "wgrad GEMM requires grad output tensor, which has not been initialized"
) )
if isinstance(dy, QuantizedTensorBase):
dy.update_usage(rowwise_usage=False, columnwise_usage=True)
# Perform dgrad GEMM # Initialize input tensor
if with_fp8_compute: if tensor_parallel_mode == "row":
kwargs = {"out": dx, "use_split_accumulator": True} x = x_local
if with_ub_all_gather_dy: if x is None:
kwargs["ub_algo"] = ub_algo_dy raise RuntimeError(
kwargs["ub"] = ub_comm_dy "wgrad GEMM requires input tensor, which has not been initialized"
elif with_ub_all_gather_x:
kwargs["ub_algo"] = ub_algo_x
kwargs["ub"] = ub_comm_x
elif with_ub_reduce_scatter_dx:
kwargs["ub_algo"] = ub_algo_dx
kwargs["ub"] = ub_comm_dx
kwargs["extra_output_tensor"] = dx_local
if with_fp8_grad_input:
fp8_meta, fp8_meta_index = get_fp8_meta_from_fp8_tensor(dx)
kwargs.update(
{
"out": dx._data,
"out_index": fp8_meta_index,
"fp8_meta_tensor": fp8_meta,
"D_dtype": dx._fp8_dtype,
}
) )
fp8_gemm( if isinstance(x, QuantizedTensorBase):
w.transpose_2d(), x.update_usage(rowwise_usage=False, columnwise_usage=True)
w._scale_inv,
0, # Check grad weight tensor
w._fp8_dtype, dw = grad_weight
dy._data, dw_dtype = dtype
dy._scale_inv, if dw is None:
0, if accumulate_into_grad_weight:
dy._fp8_dtype, raise ValueError(
dy.dtype, "Attempted to accumulate into grad weight tensor "
get_workspace(), "without providing grad weight tensor"
**kwargs,
) )
else: else:
kwargs = {"grad": True, "layout": "NN", "out": dx} dw_dtype = dw.dtype
if with_ub_all_gather_dy:
kwargs["ub_algo"] = ub_algo_dy
kwargs["ub"] = ub_comm_dy
elif with_ub_all_gather_x:
kwargs["ub_algo"] = ub_algo_x
kwargs["ub"] = ub_comm_x
elif with_ub_reduce_scatter_dx:
kwargs["ub_algo"] = ub_algo_dx
kwargs["ub"] = ub_comm_dx
kwargs["extra_output_tensor"] = dx_local
gemm(w, dy, dx.dtype, get_workspace(), **kwargs)
grad_input = reshape(dx_local, input_dims)
# Perform wgrad GEMM # Perform wgrad GEMM
if not weight_requires_grad: dw, *_ = general_gemm(
pass
elif with_fp8_compute:
kwargs = {
"accumulate": accumulate_into_grad_weight,
"out": grad_weight,
"use_split_accumulator": True,
}
if with_ub_reduce_scatter_dx:
kwargs["ub_algo"] = ub_algo_dx
kwargs["ub"] = ub_comm_dx
fp8_gemm(
x.transpose_2d(),
x._scale_inv,
0,
x._fp8_dtype,
dy.transpose_2d(),
dy._scale_inv,
0,
dy._fp8_dtype,
grad_weight.dtype,
get_workspace(),
**kwargs,
)
else:
kwargs = {
"accumulate": accumulate_into_grad_weight,
"layout": "NT",
"grad": True,
"use_bias": bias_requires_grad,
"out": grad_weight,
}
if with_ub_reduce_scatter_dx:
kwargs["ub_algo"] = ub_algo_dx
kwargs["ub"] = ub_comm_dx
grad_weight, db, _ = gemm(
x, x,
dy, dy,
grad_weight.dtype,
get_workspace(), get_workspace(),
**kwargs, out_dtype=dw_dtype,
accumulate=accumulate_into_grad_weight,
layout="NT",
out=dw,
use_split_accumulator=_2X_ACC_WGRAD,
grad=True,
ub=ub_comm_wgrad,
ub_type=ub_type_wgrad,
bulk_overlap=with_bulk_overlap,
) )
# Bulk overlap reduce-scatter with non-FP8 buffer is
# in-place. Need to copy grad input tensor to avoid data
# corruption in Userbuffers buffer.
if with_wgrad_reduce_scatter_dx:
dx_local = dx_local.clone()
# Compute grad bias if needed # Compute grad bias if needed
if db_async is not None: if db_async is not None:
db_async.wait() db_async.wait()
if bias_requires_grad: if bias_requires_grad:
if db is None:
db = dy.sum(dim=0)
extra_outputs["grad_bias"] = db extra_outputs["grad_bias"] = db
return grad_input, grad_weight, extra_outputs return dx_local, dw, extra_outputs
def fuser_backward( def fuser_backward(
self, self,
...@@ -633,40 +516,24 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -633,40 +516,24 @@ class UserbuffersBackwardLinear(FusedOperation):
else: else:
accumulate_into_main_grad = False accumulate_into_main_grad = False
# Hackily workaround Userbuffers bug with non-FP8 dgrad
# reduce-scatter overlap
weight_requires_grad = linear_op_ctx.weight_requires_grad
if not linear_op_ctx.with_fp8_compute and not weight_requires_grad:
warnings.warn(
"There is a correctness bug when using Userbuffers "
"to overlap a dgrad reduce-scatter with a non-FP8 dgrad GEMM. "
"Hackily working around by overlapping dgrad reduce-scatter "
"with wgrad GEMM, even though wgrad isn't needed. "
"Please contact Transformer Engine team "
"if you encounter this use-case."
)
weight_requires_grad = True
# Linear backward pass # Linear backward pass
retval = UserbuffersBackwardLinear._functional_backward( retval = UserbuffersBackwardLinear._functional_backward(
grad_output=grad_output, grad_output=grad_output,
input=x_local, input=x_local,
weight=linear_op.weight, weight=linear_op.weight,
input_dims=linear_op_ctx.input_dims, weight_requires_grad=linear_op_ctx.weight_requires_grad,
weight_dims=linear_op.weight.size(),
weight_requires_grad=weight_requires_grad,
bias_requires_grad=(bias_op is not None), bias_requires_grad=(bias_op is not None),
device=linear_op.device,
dtype=linear_op_ctx.dtype, dtype=linear_op_ctx.dtype,
grad_weight=grad_weight, grad_weight=grad_weight,
accumulate_into_grad_weight=accumulate_into_main_grad, accumulate_into_grad_weight=accumulate_into_main_grad,
tensor_parallel_mode=self.tensor_parallel_mode, tensor_parallel_mode=self.tensor_parallel_mode,
tensor_parallel_group=self.tensor_parallel_group, tensor_parallel_group=self.tensor_parallel_group,
sequence_parallel=self.sequence_parallel, sequence_parallel=self.sequence_parallel,
with_fp8_compute=linear_op_ctx.with_fp8_compute, with_quantized_compute=linear_op_ctx.with_quantized_compute,
weight_fp8_meta=linear_op_ctx.weight_fp8_meta, input_quantizer=linear_op_ctx.input_quantizer,
grad_output_fp8_meta=linear_op_ctx.grad_output_fp8_meta, weight_quantizer=linear_op_ctx.weight_quantizer,
grad_input_fp8_meta=linear_op_ctx.grad_input_fp8_meta, grad_output_quantizer=linear_op_ctx.grad_output_quantizer,
grad_input_quantizer=None, # Not supported
ub_comm_name=linear_op._userbuffers_options["comm_name"], ub_comm_name=linear_op._userbuffers_options["comm_name"],
) )
grad_input, grad_weight, extra_outputs = retval grad_input, grad_weight, extra_outputs = retval
...@@ -707,8 +574,6 @@ def fuse_userbuffers_backward_linear( ...@@ -707,8 +574,6 @@ def fuse_userbuffers_backward_linear(
""" """
return ops ### TODO Debug Userbuffers support
# Return immediately if environment is not distributed # Return immediately if environment is not distributed
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
return ops return ops
......
...@@ -4,20 +4,25 @@ ...@@ -4,20 +4,25 @@
"""Linear layer forward with Userbuffers communication.""" """Linear layer forward with Userbuffers communication."""
# pylint: skip-file ### TODO Debug Userbuffers support
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any, Optional from typing import Any, Optional
import torch import torch
from transformer_engine_torch import CommOverlapAlgo from transformer_engine_torch import CommOverlapType
from ...cpp_extensions import general_gemm from ...cpp_extensions import general_gemm
from ...distributed import get_distributed_world_size from ...distributed import get_distributed_world_size
from ...float8_tensor import Float8Tensor from ...fp8 import FP8GlobalStateManager
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype from ...module.base import (
from ...module.base import get_ub, get_workspace fill_userbuffers_buffer_for_all_gather,
get_ub,
get_workspace,
_2X_ACC_FPROP,
)
from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer
from ...tensor.float8_tensor import Float8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ...utils import canonicalize_device, canonicalize_dtype from ...utils import canonicalize_device, canonicalize_dtype
from ..basic import BasicLinear, Bias, ReduceScatter from ..basic import BasicLinear, Bias, ReduceScatter
from ..op import ( from ..op import (
...@@ -26,12 +31,6 @@ from ..op import ( ...@@ -26,12 +31,6 @@ from ..op import (
FusibleOperation, FusibleOperation,
OperationContext, OperationContext,
) )
from .._common import (
convert_tensor,
get_fp8_meta_from_fp8_tensor,
is_float8_tensor,
reshape,
)
class UserbuffersForwardLinear(FusedOperation): class UserbuffersForwardLinear(FusedOperation):
...@@ -51,9 +50,6 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -51,9 +50,6 @@ class UserbuffersForwardLinear(FusedOperation):
reduce_scatter: Optional[ReduceScatter], reduce_scatter: Optional[ReduceScatter],
) -> None: ) -> None:
### TODO Debug Userbuffers support
raise NotImplementedError("Userbuffers support has been broken by recent refactors")
# Basic operations that comprise this fused operation # Basic operations that comprise this fused operation
op_idxs = {"linear": 0, "bias": None, "reduce_scatter": None} op_idxs = {"linear": 0, "bias": None, "reduce_scatter": None}
ops = [linear] ops = [linear]
...@@ -98,10 +94,10 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -98,10 +94,10 @@ class UserbuffersForwardLinear(FusedOperation):
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
tensor_parallel_size: Optional[int] = None, tensor_parallel_size: Optional[int] = None,
sequence_parallel: bool = False, sequence_parallel: bool = False,
with_fp8_compute: bool = False, with_quantized_compute: bool = False,
input_fp8_meta: Optional[dict[str, Any]] = None, input_quantizer: Optional[Quantizer] = None,
weight_fp8_meta: Optional[dict[str, Any]] = None, weight_quantizer: Optional[Quantizer] = None,
output_fp8_meta: Optional[dict[str, Any]] = None, output_quantizer: Optional[Quantizer] = None,
ub_comm_name: str, ub_comm_name: str,
) -> tuple[torch.Tensor, dict]: ) -> tuple[torch.Tensor, dict]:
"""Functional API for forward pass """Functional API for forward pass
...@@ -127,16 +123,14 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -127,16 +123,14 @@ class UserbuffersForwardLinear(FusedOperation):
parallelism, i.e. distributing input or output tensors parallelism, i.e. distributing input or output tensors
along outer dimension (sequence or batch dim) when not along outer dimension (sequence or batch dim) when not
distributing along inner dimension (embedding dim) distributing along inner dimension (embedding dim)
with_fp8_compute: bool, default = `False` with_quantized_compute: bool, default = `False`
Whether to perform compute in FP8 Whether to perform compute with quantized data.
input_fp8_meta: dict, optional input_quantizer: Quantizer, optional
FP8 metadata for casting input tensor to FP8. Required for Builder class for quantized input tensor.
FP8 compute if input is not already in FP8. weight_quantizer: Quantizer, optional
weight_fp8_meta: dict, optional Builder class for quantized weight tensor.
FP8 metadata for casting weight tensor to FP8. Required for output_quantizer: Quantizer, optional
FP8 compute if weight is not already in FP8. Builder class for quantized output tensor.
output_fp8_meta: dict, optional
FP8 metadata for casting output tensor to FP8
ub_comm_name: str ub_comm_name: str
Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is
used to access the corresponding Userbuffers communicators used to access the corresponding Userbuffers communicators
...@@ -166,23 +160,6 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -166,23 +160,6 @@ class UserbuffersForwardLinear(FusedOperation):
if dtype not in (torch.float32, torch.float16, torch.bfloat16): if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
# Input tensor dims
input_dims = tuple(input.size())
weight_dims = tuple(weight.size())
if len(weight_dims) != 2:
raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})")
if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]:
raise ValueError(
f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={weight_dims}) "
"are not compatible"
)
# Output tensor dims
output_dims = list(input_dims)
output_dims[0] = -1
output_dims[-1] = weight_dims[0]
# Check tensor parallel group # Check tensor parallel group
if tensor_parallel_size is None: if tensor_parallel_size is None:
tensor_parallel_size = get_distributed_world_size(tensor_parallel_group) tensor_parallel_size = get_distributed_world_size(tensor_parallel_group)
...@@ -196,235 +173,106 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -196,235 +173,106 @@ class UserbuffersForwardLinear(FusedOperation):
if not sequence_parallel: if not sequence_parallel:
raise RuntimeError(f"Invalid configuration for Userbuffers ({sequence_parallel=})") raise RuntimeError(f"Invalid configuration for Userbuffers ({sequence_parallel=})")
# Check if FP8 is enabled # Check quantizers
if with_fp8_compute: if with_quantized_compute:
if input_fp8_meta is None and not is_float8_tensor(input): if input_quantizer is None:
raise ValueError("No FP8 metadata was provided for casting input to FP8") raise ValueError("Missing quantizer for input tensor")
if weight_fp8_meta is None and not is_float8_tensor(weight): if weight_quantizer is None:
raise ValueError("No FP8 metadata was provided for casting weight to FP8") raise ValueError("Missing quantizer for weight tensor")
if output_quantizer is not None:
raise ValueError("FP8 output is not supported")
else: else:
input_fp8_meta = None input_quantizer = None
weight_fp8_meta = None weight_quantizer = None
output_fp8_meta = None output_quantizer = None
with_fp8_output = (
with_fp8_compute and tensor_parallel_mode != "row" and output_fp8_meta is not None
)
# Get Userbuffers communicator # Get Userbuffers communicator
ub_comm = get_ub(ub_comm_name + "_fprop") ub_comm = get_ub(ub_comm_name + "_fprop")
ub_local_buffer = ub_comm.get_ubuf_output(0)
ub_global_buffer = ub_comm.get_ubuf_output(1)
with_ub_all_gather = tensor_parallel_mode == "column" with_ub_all_gather = tensor_parallel_mode == "column"
with_ub_reduce_scatter = tensor_parallel_mode == "row" with_ub_reduce_scatter = tensor_parallel_mode == "row"
ub_type = CommOverlapType.AG if with_ub_all_gather else CommOverlapType.RS
# Choose Userbuffers communication algorithm # Initialize input tensor
ub_algo = None x_local = input
x = None
if with_ub_all_gather: if with_ub_all_gather:
if with_fp8_compute and ub_comm.is_atomic_gemm(): if input_quantizer is not None:
ub_algo = CommOverlapAlgo.ATOMIC_GEMM_AG_P2P if not isinstance(x_local, QuantizedTensorBase):
else: input_quantizer.set_usage(rowwise=True, columnwise=True)
ub_algo = CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if isinstance(input_quantizer, Float8Quantizer):
elif with_ub_reduce_scatter: input_quantizer.set_usage(columnwise=False)
is_atomic_gemm = with_fp8_compute and ub_comm.is_atomic_gemm() x_local = input_quantizer(x_local)
ub_algo = { input_quantizer.set_usage(rowwise=True, columnwise=False)
(True, True): CommOverlapAlgo.ATOMIC_GEMM_RS_P2P, x, x_local = fill_userbuffers_buffer_for_all_gather(
(True, False): CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P, ub_comm,
(False, True): CommOverlapAlgo.ATOMIC_GEMM_RS,
(False, False): CommOverlapAlgo.SPLIT_PIPELINED_RS,
}[(ub_comm.is_p2p_overlap(), is_atomic_gemm)]
else:
raise RuntimeError("Could not choose Userbuffers communication algorithm")
# Cast input tensor to correct dtype
x_local = reshape(
input,
(-1, input_dims[-1]),
device=device,
dtype=dtype,
)
if with_fp8_compute and not is_float8_tensor(x_local):
fp8_dtype = get_fp8_te_dtype(
input_fp8_meta["recipe"],
fprop_tensor=True,
)
with_transpose_cache = weight.requires_grad
if tensor_parallel_mode == "column" and sequence_parallel:
with_transpose_cache = False
x_local = Float8Tensor.to_float8(
x_local, x_local,
fp8_meta=input_fp8_meta, input_quantizer,
fp8_meta_forward=True, tensor_parallel_group,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
data=(ub_local_buffer if with_ub_all_gather else None),
with_transpose_cache=with_transpose_cache,
) )
elif not with_fp8_compute and is_float8_tensor(x_local):
if with_ub_all_gather:
x_local = ub_local_buffer.copy_(x_local)
else: else:
x_local = x_local.dequantize() if with_quantized_compute:
if not isinstance(x_local, QuantizedTensorBase):
# Initialize buffers for UB all-gather if needed input_quantizer.set_usage(rowwise=True, columnwise=True)
x = x_local x_local = input_quantizer(x_local)
if with_ub_all_gather:
if with_fp8_compute:
x = Float8Tensor.make_like(x_local, data=ub_global_buffer)
if x_local._data.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(x_local._data)
else: else:
x_local._data = torch.empty_like(x_local._data) if isinstance(x_local, QuantizedTensorBase):
else: x_local = x_local.dequantize(dtype=dtype)
x = ub_global_buffer if x_local.dtype != dtype:
if x_local.data_ptr() != ub_local_buffer.data_ptr(): x_local = x_local.to(dtype=dtype)
ub_local_buffer.copy_(x_local) x = x_local
else:
x_local = torch.empty_like(x_local)
# Check weight tensor # Initialize weight tensor
w = convert_tensor( w = weight
weight, w_is_quantized = isinstance(w, QuantizedTensorBase)
device=device, if with_quantized_compute and not w_is_quantized:
dtype=dtype, weight_quantizer.set_usage(rowwise=True)
memory_format=torch.contiguous_format, w = weight_quantizer(w)
) elif not with_quantized_compute and w_is_quantized:
if with_fp8_compute and not is_float8_tensor(w):
fp8_dtype = get_fp8_te_dtype(
weight_fp8_meta["recipe"],
fprop_tensor=True,
)
w = Float8Tensor.to_float8(
w,
fp8_meta=weight_fp8_meta,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
)
elif not with_fp8_compute and is_float8_tensor(w):
w = w.dequantize() w = w.dequantize()
if not with_quantized_compute and w.dtype != dtype:
w = w.to(dtype=dtype)
# Check bias tensor # Construct output tensor if needed
b = None reduce_scatter_output = None
if bias is not None:
b = convert_tensor(
bias,
device=device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
# Construct output tensor
y = None
y_local = None
if with_ub_reduce_scatter: if with_ub_reduce_scatter:
# Initialize buffers for UB reduce-scatter y_local_size = list(x.size())
if with_fp8_output: y_local_size[0] //= tensor_parallel_size
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) y_local_size[-1] = w.size(0)
fp8_dtype = get_fp8_te_dtype( reduce_scatter_output = torch.empty(y_local_size, dtype=dtype, device=device)
output_fp8_meta["recipe"],
fprop_tensor=True,
)
y = Float8Tensor(
data=ub_global_buffer,
fp8_meta=output_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=output_fp8_meta[fp8_meta_key].scale_inv[0],
dtype=dtype,
)
ub_comm.set_ubuf_scale_inv(y._scale_inv)
else:
y = ub_global_buffer
y_local = torch.empty(
(x.size(0) // tensor_parallel_size, weight_dims[0]),
dtype=dtype,
device=device,
)
else:
# Allocate output tensor
if with_fp8_output:
fp8_dtype = get_fp8_te_dtype(
output_fp8_meta["recipe"],
fprop_tensor=True,
)
data = torch.empty(
(x.size(0), weight_dims[0]),
dtype=torch.uint8,
device=device,
)
y = Float8Tensor(
data=data,
fp8_meta=output_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
dtype=dtype,
)
else:
y = torch.empty(
(x.size(0), weight_dims[0]),
dtype=dtype,
device=device,
)
y_local = y
# Perform GEMM # Perform GEMM
if with_fp8_compute: gemm_output, *_, reduce_scatter_output = general_gemm(
kwargs = { w,
"out": y, x,
"bias": b,
"use_bias": (b is not None),
"use_split_accumulator": False,
"ub_algo": ub_algo,
"ub": ub_comm,
}
if with_ub_all_gather:
kwargs["extra_output_tensor"] = x_local._data
if with_ub_reduce_scatter:
kwargs["extra_output_tensor"] = y_local
if with_fp8_output:
fp8_meta, fp8_meta_index = get_fp8_meta_from_fp8_tensor(y)
kwargs.update(
{
"out": y._data,
"out_index": fp8_meta_index,
"fp8_meta_tensor": fp8_meta,
"D_dtype": y._fp8_dtype,
}
)
fp8_gemm(
w._data,
w._scale_inv,
0,
w._fp8_dtype,
x._data,
x._scale_inv,
0,
x._fp8_dtype,
y.dtype,
get_workspace(), get_workspace(),
**kwargs, out_dtype=dtype,
quantization_params=output_quantizer,
bias=bias,
use_split_accumulator=_2X_ACC_FPROP,
ub=ub_comm,
ub_type=ub_type,
extra_output=reduce_scatter_output,
) )
else:
kwargs = {
"out": y,
"bias": b,
"use_bias": (b is not None),
"ub_algo": ub_algo,
"ub": ub_comm,
}
if with_ub_all_gather:
kwargs["extra_output_tensor"] = x_local
if with_ub_reduce_scatter: if with_ub_reduce_scatter:
kwargs["extra_output_tensor"] = y_local y_local = reduce_scatter_output
gemm(w, x, y.dtype, get_workspace(), **kwargs) else:
y_local = gemm_output
# Detach input tensor if needed
# Note: PyTorch autograd produces esoteric errors if we save
# input tensor as context for backward pass.
if x_local is input:
x_local = x_local.detach()
# Reshape output tensor # Configure input tensor for backward pass
out = reshape(y_local, output_dims) if with_quantized_compute and isinstance(x_local, QuantizedTensorBase):
if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
# Return cast tensors # Return cast tensors
extra_outputs = {"input": x_local, "weight": w} extra_outputs = {"input": x_local, "weight": w}
return out, extra_outputs return y_local, extra_outputs
def fuser_forward( def fuser_forward(
self, self,
...@@ -450,23 +298,22 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -450,23 +298,22 @@ class UserbuffersForwardLinear(FusedOperation):
if basic_op_kwargs[idx]: if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments") raise ValueError("Bias operation forward does not expect keyword arguments")
# FP8 metadata # Quantization metadata
with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled() with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_fp8_meta = None input_quantizer = None
weight_fp8_meta = None weight_quantizer = None
output_fp8_meta = None grad_output_quantizer = None
grad_output_fp8_meta = None grad_input_quantizer = None
grad_input_fp8_meta = None if with_quantized_compute:
if with_fp8_compute: recipe = FP8GlobalStateManager.get_fp8_recipe()
input_fp8_meta = linear_op.get_fp8_meta("input") if not recipe.delayed() and not recipe.mxfp8():
weight_fp8_meta = linear_op.get_fp8_meta("param") raise RuntimeError("Userbuffers is only supported with FP8 delayed scaling recipe")
next_op = basic_op_next_ops[-1] input_quantizer = linear_op.get_quantizer("forward", 0)
if next_op is not None and next_op.num_fp8_scales("input") > 0: weight_quantizer = linear_op.get_quantizer("forward", 1)
output_fp8_meta = next_op.get_fp8_meta("input") grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_output_fp8_meta = linear_op.get_fp8_meta("grad_output")
prev_op = basic_op_prev_ops[0] prev_op = basic_op_prev_ops[0]
if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: if prev_op is not None and prev_op.num_quantizers("backward") > 0 and recipe.delayed():
grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output") grad_input_quantizer = prev_op.get_quantizer("backward", 0)
# Get autocast dtype if needed # Get autocast dtype if needed
dtype = None dtype = None
...@@ -482,26 +329,26 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -482,26 +329,26 @@ class UserbuffersForwardLinear(FusedOperation):
input=input_, input=input_,
weight=linear_op.weight, weight=linear_op.weight,
bias=bias, bias=bias,
device=linear_op.device,
dtype=dtype, dtype=dtype,
tensor_parallel_mode=self.tensor_parallel_mode, tensor_parallel_mode=self.tensor_parallel_mode,
tensor_parallel_group=self.tensor_parallel_group, tensor_parallel_group=self.tensor_parallel_group,
tensor_parallel_size=self.tensor_parallel_size, tensor_parallel_size=self.tensor_parallel_size,
sequence_parallel=self.sequence_parallel, sequence_parallel=self.sequence_parallel,
with_fp8_compute=with_fp8_compute, with_quantized_compute=with_quantized_compute,
input_fp8_meta=input_fp8_meta, input_quantizer=input_quantizer,
weight_fp8_meta=weight_fp8_meta, weight_quantizer=weight_quantizer,
output_fp8_meta=output_fp8_meta, output_quantizer=None, # Not supported
ub_comm_name=linear_op._userbuffers_options["comm_name"], ub_comm_name=linear_op._userbuffers_options["comm_name"],
) )
x_local = extra_outputs["input"] x_local = extra_outputs["input"]
# Save state for backward pass # Save state for backward pass
linear_op_ctx.save_for_backward(x_local) linear_op_ctx.save_for_backward(x_local)
linear_op_ctx.with_fp8_compute = with_fp8_compute linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.weight_fp8_meta = weight_fp8_meta linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype linear_op_ctx.dtype = dtype
linear_op_ctx.input_dims = input_.size() linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_.requires_grad linear_op_ctx.input_requires_grad = input_.requires_grad
...@@ -529,8 +376,6 @@ def fuse_userbuffers_forward_linear( ...@@ -529,8 +376,6 @@ def fuse_userbuffers_forward_linear(
""" """
return ops ### TODO Debug Userbuffers support
# Return immediately if environment is not distributed # Return immediately if environment is not distributed
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
return ops return ops
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment