Unverified Commit 7d8ef9bf authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

userbuffer: support fp8 buffer for individual overlap instance (#750)



* userbuffer fp8 reduction support for individual overlap
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* cleanup dict ub_cfg dict value load
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* cleanup
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* Remove unnecessary fence from producer

From @erhoo82 
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 27cb9950
......@@ -48,7 +48,6 @@
#define ATOMIC_PRODUCER(chunk) \
if (counters) { \
((unsigned int *)counters)[chunk] = 0; \
asm volatile("fence.sc.gpu;\n"); \
}
// Return true if producer > consumer, otherwise false while preventing integer overflow
......
......@@ -126,18 +126,16 @@ def initialize_ub(
_cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS)
# Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
fp8_buf = [
layers_all_gather_overlap = [
"qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad"
]
if bool(int(os.getenv("NVTE_UB_FP8_RS", "0"))):
fp8_buf += ["proj_fprop", "fc2_fprop"]
layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
# Default overlap methods for layers
methods = {
"ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
"pipeline":["proj_fprop", "fc2_fprop"],
"bulk":["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
}
layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
# AG-RS overlap pairs of layers forming a tensor-parallel block
ag_rs_pairs = {"qkv_fprop":"proj_fprop", "fc1_fprop":"fc2_fprop"}
......@@ -161,6 +159,7 @@ def initialize_ub(
aggregate: int = 0,
atomic_gemm: int = 0,
is_reduce_scatter: int = 0,
fp8_buf: bool = False,
) -> None:
if atomic_gemm:
warnings.warn(
......@@ -198,7 +197,7 @@ def initialize_ub(
sample_buffer = torch.empty(
shape,
dtype=torch.uint8 if (use_fp8 and name in fp8_buf) else dtype,
dtype=torch.uint8 if (use_fp8 and fp8_buf) else dtype,
device='cuda')
if method == 'ring_exchange':
ub_obj = tex.UbufP2PCommOverlap(
......@@ -232,14 +231,17 @@ def initialize_ub(
for name in (methods["ring_exchange"]+methods["pipeline"]+methods["bulk"]):
if ub_cfgs is not None and name in ub_cfgs:
ub_cfg = ub_cfgs[name]
method = ub_cfg["method"] if "method" in ub_cfg else get_method(name)
num_sm = ub_cfg["num_sm"] if "num_sm" in ub_cfg else 16
cga_size = ub_cfg["cga_size"] if "cga_size" in ub_cfg else 2
num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 4
set_sm_margin = ub_cfg["set_sm_margin"] if "set_sm_margin" in ub_cfg else 0
aggregate = ub_cfg["aggregate"] if "aggregate" in ub_cfg else 0
atomic_gemm = ub_cfg["atomic_gemm"] if "atomic_gemm" in ub_cfg else 0
method = ub_cfg.get("method", get_method(name))
num_sm = ub_cfg.get("num_sm", 16)
cga_size = ub_cfg.get("cga_size", 2)
num_splits = ub_cfg.get("num_splits", 4)
set_sm_margin = ub_cfg.get("set_sm_margin", 0)
aggregate = ub_cfg.get("aggregate", 0)
atomic_gemm = ub_cfg.get("atomic_gemm", 0)
is_reduce_scatter = 1 if name in layers_reduce_scatter_overlap else 0
# Support FP8 userbuffer when (1) AllGather and (2) FP8-GEMM output ReduceScatter
fp8_buf = ((name in layers_all_gather_overlap) or
(ub_cfg.get("fp8_buf", False) and name in methods["pipeline"]))
add_ub(
name,
method,
......@@ -250,6 +252,7 @@ def initialize_ub(
aggregate,
atomic_gemm,
is_reduce_scatter,
fp8_buf,
)
else:
method = get_method(name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment