Commit 3d57ff8c authored by yuguo's avatar yuguo
Browse files

Merge branch 'develop_v2.4' into 'main'

[DCU] add NVTE_TP_OVERLAP_AGGREGATE

See merge request dcutoolkit/deeplearing/TransformerEngine!28
parents bfd4074f b1864da3
...@@ -329,7 +329,7 @@ def initialize_ub( ...@@ -329,7 +329,7 @@ def initialize_ub(
"cga_size": 1 if method == "ring_exchange" else 2, "cga_size": 1 if method == "ring_exchange" else 2,
"set_sm_margin": not method == "ring_exchange", "set_sm_margin": not method == "ring_exchange",
"num_splits": tp_size if method == "ring_exchange" else 4, "num_splits": tp_size if method == "ring_exchange" else 4,
"aggregate": False, "aggregate": bool(int(os.getenv("NVTE_TP_OVERLAP_AGGREGATE", "0"))),
"atomic_gemm": False, "atomic_gemm": False,
"use_ce": True, "use_ce": True,
"fp8_buf": name in layers_all_gather_overlap, "fp8_buf": name in layers_all_gather_overlap,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment