"vscode:/vscode.git/clone" did not exist on "c6cece77686b0c548043dd52feb9c345b5ae0a68"
Unverified Commit 64126aa8 authored by Youngeun Kwon's avatar Youngeun Kwon Committed by GitHub
Browse files

Improving communication overlap for the case of multi kernel queue usage (#1308)



* draft implementation
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>

* compile error fix
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>

* fix compile error
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>

* remove print
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Edit comments
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* edit the bulk-overlap test case
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add version guard
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add runtime version guard
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>

* fix the version guard
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>

---------
Signed-off-by: default avatarYoungeun Kwon <youngeunk@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 09519718
...@@ -209,19 +209,39 @@ def test_atomic_gemm_overlaps(ag_type, rs_type, p2p, fp8_out): ...@@ -209,19 +209,39 @@ def test_atomic_gemm_overlaps(ag_type, rs_type, p2p, fp8_out):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"comm_type,fp8", "comm_type, fp8, connections",
[ [
("AG", False), ("AG", False, 1),
("RS", False), ("RS", False, 1),
("RS", True), ("RS", True, 1),
("AG", False, 8),
("RS", False, 8),
("RS", True, 8),
],
ids=[
"ALL-GATHER - BF16 - 1 connections",
"REDUCE-SCATTER - BF16 - 1 connections",
"REDUCE-SCATTER - FP8 - 1 connections",
"ALL-GATHER - BF16 - 8 connections",
"REDUCE-SCATTER - BF16 - 8 connections",
"REDUCE-SCATTER - FP8 - 8 connections",
], ],
ids=[" ALL-GATHER - BF16 ", " REDUCE-SCATTER - BF16 ", " REDUCE-SCATTER - FP8 "],
) )
def test_bulk_overlaps(comm_type, fp8): def test_bulk_overlaps(comm_type, fp8, connections):
""" """
Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm.
""" """
_run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) if connections == 8:
if torch.cuda.get_device_properties(0).major != 9:
pytest.skip(
"CUDA_DEVICE_MAX_CONNECTIONS=8 test only applies to devices with compute capability"
" 9.0 (HOPPER ARCH)."
)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
_run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
else:
_run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False)
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -90,6 +90,23 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl ...@@ -90,6 +90,23 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
cudaEventCreateWithFlags(&_stop_compute, 0); cudaEventCreateWithFlags(&_stop_compute, 0);
cudaEventCreateWithFlags(&_start_comm, 0); cudaEventCreateWithFlags(&_start_comm, 0);
cudaEventCreateWithFlags(&_stop_comm, 0); cudaEventCreateWithFlags(&_stop_comm, 0);
/*
Defining the launcher order between the communication and GEMM kernels
using Fast Dependent Launch when CUDA_DEVICE_MAX_CONNECTIONS>1.
The event is used to schedule the communication kernel before the GEMM.
This is needed only for Hopper, which uses persistent CTA execution.
*/
int max_connection = transformer_engine::getenv<int>("CUDA_DEVICE_MAX_CONNECTIONS", 8);
int runtime_version = 0;
cudaRuntimeGetVersion(&runtime_version);
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, 0);
if (runtime_version >= 12030 && deviceProp.major == 9 && max_connection > 1) {
cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming);
} else {
_comm_launch_event = 0;
}
} }
CommOverlapCore::~CommOverlapCore() { CommOverlapCore::~CommOverlapCore() {
...@@ -97,6 +114,7 @@ CommOverlapCore::~CommOverlapCore() { ...@@ -97,6 +114,7 @@ CommOverlapCore::~CommOverlapCore() {
cudaEventDestroy(_start_comm); cudaEventDestroy(_start_comm);
cudaEventDestroy(_stop_compute); cudaEventDestroy(_stop_compute);
cudaEventDestroy(_start_compute); cudaEventDestroy(_start_compute);
if (_comm_launch_event) cudaEventDestroy(_comm_launch_event);
if (_atomic_gemm) cudaFree(_counter.dptr()); if (_atomic_gemm) cudaFree(_counter.dptr());
...@@ -168,7 +186,8 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper ...@@ -168,7 +186,8 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper
// Communication: AG and RS // Communication: AG and RS
int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size
if (comm_type == CommOverlapType::AG) { if (comm_type == CommOverlapType::AG) {
allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm); allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm,
(cudaEvent_t)_comm_launch_event);
} else { } else {
if (_ubuf.element_size() == 1) { if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized); assert(_ubuf_scale_inv_initialized);
...@@ -178,13 +197,18 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper ...@@ -178,13 +197,18 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper
assert(rs_output.element_size() == 2); assert(rs_output.element_size() == 2);
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr()); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0, reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0,
comm_elements, _ub_comm, _stream_comm); comm_elements, _ub_comm, _stream_comm,
(cudaEvent_t)_comm_launch_event);
} else { } else {
reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm); reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm,
(cudaEvent_t)_comm_launch_event);
} }
} }
assert(pre_gelu_out.numel() == 0); assert(pre_gelu_out.numel() == 0);
// When the kernel launch order is defined, enforce the GEMM kernel launch to wait for the communication kernel launch
if (_comm_launch_event)
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _comm_launch_event, 0));
nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb, nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb,
grad, workspace.data(), accumulate, use_split_accumulator, _math_sms, grad, workspace.data(), accumulate, use_split_accumulator, _math_sms,
stream_main); stream_main);
......
...@@ -1366,6 +1366,28 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1366,6 +1366,28 @@ __global__ void __launch_bounds__(MAX_THREADS)
cfg.attrs = attribute_ub; \ cfg.attrs = attribute_ub; \
cfg.numAttrs = comm->sm_arch >= 9 ? 2 : 1; cfg.numAttrs = comm->sm_arch >= 9 ? 2 : 1;
#if (CUDART_VERSION >= 12030)
#define ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) \
attribute_ub[2].id = cudaLaunchAttributeLaunchCompletionEvent; \
attribute_ub[2].val.launchCompletionEvent.event = comm_launch_event;
#define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 3
#else
#define ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event)
#define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 2
#endif
#define SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, threads, stream, comm_launch_event) \
cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \
cudaLaunchAttribute attribute_ub[NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH] = {}; \
ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) \
attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \
attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \
attribute_ub[1].val.clusterDim.y = 1; \
attribute_ub[1].val.clusterDim.z = 1; \
attribute_ub[0].id = cudaLaunchAttributeCooperative; \
cfg.attrs = attribute_ub; \
cfg.numAttrs = NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH;
#define callranks_ag(x) \ #define callranks_ag(x) \
if (ar_nvsize == x) { \ if (ar_nvsize == x) { \
int arg1 = op - NVTE_MAX_OPS, \ int arg1 = op - NVTE_MAX_OPS, \
...@@ -1753,7 +1775,8 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler ...@@ -1753,7 +1775,8 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler
} }
void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, void allgather2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) { communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event) {
const int op = userbuffers_allreduceop_nonsharp2; const int op = userbuffers_allreduceop_nonsharp2;
const int ar_firstgpu = const int ar_firstgpu =
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
...@@ -1766,11 +1789,20 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int ...@@ -1766,11 +1789,20 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int
int warps = comm->threads / 32; int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize; if (warps < ar_nvsize) warps = ar_nvsize;
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); if (comm_launch_event) {
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_agMC(2) callranks_agMC(4) callranks_agMC(8)
} else {
callranks_ag(2) callranks_ag(4) callranks_ag(8)
}
} else { } else {
callranks_ag(2) callranks_ag(4) callranks_ag(8) SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_agMC(2) callranks_agMC(4) callranks_agMC(8)
} else {
callranks_ag(2) callranks_ag(4) callranks_ag(8)
}
} }
} }
...@@ -1790,7 +1822,8 @@ void allgather2_userbuff_inplace_sliced(const int handler, const int offset, con ...@@ -1790,7 +1822,8 @@ void allgather2_userbuff_inplace_sliced(const int handler, const int offset, con
} }
void reducescatter2_userbuff_inplace(const int handler, const int offset, const int elements, void reducescatter2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) { communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event) {
const int op = userbuffers_allreduceop_nonsharp2; const int op = userbuffers_allreduceop_nonsharp2;
const int ar_firstgpu = const int ar_firstgpu =
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
...@@ -1803,17 +1836,26 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const ...@@ -1803,17 +1836,26 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const
int warps = comm->threads / 32; int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize; if (warps < ar_nvsize) warps = ar_nvsize;
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); if (comm_launch_event) {
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8)
} else {
callranks_rs(2) callranks_rs(4) callranks_rs(8)
}
} else { } else {
callranks_rs(2) callranks_rs(4) callranks_rs(8) SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8)
} else {
callranks_rs(2) callranks_rs(4) callranks_rs(8)
}
} }
} }
void reducescatter2_userbuff_stridedoutput(void *output, const int handler, const int offset, void reducescatter2_userbuff_stridedoutput(void *output, const int handler, const int offset,
const int rowelements, const int colelements, const int rowelements, const int colelements,
const int strideelements, communicator *comm, const int strideelements, communicator *comm,
cudaStream_t stream) { cudaStream_t stream, cudaEvent_t comm_launch_event) {
const int elements = rowelements * colelements; const int elements = rowelements * colelements;
const int op = userbuffers_allreduceop_nonsharp2; const int op = userbuffers_allreduceop_nonsharp2;
const int ar_firstgpu = const int ar_firstgpu =
...@@ -1827,23 +1869,35 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons ...@@ -1827,23 +1869,35 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons
int warps = comm->threads / 32; int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize; if (warps < ar_nvsize) warps = ar_nvsize;
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); if (comm_launch_event) {
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8)
} else {
callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8)
}
} else { } else {
callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8)
} else {
callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8)
}
} }
} }
void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements, void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) { communicator *comm, cudaStream_t stream,
reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream); cudaEvent_t comm_launch_event) {
reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream,
comm_launch_event);
} }
template <typename fp8type> template <typename fp8type>
void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler, void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler,
const int offset, const int rowelements, const int offset, const int rowelements,
const int colelements, const int strideelements, const int colelements, const int strideelements,
communicator *comm, cudaStream_t stream) { communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event) {
const int elements = rowelements * colelements; const int elements = rowelements * colelements;
const int op = userbuffers_allreduceop_nonsharp2; const int op = userbuffers_allreduceop_nonsharp2;
const int ar_firstgpu = const int ar_firstgpu =
...@@ -1857,33 +1911,43 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const ...@@ -1857,33 +1911,43 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
int warps = comm->threads / 32; int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize; if (warps < ar_nvsize) warps = ar_nvsize;
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); if (comm_launch_event) {
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
} else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
}
} }
template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e5m2>( template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements, void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream); const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event);
template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>( template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements, void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream); const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event);
template <typename fp8type> template <typename fp8type>
void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset, void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset,
const int elements, communicator *comm, cudaStream_t stream) { const int elements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event) {
reducescatter2_userbuff_stridedoutput_fp8<fp8type>(output, scale, handler, offset, elements, 1, 0, reducescatter2_userbuff_stridedoutput_fp8<fp8type>(output, scale, handler, offset, elements, 1, 0,
comm, stream); comm, stream, comm_launch_event);
} }
template void reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(void *output, float *scale, template void reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(void *output, float *scale,
const int handler, const int offset, const int handler, const int offset,
const int elements, communicator *comm, const int elements, communicator *comm,
cudaStream_t stream); cudaStream_t stream,
cudaEvent_t comm_launch_event);
template void reducescatter2_userbuff_fp8<__nv_fp8_e4m3>(void *output, float *scale, template void reducescatter2_userbuff_fp8<__nv_fp8_e4m3>(void *output, float *scale,
const int handler, const int offset, const int handler, const int offset,
const int elements, communicator *comm, const int elements, communicator *comm,
cudaStream_t stream); cudaStream_t stream,
cudaEvent_t comm_launch_event);
template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>( template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements, void *output, float *scale, const int handler, const int offset, const int rowelements,
......
...@@ -213,7 +213,8 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * ...@@ -213,7 +213,8 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
// for TP-parallelism, only single node is implemented // for TP-parallelism, only single node is implemented
void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, void allgather2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0); communicator *comm, cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
/* /*
each Rank input is each Rank input is
allgather2_userbuff_inplace: offset+myrank*elements allgather2_userbuff_inplace: offset+myrank*elements
...@@ -228,21 +229,26 @@ for(int slice=0;slice<ncslices;slice++) ...@@ -228,21 +229,26 @@ for(int slice=0;slice<ncslices;slice++)
allgather2_userbuff_inplace(hndl,offset, elements*nslices,comm,stream); allgather2_userbuff_inplace(hndl,offset, elements*nslices,comm,stream);
*/ */
void reducescatter2_userbuff_inplace(const int handler, const int offset, const int elements, void reducescatter2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0); communicator *comm, cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements, void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0); communicator *comm, cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
void reducescatter2_userbuff_stridedoutput(void *output, const int handler, const int offset, void reducescatter2_userbuff_stridedoutput(void *output, const int handler, const int offset,
const int rowelements, const int colelements, const int rowelements, const int colelements,
const int strideelements, communicator *comm, const int strideelements, communicator *comm,
cudaStream_t stream = 0); cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
template <typename fp8type> template <typename fp8type>
void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler, void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler,
const int offset, const int rowelements, const int offset, const int rowelements,
const int colelements, const int strideelements, const int colelements, const int strideelements,
communicator *comm, cudaStream_t stream = 0); communicator *comm, cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
template <typename fp8type> template <typename fp8type>
void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset, void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset,
const int elements, communicator *comm, cudaStream_t stream = 0); const int elements, communicator *comm, cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
template <typename fp8type> template <typename fp8type>
void reducescatter2_userbuff_strided_atomic_fp8(void *output, float *scale, const int handler, void reducescatter2_userbuff_strided_atomic_fp8(void *output, float *scale, const int handler,
const int offset, const int rowelements, const int offset, const int rowelements,
......
...@@ -62,7 +62,7 @@ class CommOverlapCore { ...@@ -62,7 +62,7 @@ class CommOverlapCore {
bool _ubuf_scale_inv_initialized{false}; bool _ubuf_scale_inv_initialized{false};
std::vector<cudaStream_t> _stream_compute; std::vector<cudaStream_t> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm; cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event;
public: public:
CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment