Unverified Commit c257bf31 authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

Blackwell devel commoverlap mlperftests (#1529)



* Add options to comm overlap tests
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Fix Typo
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Update tests/pytorch/distributed/run_layer_with_overlap.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 12c3e323
...@@ -106,7 +106,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -106,7 +106,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
install_reqs.extend(["torch>=2.1"]) install_reqs.extend(["torch>=2.1"])
# Blackwell is not supported as of Triton 3.2.0, need custom internal build # Blackwell is not supported as of Triton 3.2.0, need custom internal build
# install_reqs.append("triton") # install_reqs.append("triton")
test_reqs.extend(["numpy", "torchvision", "prettytable"]) test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"])
if "jax" in frameworks: if "jax" in frameworks:
install_reqs.extend(["jax", "flax>=0.7.1"]) install_reqs.extend(["jax", "flax>=0.7.1"])
# test_reqs.extend(["numpy", "praxis"]) # test_reqs.extend(["numpy", "praxis"])
......
...@@ -11,6 +11,7 @@ import subprocess ...@@ -11,6 +11,7 @@ import subprocess
import argparse import argparse
import warnings import warnings
import pprint import pprint
import yaml
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -46,7 +47,7 @@ def _get_layer_args(config, tp_group, tp_size, reference=False): ...@@ -46,7 +47,7 @@ def _get_layer_args(config, tp_group, tp_size, reference=False):
input_shape = [config.seq_length, config.batch_size, hidden_size] input_shape = [config.seq_length, config.batch_size, hidden_size]
args = [hidden_size] args = [hidden_size]
kwargs = { kwargs = {
"params_dtype": torch.float32, "params_dtype": torch.float32 if not config.use_bf16_params else torch.bfloat16,
"device": "cuda", "device": "cuda",
"tp_group": tp_group, "tp_group": tp_group,
"tp_size": tp_size, "tp_size": tp_size,
...@@ -59,11 +60,18 @@ def _get_layer_args(config, tp_group, tp_size, reference=False): ...@@ -59,11 +60,18 @@ def _get_layer_args(config, tp_group, tp_size, reference=False):
if config.linear_parallel_mode == "row": if config.linear_parallel_mode == "row":
input_shape[-1] = ffn_hidden_size // tp_size input_shape[-1] = ffn_hidden_size // tp_size
args = [ffn_hidden_size, hidden_size] args = [ffn_hidden_size, hidden_size]
if config.in_features is not None:
input_shape[-1] = config.in_features // tp_size
args = [config.in_features, hidden_size]
kwargs["ub_name"] = "proj" if config.layer_type == te.Linear else "fc2" kwargs["ub_name"] = "proj" if config.layer_type == te.Linear else "fc2"
kwargs["ub_name"] = kwargs["ub_name"] if config.ub_name is None else config.ub_name
elif config.linear_parallel_mode == "column": elif config.linear_parallel_mode == "column":
input_shape[0] = config.seq_length // tp_size input_shape[0] = config.seq_length // tp_size
if config.out_features is not None:
args.append(config.out_features)
else:
args.append(qkv_size) args.append(qkv_size)
kwargs["ub_name"] = "qkv" kwargs["ub_name"] = "qkv" if config.ub_name is None else config.ub_name
kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference
...@@ -87,6 +95,9 @@ def _get_layer_args(config, tp_group, tp_size, reference=False): ...@@ -87,6 +95,9 @@ def _get_layer_args(config, tp_group, tp_size, reference=False):
kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference
if config.ub_cfg is not None and isinstance(config.ub_cfg, str):
with open(config.ub_cfg, "r") as stream:
config.ub_cfg = yaml.safe_load(stream)
return args, kwargs, input_shape return args, kwargs, input_shape
...@@ -103,6 +114,30 @@ def _parse_args(argv=None, namespace=None): ...@@ -103,6 +114,30 @@ def _parse_args(argv=None, namespace=None):
parser.add_argument( parser.add_argument(
"-d", "--head-dim", type=int, default=48, help="Dimension of each attention head." "-d", "--head-dim", type=int, default=48, help="Dimension of each attention head."
) )
parser.add_argument(
"--in-features",
type=int,
default=None,
help="Optional input feature size for weight. Only used for Linear layer.",
)
parser.add_argument(
"--out-features",
type=int,
default=None,
help="Optional output feature size for weight. Only used for LayerNormLinear layer.",
)
parser.add_argument(
"--tp",
type=int,
default=None,
help="Optional tensor_model_parallel_size used to initialize UB.",
)
parser.add_argument(
"--use-bf16-params",
action="store_true",
default=False,
help="Use BF16 params instead of FP32.",
)
parser.add_argument("--seed", type=int, default=42, help="RNG seed.") parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
parser.add_argument( parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
...@@ -132,6 +167,28 @@ def _parse_args(argv=None, namespace=None): ...@@ -132,6 +167,28 @@ def _parse_args(argv=None, namespace=None):
parser.add_argument( parser.add_argument(
"--use-cuda-graphs", action="store_true", default=False, help="Use CUDA Graphs." "--use-cuda-graphs", action="store_true", default=False, help="Use CUDA Graphs."
) )
parser.add_argument(
"--ub-cfg", type=str, default=None, help="Optional TP config yaml file input."
)
parser.add_argument("--ub-name", type=str, default=None, help="Optional TP layer name.")
parser.add_argument(
"--skip-verify",
action="store_true",
default=False,
help="Skip numerics check.",
)
parser.add_argument(
"--benchmark",
action="store_true",
default=False,
help="Benchmark comm-gemm overlap perf.",
)
parser.add_argument(
"--benchmark-iter",
type=int,
default=100,
help="Number of iterations for benchmarking perf.",
)
parser.add_argument( parser.add_argument(
"--linear-parallel-mode", "--linear-parallel-mode",
type=str.lower, type=str.lower,
...@@ -223,9 +280,36 @@ def _train(opts): ...@@ -223,9 +280,36 @@ def _train(opts):
shell=True, shell=True,
) )
if result.stdout == "0": # Extra checks for non-MNNVL platforms if result.stdout == "0" and opts.tp is None: # Extra checks for non-MNNVL platforms
assert WORLD_SIZE == LOCAL_SIZE assert WORLD_SIZE == LOCAL_SIZE
# Initialize torch.distributed tp process group
new_group_kwargs = {
"backend": "nccl",
}
if opts.tp is not None:
LOCAL_SIZE = opts.tp
tp_base_rank = (WORLD_RANK // LOCAL_SIZE) * LOCAL_SIZE
tp_rank_list = list(range(tp_base_rank, tp_base_rank + LOCAL_SIZE))
new_group_kwargs = {
"backend": "nccl",
"ranks": tp_rank_list,
}
else:
opts.tp = WORLD_SIZE
# Tensor dim overrides for tensors that do not require TP communication
if opts.in_features is not None:
assert opts.layer_type is te.Linear and opts.linear_parallel_mode == "row", (
"--in-features is only used to configure row-tensor-parallel Linear layers. Use"
" --num-heads or --head-dim for other cases."
)
if opts.out_features is not None:
assert opts.layer_type is te.LayerNormLinear and opts.linear_parallel_mode == "column", (
"--out-features is only used to configure column-tensor-parallel LayerNormLinear"
" layers. Use --num-heads or --head-dim for other cases."
)
def dist_print(msg, src=None, end="\n", debug=False, error=False): def dist_print(msg, src=None, end="\n", debug=False, error=False):
if debug and not opts.debug: if debug and not opts.debug:
return return
...@@ -253,9 +337,11 @@ def _train(opts): ...@@ -253,9 +337,11 @@ def _train(opts):
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available() assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs) dist.init_process_group(**dist_init_kwargs)
nccl_world = dist.new_group(backend="nccl") nccl_world = dist.new_group(**new_group_kwargs)
dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs")
# Initialize the Transformer Engine layer with overlap
args, kwargs, input_shape = _get_layer_args(opts, nccl_world, opts.tp)
# Intialize userbuffers # Intialize userbuffers
ub_cfgs = None ub_cfgs = None
if opts.overlap_rs_dgrad: if opts.overlap_rs_dgrad:
...@@ -265,15 +351,13 @@ def _train(opts): ...@@ -265,15 +351,13 @@ def _train(opts):
} }
te.module.base.initialize_ub( te.module.base.initialize_ub(
[opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim], [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim],
WORLD_SIZE, opts.tp,
use_fp8=opts.fp8, use_fp8=opts.fp8,
dtype=torch.bfloat16, dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend, bootstrap_backend=opts.bootstrap_backend,
ub_cfgs=ub_cfgs, ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg,
) )
# Initialize the Transformer Engine layer with overlap
args, kwargs, input_shape = _get_layer_args(opts, nccl_world, WORLD_SIZE)
with te.fp8_model_init(enabled=opts.fp8_init): with te.fp8_model_init(enabled=opts.fp8_init):
test_model = opts.layer_type(*args, **kwargs) test_model = opts.layer_type(*args, **kwargs)
dist_print("Initialized test model...", debug=True) dist_print("Initialized test model...", debug=True)
...@@ -283,7 +367,7 @@ def _train(opts): ...@@ -283,7 +367,7 @@ def _train(opts):
dist.barrier() dist.barrier()
# Initialize the reference model and copy all parameters # Initialize the reference model and copy all parameters
ref_args, ref_kwargs, _ = _get_layer_args(opts, nccl_world, WORLD_SIZE, reference=True) ref_args, ref_kwargs, _ = _get_layer_args(opts, nccl_world, opts.tp, reference=True)
with te.fp8_model_init(enabled=opts.fp8_init): with te.fp8_model_init(enabled=opts.fp8_init):
ref_model = opts.layer_type(*ref_args, **ref_kwargs) ref_model = opts.layer_type(*ref_args, **ref_kwargs)
dist_print("Initialized reference model...", debug=True) dist_print("Initialized reference model...", debug=True)
...@@ -326,6 +410,7 @@ def _train(opts): ...@@ -326,6 +410,7 @@ def _train(opts):
with torch.cuda.graph(test_graph): with torch.cuda.graph(test_graph):
test_out = run_fwd_bwd(test_model, test_x) test_out = run_fwd_bwd(test_model, test_x)
test_graph.replay() test_graph.replay()
if not opts.benchmark:
del test_graph del test_graph
else: else:
test_out = run_fwd_bwd(test_model, test_x) test_out = run_fwd_bwd(test_model, test_x)
...@@ -351,8 +436,9 @@ def _train(opts): ...@@ -351,8 +436,9 @@ def _train(opts):
if ref_param.requires_grad and "layer_norm" not in ref_name: if ref_param.requires_grad and "layer_norm" not in ref_name:
ref_grads.append(ref_param.grad) ref_grads.append(ref_param.grad)
# Make sure we have the same number of gradients
numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda") numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda")
if not opts.skip_verify:
# Make sure we have the same number of gradients
if len(test_grads) != len(ref_grads): if len(test_grads) != len(ref_grads):
numerics_failed[0] = 1 numerics_failed[0] = 1
numerics_info = ( numerics_info = (
...@@ -374,6 +460,23 @@ def _train(opts): ...@@ -374,6 +460,23 @@ def _train(opts):
if bool(numerics_failed.item()) and not opts.debug: if bool(numerics_failed.item()) and not opts.debug:
break break
if opts.benchmark:
# Warmup to not profile CPU overhead
for _ in range(100):
if opts.use_cuda_graphs:
test_graph.replay()
else:
test_out = run_fwd_bwd(test_model, test_x)
torch.cuda.cudart().cudaProfilerStart()
for _ in range(opts.benchmark_iter):
if opts.use_cuda_graphs:
test_graph.replay()
else:
test_out = run_fwd_bwd(test_model, test_x)
torch.cuda.cudart().cudaProfilerStop()
if opts.use_cuda_graphs:
del test_graph
te.module.base.destroy_ub() te.module.base.destroy_ub()
dist_print("Destroying Userbuffers objects...", debug=True) dist_print("Destroying Userbuffers objects...", debug=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment