"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "10cceae9cc7b0bc453565a76f87602b1c824ea19"
Unverified Commit 8c0a0c93 authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

DGRAD_RS UB overlap Bug fixes (#1004)



* DGRAD_RS UB overlap Bug fixes
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e39674b9
...@@ -1890,11 +1890,18 @@ template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>( ...@@ -1890,11 +1890,18 @@ template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements, void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in, const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream); const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e4m3>( template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements, void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in, const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream); const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
__global__ void kuserbuffers_pullsend(int myrank, int peer, int *send_id, int *flagptr) { __global__ void kuserbuffers_pullsend(int myrank, int peer, int *send_id, int *flagptr) {
atomicAdd_system(flagptr, 1); atomicAdd_system(flagptr, 1);
......
...@@ -233,7 +233,9 @@ def initialize_ub( ...@@ -233,7 +233,9 @@ def initialize_ub(
wgrad_name = name.replace("dgrad", "wgrad") wgrad_name = name.replace("dgrad", "wgrad")
assert wgrad_name not in ub_cfgs assert wgrad_name not in ub_cfgs
layers_reduce_scatter_overlap.remove(wgrad_name) layers_reduce_scatter_overlap.remove(wgrad_name)
layers_all_gather_overlap.remove(name)
layers_reduce_scatter_overlap.append(name) layers_reduce_scatter_overlap.append(name)
methods["pipeline"].append(name)
for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]: for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]:
if ub_cfgs is not None and name in ub_cfgs: if ub_cfgs is not None and name in ub_cfgs:
......
...@@ -184,7 +184,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -184,7 +184,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
out=ln_out_fp8, out=ln_out_fp8,
) )
ln_out = ln_out_fp8 ln_out = torch.empty_like(ln_out_fp8)
else: else:
ln_out_total = tex.cast_to_fp8( ln_out_total = tex.cast_to_fp8(
ln_out_total, ln_out_total,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment