Unverified Commit d541d208 authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

Fix the default userbuffer communicator init settings (#755)



fix the default userbuffer communicator init settings
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
parent e3de4037
......@@ -153,13 +153,13 @@ def initialize_ub(
def add_ub(
name: str,
method: str,
is_reduce_scatter: int,
num_sm: int = 16,
cga_size: int = 2,
set_sm_margin: int = 0,
num_splits: int = 4,
num_splits: int = 0,
aggregate: int = 0,
atomic_gemm: int = 0,
is_reduce_scatter: int = 0,
fp8_buf: bool = False,
) -> None:
if atomic_gemm:
......@@ -243,7 +243,7 @@ def initialize_ub(
method = ub_cfg.get("method", get_method(name))
num_sm = ub_cfg.get("num_sm", 16)
cga_size = ub_cfg.get("cga_size", 2)
num_splits = ub_cfg.get("num_splits", 4)
num_splits = ub_cfg.get("num_splits", 4 if method == "pipeline" else 0)
set_sm_margin = ub_cfg.get("set_sm_margin", 0)
aggregate = ub_cfg.get("aggregate", 0)
atomic_gemm = ub_cfg.get("atomic_gemm", 0)
......@@ -254,21 +254,24 @@ def initialize_ub(
add_ub(
name,
method,
is_reduce_scatter,
num_sm,
cga_size,
set_sm_margin,
num_splits,
aggregate,
atomic_gemm,
is_reduce_scatter,
fp8_buf,
)
else:
method = get_method(name)
if method == "pipeline":
add_ub(name, method)
else:
add_ub(name, method, num_splits=0)
add_ub(
name,
method=method,
is_reduce_scatter=1 if name in layers_reduce_scatter_overlap else 0,
num_splits=4 if method == "pipeline" else 0,
fp8_buf=name in layers_all_gather_overlap,
)
def get_ub(name: str):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment