Unverified Commit bdf1afee authored by Jaemin Choi's avatar Jaemin Choi Committed by GitHub
Browse files

Use arguments instead of env vars for TP comm overlap (#649)



* Pass knobs for TP comm overlap instead of env vars
Signed-off-by: default avatarJaemin Choi <jaeminc@nvidia.com>

* Comment out debugging print
Signed-off-by: default avatarJaemin Choi <jaeminc@nvidia.com>

* Remove docstring
Signed-off-by: default avatarJaemin Choi <jaeminc@nvidia.com>

* Remove debugging output
Signed-off-by: default avatarJaemin Choi <jaeminc@nvidia.com>

---------
Signed-off-by: default avatarJaemin Choi <jaeminc@nvidia.com>
Co-authored-by: default avatarJaemin Choi <jaeminc@nvidia.com>
parent a174985b
...@@ -95,17 +95,6 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -95,17 +95,6 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
} }
_ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options());
const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC");
const char *env_q = std::getenv("NVTE_UB_ATOMIC_GEMM_RS");
if (rank == 0 && env_p != nullptr && env_q != nullptr && env_q[0] == '1') {
if (env_p[0] == '1')
printf("!! Using reducescatter2_userbuff_strided_atomic\n");
else if (env_p[0] == '2')
printf("!! Using reducescatter2_userbuff_strided_multiatomic\n");
else
printf("!! Using reducescatter2_userbuff_strided\n");
}
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { for (int i = 0; i < std::min(num_max_streams, num_splits); i++) {
cudaStream_t stream; cudaStream_t stream;
......
...@@ -257,6 +257,12 @@ class TransformerLayer(torch.nn.Module): ...@@ -257,6 +257,12 @@ class TransformerLayer(torch.nn.Module):
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True, qkv_weight_interleaved: bool = True,
ub_tp_comm_overlap: bool = False, ub_tp_comm_overlap: bool = False,
ub_bulk_wgrad: bool = True,
ub_bulk_dgrad: bool = True,
ub_split_ag: bool = True,
ub_split_rs: bool = True,
ub_atomic_gemm_ag: bool = False,
ub_atomic_gemm_rs: bool = False,
bias: bool = True, bias: bool = True,
activation: str = 'gelu', activation: str = 'gelu',
normalization: str = "LayerNorm", normalization: str = "LayerNorm",
...@@ -274,21 +280,18 @@ class TransformerLayer(torch.nn.Module): ...@@ -274,21 +280,18 @@ class TransformerLayer(torch.nn.Module):
self.window_size = window_size self.window_size = window_size
self.window_size = check_set_window_size(self_attn_mask_type, self.window_size) self.window_size = check_set_window_size(self_attn_mask_type, self.window_size)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
ub_tp_comm_overlap = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_OVERLAP", "1"))) ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad
ub_bulk_wgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_WGRAD", "1"))) ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad
ub_bulk_dgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_DGRAD", "1"))) ub_split_ag = ub_tp_comm_overlap and ub_split_ag
ub_split_ag = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_AG", "1"))) ub_split_rs = ub_tp_comm_overlap and ub_split_rs
ub_split_rs = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_RS", "1"))) ub_atomic_gemm_rs = ub_tp_comm_overlap and ub_atomic_gemm_rs
ub_atomic_gemm_rs = (ub_tp_comm_overlap
and bool(int(os.getenv("NVTE_UB_ATOMIC_GEMM_RS", "0"))))
assert ( assert (
not (ub_split_rs and ub_atomic_gemm_rs) not (ub_split_rs and ub_atomic_gemm_rs)
), "Only one type of RS overlap NVTE_UB_SPLIT_RS/NVTE_UB_ATOMIC_GEMM_RS should be enabled." ), "Only one type of RS overlap ub_split_rs/ub_atomic_gemm_rs should be enabled."
ub_atomic_gemm_ag = (ub_tp_comm_overlap ub_atomic_gemm_ag = ub_tp_comm_overlap and ub_atomic_gemm_ag
and bool(int(os.getenv("NVTE_UB_ATOMIC_GEMM_AG", "0"))))
assert ( assert (
not (ub_split_ag and ub_atomic_gemm_ag) not (ub_split_ag and ub_atomic_gemm_ag)
), "Only one type of AG overlap NVTE_UB_SPLIT_AG/NVTE_UB_ATOMIC_GEMM_AG should be enabled." ), "Only one type of AG overlap ub_split_ag/ub_atomic_gemm_ag should be enabled."
if ub_atomic_gemm_rs or ub_atomic_gemm_ag: if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
warnings.warn( warnings.warn(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment