Unverified Commit 24024061 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] Adding TP overlap support for `te.Linear` with `parallel_mode="column"` (#1343)



* support AG overlap in sequence-parallel Linear forward and RS overlap in sequence-parallel Linear backward
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* implemented TP overlap support for column-parallel te.Linear
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed backward pass for te.Linear column-parallel with TP overlap, updated unit tests
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* improved error messages for internal failure to infer TP overlap options in te.Linear
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed linting errors
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed incorrect TP overlap option asserts
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent cbc46531
......@@ -51,15 +51,23 @@ def _get_layer_args(config, tp_group, tp_size, reference=False):
kwargs["ub_overlap_ag"] = not reference
if config.layer_type is te.Linear:
input_shape[2] = hidden_size // tp_size
args.append(hidden_size)
kwargs["parallel_mode"] = "row"
kwargs["ub_overlap_rs"] = not reference
kwargs["ub_name"] = "proj"
if config.linear_parallel_mode == "row":
input_shape[2] = hidden_size // tp_size
args.append(hidden_size)
kwargs["ub_overlap_rs"] = not reference
elif config.linear_parallel_mode == "column":
input_shape[0] = config.seq_length // tp_size
args.append(3 * hidden_size)
kwargs["ub_overlap_rs"] = config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference
kwargs["parallel_mode"] = config.linear_parallel_mode
kwargs["ub_name"] = "proj" if config.linear_parallel_mode == "row" else "qkv"
else:
input_shape[0] = config.seq_length // tp_size
kwargs["ub_bulk_wgrad"] = not reference
kwargs["ub_bulk_dgrad"] = not reference
kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference
if config.layer_type is te.LayerNormLinear:
args.append(3 * hidden_size)
kwargs["parallel_mode"] = "column"
......@@ -125,6 +133,19 @@ def _parse_args(argv=None, namespace=None):
parser.add_argument(
"--use-cuda-graphs", action="store_true", default=False, help="Use CUDA Graphs."
)
parser.add_argument(
"--linear-parallel-mode",
type=str.lower,
default="row",
choices=["row", "column"],
help="Parallel mode for te.Linear.",
)
parser.add_argument(
"--overlap-rs-dgrad",
action="store_true",
default=False,
help="Overlap reduce-scatter with DGRAD in the backward pass instead of bulk overlaps.",
)
parser.add_argument(
"--debug",
action="store_true",
......@@ -230,12 +251,19 @@ def _train(opts):
dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs")
# Intialize userbuffers
ub_cfgs = None
if opts.overlap_rs_dgrad:
ub_cfgs = {
"proj_dgrad": {"method": "ring_exchange"},
"qkv_dgrad": {"method": "ring_exchange"},
}
te.module.base.initialize_ub(
[opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim],
WORLD_SIZE,
use_fp8=opts.fp8,
dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend,
ub_cfgs=ub_cfgs,
)
# Initialize the Transformer Engine layer with overlap
......@@ -314,27 +342,29 @@ def _train(opts):
ref_grads.append(ref_param.grad)
# Make sure we have the same number of gradients
numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda")
num_grads_failed = torch.tensor([0], dtype=torch.uint8, device="cuda")
if len(test_grads) != len(ref_grads):
numerics_failed[0] = 1
num_grads_failed[0] = 1
numerics_info = (
"NUMERICAL CHECK FAILED: Incorrect number of gradients, "
+ f"expected {len(ref_grads)} but got {len(test_grads)}."
)
dist_print(numerics_info, src=WORLD_RANK, error=True)
dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world)
dist.all_reduce(num_grads_failed, dist.ReduceOp.MAX, nccl_world)
# Now validate accuracy
if not bool(numerics_failed.item()):
numerics_failed = torch.zeros(len(test_grads), dtype=torch.uint8, device="cuda")
if not bool(num_grads_failed.item()):
for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)):
rtol = 0.125 if opts.fp8 else 0.025
atol = 0.0625 if opts.fp8 else 0.00125
grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol)
dist_print(grad_info, src=WORLD_RANK, error=grad_failed)
numerics_failed[0] = int(grad_failed)
dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world)
if bool(numerics_failed.item()):
break
numerics_failed[i] = int(grad_failed)
return_code = torch.max(numerics_failed)
dist.all_reduce(return_code, dist.ReduceOp.MAX, nccl_world)
else:
return_code = num_grads_failed
te.module.base.destroy_ub()
dist_print("Destroying Userbuffers objects...", debug=True)
......@@ -344,7 +374,7 @@ def _train(opts):
if opts.debug and WORLD_RANK == 0:
print("Exiting...\n", end="", flush=True)
return numerics_failed[0].item()
return return_code.item()
if __name__ == "__main__":
......
......@@ -21,8 +21,10 @@ SEQ_LENGTH: int = 512
BATCH_SIZE: int = 2
NUM_HEADS: int = 12
HEAD_DIM: int = 64
# NOTE: te.Linear is intentionally omitted here and manually added later for testing both
# row and column parallel layouts.
TE_LAYERS = [
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
te.MultiheadAttention,
......@@ -86,7 +88,7 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggreg
raise AssertionError(result.stderr.decode())
def _run_layer_with_overlap(layer_type, fp8, fp8_init):
def _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init):
test_path = TEST_ROOT / "run_layer_with_overlap.py"
test_cmd = LAUNCH_CMD + [
str(test_path),
......@@ -97,6 +99,8 @@ def _run_layer_with_overlap(layer_type, fp8, fp8_init):
f"--head-dim={HEAD_DIM}",
f"--layer-type={layer_type}",
]
if layer_type == te.Linear.__name__:
test_cmd.append(f"--linear-parallel-mode={linear_parallel_mode}")
if fp8:
if not fp8_available:
......@@ -245,9 +249,15 @@ def test_bulk_overlaps(comm_type, fp8, connections):
@pytest.mark.parametrize(
"layer_type",
[layer.__name__ for layer in TE_LAYERS],
ids=[(" " + layer.__name__ + " ") for layer in TE_LAYERS],
"layer_type,linear_parallel_mode",
(
[(te.Linear.__name__, "row"), (te.Linear.__name__, "column")]
+ list(zip([layer.__name__ for layer in TE_LAYERS], [None for _ in range(len(TE_LAYERS))]))
),
ids=(
[f" {te.Linear.__name__} (row-parallel) ", f" {te.Linear.__name__} (column-parallel) "]
+ [(" " + layer.__name__ + " ") for layer in TE_LAYERS]
),
)
@pytest.mark.parametrize(
"fp8,fp8_init",
......@@ -262,8 +272,8 @@ def test_bulk_overlaps(comm_type, fp8, connections):
" FP8 GEMM - FP8 PARAMS ",
],
)
def test_layers_with_overlap(layer_type, fp8, fp8_init):
def test_layers_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap(layer_type, fp8, fp8_init)
_run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment