Unverified Commit 24024061 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] Adding TP overlap support for `te.Linear` with `parallel_mode="column"` (#1343)



* support AG overlap in sequence-parallel Linear forward and RS overlap in sequence-parallel Linear backward
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* implemented TP overlap support for column-parallel te.Linear
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed backward pass for te.Linear column-parallel with TP overlap, updated unit tests
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* improved error messages for internal failure to infer TP overlap options in te.Linear
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed linting errors
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed incorrect TP overlap option asserts
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent cbc46531
...@@ -51,15 +51,23 @@ def _get_layer_args(config, tp_group, tp_size, reference=False): ...@@ -51,15 +51,23 @@ def _get_layer_args(config, tp_group, tp_size, reference=False):
kwargs["ub_overlap_ag"] = not reference kwargs["ub_overlap_ag"] = not reference
if config.layer_type is te.Linear: if config.layer_type is te.Linear:
input_shape[2] = hidden_size // tp_size if config.linear_parallel_mode == "row":
args.append(hidden_size) input_shape[2] = hidden_size // tp_size
kwargs["parallel_mode"] = "row" args.append(hidden_size)
kwargs["ub_overlap_rs"] = not reference kwargs["ub_overlap_rs"] = not reference
kwargs["ub_name"] = "proj" elif config.linear_parallel_mode == "column":
input_shape[0] = config.seq_length // tp_size
args.append(3 * hidden_size)
kwargs["ub_overlap_rs"] = config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference
kwargs["parallel_mode"] = config.linear_parallel_mode
kwargs["ub_name"] = "proj" if config.linear_parallel_mode == "row" else "qkv"
else: else:
input_shape[0] = config.seq_length // tp_size input_shape[0] = config.seq_length // tp_size
kwargs["ub_bulk_wgrad"] = not reference kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_dgrad"] = not reference kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference
if config.layer_type is te.LayerNormLinear: if config.layer_type is te.LayerNormLinear:
args.append(3 * hidden_size) args.append(3 * hidden_size)
kwargs["parallel_mode"] = "column" kwargs["parallel_mode"] = "column"
...@@ -125,6 +133,19 @@ def _parse_args(argv=None, namespace=None): ...@@ -125,6 +133,19 @@ def _parse_args(argv=None, namespace=None):
parser.add_argument( parser.add_argument(
"--use-cuda-graphs", action="store_true", default=False, help="Use CUDA Graphs." "--use-cuda-graphs", action="store_true", default=False, help="Use CUDA Graphs."
) )
parser.add_argument(
"--linear-parallel-mode",
type=str.lower,
default="row",
choices=["row", "column"],
help="Parallel mode for te.Linear.",
)
parser.add_argument(
"--overlap-rs-dgrad",
action="store_true",
default=False,
help="Overlap reduce-scatter with DGRAD in the backward pass instead of bulk overlaps.",
)
parser.add_argument( parser.add_argument(
"--debug", "--debug",
action="store_true", action="store_true",
...@@ -230,12 +251,19 @@ def _train(opts): ...@@ -230,12 +251,19 @@ def _train(opts):
dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs")
# Intialize userbuffers # Intialize userbuffers
ub_cfgs = None
if opts.overlap_rs_dgrad:
ub_cfgs = {
"proj_dgrad": {"method": "ring_exchange"},
"qkv_dgrad": {"method": "ring_exchange"},
}
te.module.base.initialize_ub( te.module.base.initialize_ub(
[opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim], [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim],
WORLD_SIZE, WORLD_SIZE,
use_fp8=opts.fp8, use_fp8=opts.fp8,
dtype=torch.bfloat16, dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend, bootstrap_backend=opts.bootstrap_backend,
ub_cfgs=ub_cfgs,
) )
# Initialize the Transformer Engine layer with overlap # Initialize the Transformer Engine layer with overlap
...@@ -314,27 +342,29 @@ def _train(opts): ...@@ -314,27 +342,29 @@ def _train(opts):
ref_grads.append(ref_param.grad) ref_grads.append(ref_param.grad)
# Make sure we have the same number of gradients # Make sure we have the same number of gradients
numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda") num_grads_failed = torch.tensor([0], dtype=torch.uint8, device="cuda")
if len(test_grads) != len(ref_grads): if len(test_grads) != len(ref_grads):
numerics_failed[0] = 1 num_grads_failed[0] = 1
numerics_info = ( numerics_info = (
"NUMERICAL CHECK FAILED: Incorrect number of gradients, " "NUMERICAL CHECK FAILED: Incorrect number of gradients, "
+ f"expected {len(ref_grads)} but got {len(test_grads)}." + f"expected {len(ref_grads)} but got {len(test_grads)}."
) )
dist_print(numerics_info, src=WORLD_RANK, error=True) dist_print(numerics_info, src=WORLD_RANK, error=True)
dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) dist.all_reduce(num_grads_failed, dist.ReduceOp.MAX, nccl_world)
# Now validate accuracy # Now validate accuracy
if not bool(numerics_failed.item()): numerics_failed = torch.zeros(len(test_grads), dtype=torch.uint8, device="cuda")
if not bool(num_grads_failed.item()):
for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)): for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)):
rtol = 0.125 if opts.fp8 else 0.025 rtol = 0.125 if opts.fp8 else 0.025
atol = 0.0625 if opts.fp8 else 0.00125 atol = 0.0625 if opts.fp8 else 0.00125
grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol) grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol)
dist_print(grad_info, src=WORLD_RANK, error=grad_failed) dist_print(grad_info, src=WORLD_RANK, error=grad_failed)
numerics_failed[0] = int(grad_failed) numerics_failed[i] = int(grad_failed)
dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) return_code = torch.max(numerics_failed)
if bool(numerics_failed.item()): dist.all_reduce(return_code, dist.ReduceOp.MAX, nccl_world)
break else:
return_code = num_grads_failed
te.module.base.destroy_ub() te.module.base.destroy_ub()
dist_print("Destroying Userbuffers objects...", debug=True) dist_print("Destroying Userbuffers objects...", debug=True)
...@@ -344,7 +374,7 @@ def _train(opts): ...@@ -344,7 +374,7 @@ def _train(opts):
if opts.debug and WORLD_RANK == 0: if opts.debug and WORLD_RANK == 0:
print("Exiting...\n", end="", flush=True) print("Exiting...\n", end="", flush=True)
return numerics_failed[0].item() return return_code.item()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -21,8 +21,10 @@ SEQ_LENGTH: int = 512 ...@@ -21,8 +21,10 @@ SEQ_LENGTH: int = 512
BATCH_SIZE: int = 2 BATCH_SIZE: int = 2
NUM_HEADS: int = 12 NUM_HEADS: int = 12
HEAD_DIM: int = 64 HEAD_DIM: int = 64
# NOTE: te.Linear is intentionally omitted here and manually added later for testing both
# row and column parallel layouts.
TE_LAYERS = [ TE_LAYERS = [
te.Linear,
te.LayerNormLinear, te.LayerNormLinear,
te.LayerNormMLP, te.LayerNormMLP,
te.MultiheadAttention, te.MultiheadAttention,
...@@ -86,7 +88,7 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggreg ...@@ -86,7 +88,7 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggreg
raise AssertionError(result.stderr.decode()) raise AssertionError(result.stderr.decode())
def _run_layer_with_overlap(layer_type, fp8, fp8_init): def _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init):
test_path = TEST_ROOT / "run_layer_with_overlap.py" test_path = TEST_ROOT / "run_layer_with_overlap.py"
test_cmd = LAUNCH_CMD + [ test_cmd = LAUNCH_CMD + [
str(test_path), str(test_path),
...@@ -97,6 +99,8 @@ def _run_layer_with_overlap(layer_type, fp8, fp8_init): ...@@ -97,6 +99,8 @@ def _run_layer_with_overlap(layer_type, fp8, fp8_init):
f"--head-dim={HEAD_DIM}", f"--head-dim={HEAD_DIM}",
f"--layer-type={layer_type}", f"--layer-type={layer_type}",
] ]
if layer_type == te.Linear.__name__:
test_cmd.append(f"--linear-parallel-mode={linear_parallel_mode}")
if fp8: if fp8:
if not fp8_available: if not fp8_available:
...@@ -245,9 +249,15 @@ def test_bulk_overlaps(comm_type, fp8, connections): ...@@ -245,9 +249,15 @@ def test_bulk_overlaps(comm_type, fp8, connections):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"layer_type", "layer_type,linear_parallel_mode",
[layer.__name__ for layer in TE_LAYERS], (
ids=[(" " + layer.__name__ + " ") for layer in TE_LAYERS], [(te.Linear.__name__, "row"), (te.Linear.__name__, "column")]
+ list(zip([layer.__name__ for layer in TE_LAYERS], [None for _ in range(len(TE_LAYERS))]))
),
ids=(
[f" {te.Linear.__name__} (row-parallel) ", f" {te.Linear.__name__} (column-parallel) "]
+ [(" " + layer.__name__ + " ") for layer in TE_LAYERS]
),
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"fp8,fp8_init", "fp8,fp8_init",
...@@ -262,8 +272,8 @@ def test_bulk_overlaps(comm_type, fp8, connections): ...@@ -262,8 +272,8 @@ def test_bulk_overlaps(comm_type, fp8, connections):
" FP8 GEMM - FP8 PARAMS ", " FP8 GEMM - FP8 PARAMS ",
], ],
) )
def test_layers_with_overlap(layer_type, fp8, fp8_init): def test_layers_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init):
""" """
Test Transformer Engine layers with comm+GEMM overlap. Test Transformer Engine layers with comm+GEMM overlap.
""" """
_run_layer_with_overlap(layer_type, fp8, fp8_init) _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment