Unverified Commit d541d208 authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

Fix the default userbuffer communicator init settings (#755)



fix the default userbuffer communicator init settings
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
parent e3de4037
...@@ -153,13 +153,13 @@ def initialize_ub( ...@@ -153,13 +153,13 @@ def initialize_ub(
def add_ub( def add_ub(
name: str, name: str,
method: str, method: str,
is_reduce_scatter: int,
num_sm: int = 16, num_sm: int = 16,
cga_size: int = 2, cga_size: int = 2,
set_sm_margin: int = 0, set_sm_margin: int = 0,
num_splits: int = 4, num_splits: int = 0,
aggregate: int = 0, aggregate: int = 0,
atomic_gemm: int = 0, atomic_gemm: int = 0,
is_reduce_scatter: int = 0,
fp8_buf: bool = False, fp8_buf: bool = False,
) -> None: ) -> None:
if atomic_gemm: if atomic_gemm:
...@@ -243,7 +243,7 @@ def initialize_ub( ...@@ -243,7 +243,7 @@ def initialize_ub(
method = ub_cfg.get("method", get_method(name)) method = ub_cfg.get("method", get_method(name))
num_sm = ub_cfg.get("num_sm", 16) num_sm = ub_cfg.get("num_sm", 16)
cga_size = ub_cfg.get("cga_size", 2) cga_size = ub_cfg.get("cga_size", 2)
num_splits = ub_cfg.get("num_splits", 4) num_splits = ub_cfg.get("num_splits", 4 if method == "pipeline" else 0)
set_sm_margin = ub_cfg.get("set_sm_margin", 0) set_sm_margin = ub_cfg.get("set_sm_margin", 0)
aggregate = ub_cfg.get("aggregate", 0) aggregate = ub_cfg.get("aggregate", 0)
atomic_gemm = ub_cfg.get("atomic_gemm", 0) atomic_gemm = ub_cfg.get("atomic_gemm", 0)
...@@ -254,21 +254,24 @@ def initialize_ub( ...@@ -254,21 +254,24 @@ def initialize_ub(
add_ub( add_ub(
name, name,
method, method,
is_reduce_scatter,
num_sm, num_sm,
cga_size, cga_size,
set_sm_margin, set_sm_margin,
num_splits, num_splits,
aggregate, aggregate,
atomic_gemm, atomic_gemm,
is_reduce_scatter,
fp8_buf, fp8_buf,
) )
else: else:
method = get_method(name) method = get_method(name)
if method == "pipeline": add_ub(
add_ub(name, method) name,
else: method=method,
add_ub(name, method, num_splits=0) is_reduce_scatter=1 if name in layers_reduce_scatter_overlap else 0,
num_splits=4 if method == "pipeline" else 0,
fp8_buf=name in layers_all_gather_overlap,
)
def get_ub(name: str): def get_ub(name: str):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment