"...git@developer.sourcefind.cn:yaoyuping/nndetection.git" did not exist on "131a40e9fd36c5525e292be74feb8a286421392b"
Unverified Commit 7d8ef9bf authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

userbuffer: support fp8 buffer for individual overlap instance (#750)



* userbuffer fp8 reduction support for individual overlap
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* cleanup dict ub_cfg dict value load
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* cleanup
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* Remove unnecessary fence from producer

From @erhoo82 
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 27cb9950
...@@ -48,7 +48,6 @@ ...@@ -48,7 +48,6 @@
#define ATOMIC_PRODUCER(chunk) \ #define ATOMIC_PRODUCER(chunk) \
if (counters) { \ if (counters) { \
((unsigned int *)counters)[chunk] = 0; \ ((unsigned int *)counters)[chunk] = 0; \
asm volatile("fence.sc.gpu;\n"); \
} }
// Return true if producer > consumer, otherwise false while preventing integer overflow // Return true if producer > consumer, otherwise false while preventing integer overflow
......
...@@ -126,18 +126,16 @@ def initialize_ub( ...@@ -126,18 +126,16 @@ def initialize_ub(
_cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS) _cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS)
# Default buffer precision: AllGather buffers use fp8 when using fp8 recipe # Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
fp8_buf = [ layers_all_gather_overlap = [
"qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad" "qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad"
] ]
if bool(int(os.getenv("NVTE_UB_FP8_RS", "0"))): layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
fp8_buf += ["proj_fprop", "fc2_fprop"]
# Default overlap methods for layers # Default overlap methods for layers
methods = { methods = {
"ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"], "ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
"pipeline":["proj_fprop", "fc2_fprop"], "pipeline":["proj_fprop", "fc2_fprop"],
"bulk":["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], "bulk":["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
} }
layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
# AG-RS overlap pairs of layers forming a tensor-parallel block # AG-RS overlap pairs of layers forming a tensor-parallel block
ag_rs_pairs = {"qkv_fprop":"proj_fprop", "fc1_fprop":"fc2_fprop"} ag_rs_pairs = {"qkv_fprop":"proj_fprop", "fc1_fprop":"fc2_fprop"}
...@@ -161,6 +159,7 @@ def initialize_ub( ...@@ -161,6 +159,7 @@ def initialize_ub(
aggregate: int = 0, aggregate: int = 0,
atomic_gemm: int = 0, atomic_gemm: int = 0,
is_reduce_scatter: int = 0, is_reduce_scatter: int = 0,
fp8_buf: bool = False,
) -> None: ) -> None:
if atomic_gemm: if atomic_gemm:
warnings.warn( warnings.warn(
...@@ -198,7 +197,7 @@ def initialize_ub( ...@@ -198,7 +197,7 @@ def initialize_ub(
sample_buffer = torch.empty( sample_buffer = torch.empty(
shape, shape,
dtype=torch.uint8 if (use_fp8 and name in fp8_buf) else dtype, dtype=torch.uint8 if (use_fp8 and fp8_buf) else dtype,
device='cuda') device='cuda')
if method == 'ring_exchange': if method == 'ring_exchange':
ub_obj = tex.UbufP2PCommOverlap( ub_obj = tex.UbufP2PCommOverlap(
...@@ -232,14 +231,17 @@ def initialize_ub( ...@@ -232,14 +231,17 @@ def initialize_ub(
for name in (methods["ring_exchange"]+methods["pipeline"]+methods["bulk"]): for name in (methods["ring_exchange"]+methods["pipeline"]+methods["bulk"]):
if ub_cfgs is not None and name in ub_cfgs: if ub_cfgs is not None and name in ub_cfgs:
ub_cfg = ub_cfgs[name] ub_cfg = ub_cfgs[name]
method = ub_cfg["method"] if "method" in ub_cfg else get_method(name) method = ub_cfg.get("method", get_method(name))
num_sm = ub_cfg["num_sm"] if "num_sm" in ub_cfg else 16 num_sm = ub_cfg.get("num_sm", 16)
cga_size = ub_cfg["cga_size"] if "cga_size" in ub_cfg else 2 cga_size = ub_cfg.get("cga_size", 2)
num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 4 num_splits = ub_cfg.get("num_splits", 4)
set_sm_margin = ub_cfg["set_sm_margin"] if "set_sm_margin" in ub_cfg else 0 set_sm_margin = ub_cfg.get("set_sm_margin", 0)
aggregate = ub_cfg["aggregate"] if "aggregate" in ub_cfg else 0 aggregate = ub_cfg.get("aggregate", 0)
atomic_gemm = ub_cfg["atomic_gemm"] if "atomic_gemm" in ub_cfg else 0 atomic_gemm = ub_cfg.get("atomic_gemm", 0)
is_reduce_scatter = 1 if name in layers_reduce_scatter_overlap else 0 is_reduce_scatter = 1 if name in layers_reduce_scatter_overlap else 0
# Support FP8 userbuffer when (1) AllGather and (2) FP8-GEMM output ReduceScatter
fp8_buf = ((name in layers_all_gather_overlap) or
(ub_cfg.get("fp8_buf", False) and name in methods["pipeline"]))
add_ub( add_ub(
name, name,
method, method,
...@@ -250,6 +252,7 @@ def initialize_ub( ...@@ -250,6 +252,7 @@ def initialize_ub(
aggregate, aggregate,
atomic_gemm, atomic_gemm,
is_reduce_scatter, is_reduce_scatter,
fp8_buf,
) )
else: else:
method = get_method(name) method = get_method(name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment