Unverified Commit 958e1889 authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

Atomic gemm and FP8 Reduce Scatter (#449)



* Initial commit
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Repro for RS output mismatch with Single GEMM + Split pipelined RS
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* minor changes for AG->GEMM pipelined overlap
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Add Atomic Gemm cublasApi attributes and initial implementation of AG->Atomic GEMM
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* AtomicGemm+RS functional with workaround
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* add amax update to layernorm_linear for FP8 unit test accuracy
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Enable reducescatter2_userbuff_strided variants
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fix
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* AG+AtomicGemm overlap functional but gemm doesnt overlap with comm
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Add userbuffers_sendrecv kernel variants
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* TransformerLayer API changes to enable AtomicGemm+RS overlap
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Code cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Code cleanup2
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [UB] AllGather Atomic GEMM overlap using userbuffer_sendrecv kernels
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Code cleanup + bug fix for multiatomic sendrecv kernel
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fixes
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [UB] Add shuffling for better AG AtomicGEMM overlap
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fix for AG AtomicGemm overlap
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fix for multiAtomicAG and singleAtomicAG
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Use chunk_i+1 as recv_chunk for multiatomic_AG with shuffling
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Launch AtomicGEMM after first-chunk AG
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Rebase to main
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Add FP8 ReduceScatter kernels, AtomicGEMM+FP8 RS not functional
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Revert "Add FP8 ReduceScatter kernels, AtomicGEMM+FP8 RS not functional"

This reverts commit 80a47a76355440cd5fb4314c96fe9fda632d87f9.
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Add support for NVLS-MC and FP8 Reduce Scatter
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fix
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Atomic and Multiatomic FP8 RS functional
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Remove debug print
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* UB comm initialization hang fix
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Code cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Create new GEMM API for Atomic GEMM
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* CI ready
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* more fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* license
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fix
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Revert NVLS-MC
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Check cu* versions for running atomic gemms
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Add experimental warning
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Better wording
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add warning to c api
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix wording
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent be67f219
...@@ -506,7 +506,7 @@ def test_export_gemm( ...@@ -506,7 +506,7 @@ def test_export_gemm(
self.fp8_tensor_weight, self.fp8_tensor_weight,
self.weights_type) self.weights_type)
ret = fp8_gemm( ret, _ = fp8_gemm(
weight_fp8, weight_fp8,
self.meta_weight.scale_inv, self.meta_weight.scale_inv,
self.fp8_tensor_weight, self.fp8_tensor_weight,
...@@ -1323,7 +1323,7 @@ def test_export_gemm_layernorm( ...@@ -1323,7 +1323,7 @@ def test_export_gemm_layernorm(
self.fp8_tensor_weight, self.fp8_tensor_weight,
self.weights_type) self.weights_type)
ret = fp8_gemm( ret, _ = fp8_gemm(
weight_fp8, weight_fp8,
self.meta_weight.scale_inv, self.meta_weight.scale_inv,
self.fp8_tensor_weight, self.fp8_tensor_weight,
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <transformer_engine/logging.h> #include <transformer_engine/logging.h>
#include <transformer_engine/gemm.h> #include <transformer_engine/gemm.h>
#include <cuda.h>
#include <cublasLt.h> #include <cublasLt.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#include "../common.h" #include "../common.h"
...@@ -50,6 +51,10 @@ void cublas_gemm(const Tensor *inputA, ...@@ -50,6 +51,10 @@ void cublas_gemm(const Tensor *inputA,
bool accumulate, bool accumulate,
bool use_split_accumulator, bool use_split_accumulator,
int math_sm_count, int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const Tensor *inputCounter,
cudaStream_t stream cudaStream_t stream
) { ) {
void *A = inputA->data.dptr; void *A = inputA->data.dptr;
...@@ -63,6 +68,10 @@ void cublas_gemm(const Tensor *inputA, ...@@ -63,6 +68,10 @@ void cublas_gemm(const Tensor *inputA,
void *bias_ptr = inputBias->data.dptr; void *bias_ptr = inputBias->data.dptr;
const bool bias = bias_ptr != nullptr; const bool bias = bias_ptr != nullptr;
void *pre_gelu_out = outputPreGelu->data.dptr; void *pre_gelu_out = outputPreGelu->data.dptr;
void *counter = nullptr;
if (inputCounter != nullptr) {
counter = inputCounter->data.dptr;
}
const bool gelu = pre_gelu_out != nullptr; const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype); is_fp8_dtype(inputB->data.dtype);
...@@ -223,6 +232,27 @@ void cublas_gemm(const Tensor *inputA, ...@@ -223,6 +232,27 @@ void cublas_gemm(const Tensor *inputA,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue, sizeof(epilogue))); &epilogue, sizeof(epilogue)));
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205
if (counter != nullptr) {
if (m_split == 0) m_split=1;
if (n_split == 0) n_split=1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS,
&m_split, sizeof(m_split)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS,
&n_split, sizeof(n_split)));
if (gemm_producer) {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER,
&counter, sizeof(counter)));
} else {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER,
&counter, sizeof(counter)));
}
}
#endif
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
...@@ -254,7 +284,6 @@ void cublas_gemm(const Tensor *inputA, ...@@ -254,7 +284,6 @@ void cublas_gemm(const Tensor *inputA,
workspaceSize, workspaceSize,
stream)); /* stream */ stream)); /* stream */
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc));
...@@ -320,5 +349,82 @@ void nvte_cublas_gemm(const NVTETensor A, ...@@ -320,5 +349,82 @@ void nvte_cublas_gemm(const NVTETensor A,
wspace->data.shape[0], wspace->data.shape[0],
accumulate, use_split_accumulator, accumulate, use_split_accumulator,
math_sm_count, math_sm_count,
0,
0,
false,
nullptr,
stream);
}
void nvte_cublas_atomic_gemm(const NVTETensor A,
const NVTETensor B,
NVTETensor D,
const NVTETensor bias,
NVTETensor pre_gelu_out,
bool transa,
bool transb,
bool grad,
NVTETensor workspace,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const NVTETensor counter,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_atomic_gemm);
int cudart_version;
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&cudart_version));
NVTE_CHECK(cudart_version >= 12020, "Cuda version 12.2 is required for atomic gemm.");
NVTE_CHECK(cublasLtGetVersion() >= 120205, "Cublas version 12.2.5 is required for atomic gemm.");
using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor*>(A);
const Tensor *inputB = reinterpret_cast<const Tensor*>(B);
Tensor *outputD = reinterpret_cast<Tensor*>(D);
const Tensor *biasTensor = reinterpret_cast<const Tensor*>(bias);
Tensor *outputGelu = reinterpret_cast<Tensor*>(pre_gelu_out);
const Tensor *inputCounter = reinterpret_cast<const Tensor*>(counter);
Tensor *wspace = reinterpret_cast<Tensor*>(workspace);
const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
int lda, ldb, ldd;
if (transa && !transb) { // TN
lda = k;
ldb = k;
ldd = m;
} else if (!transa && !transb) { // NN
lda = m;
ldb = k;
ldd = m;
} else if (!transa && transb) { // NT
lda = m;
ldb = n;
ldd = m;
} else { // TT
NVTE_ERROR("TT layout not allowed.");
}
cublas_gemm(inputA,
inputB,
outputD,
biasTensor,
outputGelu,
m, n, k,
lda, ldb, ldd,
(transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N,
grad, wspace->data.dptr,
wspace->data.shape[0],
accumulate, use_split_accumulator,
math_sm_count,
m_split,
n_split,
gemm_producer,
inputCounter,
stream); stream);
} }
...@@ -54,6 +54,52 @@ void nvte_cublas_gemm(const NVTETensor A, ...@@ -54,6 +54,52 @@ void nvte_cublas_gemm(const NVTETensor A,
cudaStream_t stream cudaStream_t stream
); );
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
*
* \warning Cublas atomic gemm uses a beta API and is not tested for all use cases.
*
* Computes:
* - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors
* - `D = AB + bias` if `pre_gelu_out` is empty and `bias` is not empty
* - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors
*
* \param[in] A The A matrix.
* \param[in] B The B matrix.
* \param[in,out] D Output matrix.
* \param[in] bias Bias tensor.
* \param[in,out] pre_gelu_out Output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of the
* gradient computation.
* \param[out] workspace Workspace tensor.
* \param[in] accumulate Whether to accumulate the result into the D matrix.
* \param[in] use_split_accumulator Whether to use split accumulator in the FP8 GEMM.
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] m_split Number of chunks/splits along m-dimension for Atomic GEMM.
* \param[in] n_split Number of chunks/splits along n-dimension for Atomic GEMM.
* \param[in] gemm_producer Whether Atomic GEMM is the producer or consumer.
* \param[in,out] counter counter[chunk_i]=0 indicates chunk_i has been produced.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cublas_atomic_gemm(const NVTETensor A,
const NVTETensor B,
NVTETensor D,
const NVTETensor bias,
NVTETensor pre_gelu_out,
bool transa,
bool transb,
bool grad,
NVTETensor workspace,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const NVTETensor counter,
cudaStream_t stream
);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -2262,6 +2262,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2262,6 +2262,8 @@ class MultiheadAttention(torch.nn.Module):
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False, ub_split_rs: bool = False,
ub_split_ag: bool = False, ub_split_ag: bool = False,
ub_atomic_gemm_rs: bool = False,
ub_atomic_gemm_ag: bool = False,
bias: bool = True, bias: bool = True,
normalization: str = "LayerNorm", normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
...@@ -2342,6 +2344,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2342,6 +2344,7 @@ class MultiheadAttention(torch.nn.Module):
ub_bulk_dgrad=ub_bulk_dgrad, ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag, ub_split_ag=ub_split_ag,
normalization=normalization, normalization=normalization,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
else: else:
...@@ -2372,6 +2375,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2372,6 +2375,7 @@ class MultiheadAttention(torch.nn.Module):
ub_bulk_dgrad=ub_bulk_dgrad, ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag, ub_split_ag=ub_split_ag,
normalization=normalization, normalization=normalization,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
else: else:
...@@ -2418,6 +2422,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2418,6 +2422,8 @@ class MultiheadAttention(torch.nn.Module):
parallel_mode="row" if set_parallel_mode else None, parallel_mode="row" if set_parallel_mode else None,
ub_split_rs=ub_split_rs, ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag, ub_split_ag=ub_split_ag,
ub_atomic_gemm_rs=ub_atomic_gemm_rs,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
......
...@@ -91,22 +91,40 @@ def fp8_gemm( ...@@ -91,22 +91,40 @@ def fp8_gemm(
assert ub is not None, 'ub object is None!' assert ub is not None, 'ub object is None!'
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG: if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap fn = ub.bulk_overlap
args = tuple(args + (1,)) extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (1, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS: elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap fn = ub.bulk_overlap
args = tuple(args + (0,)) extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (0, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG: elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag fn = ub.split_overlap_ag
extra_output_tensor = ( extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor empty_tensor if extra_output_tensor is None else extra_output_tensor
) )
args = tuple(args + (extra_output_tensor,)) args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG:
fn = ub.atomic_gemm_overlap_ag
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS: elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS:
fn = ub.split_overlap_rs fn = ub.split_overlap_rs
assert ( assert (
extra_output_tensor is not None extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS requires extra output tensor' ), 'SPLIT_PIPELINED_RS requires extra output tensor'
args = tuple(args + (True, extra_output_tensor,)) args = tuple(args + (True, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS:
fn = ub.atomic_gemm_overlap_rs
assert (
extra_output_tensor is not None
), 'ATOMIC_GEMM_RS requires extra output tensor'
args = tuple(args + (True, extra_output_tensor,))
_ = fn(*args) _ = fn(*args)
return out, gelu_input return out, gelu_input
...@@ -195,10 +213,10 @@ def gemm( ...@@ -195,10 +213,10 @@ def gemm(
assert ub is not None, 'ub object is None!' assert ub is not None, 'ub object is None!'
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG: if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap fn = ub.bulk_overlap
args = tuple(args + (1,)) args = tuple(args + (1, empty_tensor))
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS: elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap fn = ub.bulk_overlap
args = tuple(args + (0,)) args = tuple(args + (0, empty_tensor))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG: elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag fn = ub.split_overlap_ag
extra_output_tensor = ( extra_output_tensor = (
......
...@@ -4,17 +4,18 @@ ...@@ -4,17 +4,18 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "userbuffers/userbuffers.h"
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_fp8.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <torch/cuda.h> #include <torch/cuda.h>
#include <torch/custom_class.h> #include <torch/custom_class.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <torch/types.h> #include <torch/types.h>
#include "userbuffers/userbuffers.h"
#define HALF_BYTES 2 #define HALF_BYTES 2
#define UB_MAX_SM 32 #define UB_MAX_SM 32
...@@ -28,6 +29,7 @@ ...@@ -28,6 +29,7 @@
} \ } \
} while (0) } while (0)
using namespace torch::indexing;
namespace ubuf { namespace ubuf {
enum class COMM_TYPE { RS = 0, AG = 1 }; enum class COMM_TYPE { RS = 0, AG = 1 };
...@@ -36,11 +38,16 @@ enum class UBOverlapAlgo { ...@@ -36,11 +38,16 @@ enum class UBOverlapAlgo {
BULK_OVERLAP_AG = 0, BULK_OVERLAP_AG = 0,
BULK_OVERLAP_RS = 1, BULK_OVERLAP_RS = 1,
SPLIT_PIPELINED_AG = 2, SPLIT_PIPELINED_AG = 2,
SPLIT_PIPELINED_RS = 3 SPLIT_PIPELINED_RS = 3,
ATOMIC_GEMM_RS = 4,
ATOMIC_GEMM_AG = 5
}; };
struct UbufCommOverlap : torch::CustomClassHolder { struct UbufBase {
communicator *_ub_comm; static inline communicator *_ub_comm{nullptr};
static inline bool comm_created{false};
};
struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
int _tp_id; int _tp_id;
int _tp_size; int _tp_size;
int _num_splits; int _num_splits;
...@@ -49,24 +56,53 @@ struct UbufCommOverlap : torch::CustomClassHolder { ...@@ -49,24 +56,53 @@ struct UbufCommOverlap : torch::CustomClassHolder {
void *_ubuf_ptr; void *_ubuf_ptr;
torch::Tensor _ubuf; torch::Tensor _ubuf;
torch::Tensor output_tensor; torch::Tensor output_tensor;
torch::Tensor _ubuf_scale_inv;
bool _ubuf_scale_inv_initialized;
torch::Tensor counter;
torch::Tensor _empty_tensor;
at::cuda::CUDAStream _stream_comm = at::cuda::getStreamFromPool(true); at::cuda::CUDAStream _stream_comm = at::cuda::getStreamFromPool(true);
std::vector<at::cuda::CUDAStream> _stream_compute; std::vector<at::cuda::CUDAStream> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _start_d2dcopy, _start_comm, _stop_comm; cudaEvent_t _start_compute, _stop_compute, _start_d2dcopy, _start_comm, _stop_comm;
int comm_sms;
int cga_size;
int use_ce;
UbufCommOverlap(torch::Tensor sample, int rank, int tp_size, int num_comm_sm, int comm_cga_size, UbufCommOverlap(torch::Tensor sample, int rank, int tp_size, int num_comm_sm, int comm_cga_size,
int num_splits, bool set_sm_margin, int num_max_streams) { int num_splits, bool set_sm_margin, int num_max_streams,
torch::Tensor empty_tensor) {
// Initialize userbuf communicator // Initialize userbuf communicator
if (!comm_created) {
if (rank == 0) {
printf("!!! [UB] Create UbufCommOverlap Communicator\n");
}
create_communicator_grouped2(&_ub_comm, 1, 1, tp_size, 1); create_communicator_grouped2(&_ub_comm, 1, 1, tp_size, 1);
_ub_comm->use_ce = 0; comm_created = true;
_ub_comm->sms = num_comm_sm; }
_ub_comm->cga_size = comm_cga_size; use_ce = 0;
comm_sms = num_comm_sm;
cga_size = comm_cga_size;
_empty_tensor = empty_tensor;
// Allocate and register extra userbuffers // Allocate and register extra userbuffers
int ubuf_bytes = sample.numel() * sample.element_size(); int ubuf_bytes = sample.numel() * sample.element_size();
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes, _ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, true); _ub_comm, true);
if (rank == 0) {
printf("!!! [UB] Register UBuf %d\n", _ub_reg);
}
_ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options());
const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC");
const char *env_q = std::getenv("NVTE_UB_ATOMIC_GEMM_RS");
if (rank == 0 && env_p != nullptr && env_q != nullptr && env_q[0] == '1') {
if (env_p[0] == '1')
printf("!! Using reducescatter2_userbuff_strided_atomic\n");
else if (env_p[0] == '2')
printf("!! Using reducescatter2_userbuff_strided_multiatomic\n");
else
printf("!! Using reducescatter2_userbuff_strided\n");
}
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { for (int i = 0; i < std::min(num_max_streams, num_splits); i++) {
cudaStream_t stream; cudaStream_t stream;
...@@ -78,6 +114,7 @@ struct UbufCommOverlap : torch::CustomClassHolder { ...@@ -78,6 +114,7 @@ struct UbufCommOverlap : torch::CustomClassHolder {
_num_splits = num_splits; _num_splits = num_splits;
_tp_size = tp_size; _tp_size = tp_size;
_tp_id = (rank % tp_size); _tp_id = (rank % tp_size);
_ubuf_scale_inv_initialized = false;
// Set the number of SMs for GEMM with margin // Set the number of SMs for GEMM with margin
cudaDeviceProp prop; cudaDeviceProp prop;
...@@ -85,6 +122,9 @@ struct UbufCommOverlap : torch::CustomClassHolder { ...@@ -85,6 +122,9 @@ struct UbufCommOverlap : torch::CustomClassHolder {
_math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount; _math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount;
output_tensor = torch::Tensor(); output_tensor = torch::Tensor();
auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA);
counter = torch::zeros({num_splits * 2}, counter_options);
counter.index_put_({Slice(None, num_splits)}, 1);
// CUDA event creation // CUDA event creation
cudaEventCreateWithFlags(&_start_compute, 0); cudaEventCreateWithFlags(&_start_compute, 0);
cudaEventCreateWithFlags(&_stop_compute, 0); cudaEventCreateWithFlags(&_stop_compute, 0);
...@@ -97,13 +137,17 @@ struct UbufCommOverlap : torch::CustomClassHolder { ...@@ -97,13 +137,17 @@ struct UbufCommOverlap : torch::CustomClassHolder {
** Bulk GEMM + COMM ** Bulk GEMM + COMM
** This function assumes the communication input is pre-copied to _ubuf ** This function assumes the communication input is pre-copied to _ubuf
*/ */
std::vector<at::Tensor> bulk_overlap( std::vector<at::Tensor>
at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, bulk_overlap(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, transformer_engine::DType A_type, bool transa, at::Tensor B,
int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, transformer_engine::DType B_type,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, bool transb, at::Tensor D, at::Tensor D_scale, transformer_engine::DType D_type,
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, at::Tensor D_amax, at::Tensor bias, transformer_engine::DType bias_type,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, int comm_type) { at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, int comm_type, at::Tensor rs_output) {
_ub_comm->use_ce = use_ce;
_ub_comm->sms = comm_sms;
_ub_comm->cga_size = cga_size;
// Get the current userbuf offset // Get the current userbuf offset
char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.data_ptr()); char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size
...@@ -121,15 +165,30 @@ struct UbufCommOverlap : torch::CustomClassHolder { ...@@ -121,15 +165,30 @@ struct UbufCommOverlap : torch::CustomClassHolder {
if (_comm_type == COMM_TYPE::AG) { if (_comm_type == COMM_TYPE::AG) {
allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, (cudaStream_t)_stream_comm); allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, (cudaStream_t)_stream_comm);
} else if (_comm_type == COMM_TYPE::RS) { } else if (_comm_type == COMM_TYPE::RS) {
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
comm_elements *= 2;
float *scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
assert(rs_output.numel() == _ubuf.numel() / _tp_size);
assert(rs_output.size(0) == _ubuf.size(0) / _tp_size);
assert(rs_output.element_size() == 2);
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, scale_inv_ptr, _ub_reg, 0,
comm_elements, _ub_comm,
(cudaStream_t)_stream_comm);
} else {
reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm,
(cudaStream_t)_stream_comm); (cudaStream_t)_stream_comm);
}
} else { } else {
NVTE_ERROR("Not supported communication type."); NVTE_ERROR("Not supported communication type.");
} }
if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; if (A_scale_inverse.numel())
A_scale_inverse = A_scale_inverse[A_fp8_tensor];
if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; if (B_scale_inverse.numel())
B_scale_inverse = B_scale_inverse[B_fp8_tensor];
assert(pre_gelu_out.numel() == 0); assert(pre_gelu_out.numel() == 0);
te_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, D, D_scale, te_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, D, D_scale,
...@@ -147,6 +206,117 @@ struct UbufCommOverlap : torch::CustomClassHolder { ...@@ -147,6 +206,117 @@ struct UbufCommOverlap : torch::CustomClassHolder {
return {D, output_tensor}; return {D, output_tensor};
} // bulk_overlap } // bulk_overlap
/*
** Split FPROP GEMM + ReduceScatter
*/
void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type,
at::Tensor D_amax, at::Tensor bias,
transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, bool gemm_overlap,
at::Tensor rs_output) {
_ub_comm->use_ce = use_ce;
_ub_comm->sms = comm_sms;
_ub_comm->cga_size = cga_size;
// Get GEMM dimensions
int m = A.size(0);
int k = A.size(1);
int n = B.size(0);
int m_chunk = m / _num_splits;
int workspace_size_chunk = workspaceSize / _stream_compute.size();
// Get input, output, and workspace data pointers
char *input_a_chunk_ptr = reinterpret_cast<char *>(A.data_ptr());
char *output_buf_chunk_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
int *counter_ptr = reinterpret_cast<int *>(counter.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
int ori_sms = _ub_comm->sms;
// Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm));
for (int i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _stop_comm, 0));
}
if (A_scale_inverse.numel())
A_scale_inverse = A_scale_inverse[A_fp8_tensor];
if (B_scale_inverse.numel())
B_scale_inverse = B_scale_inverse[B_fp8_tensor];
assert(pre_gelu_out.numel() == 0);
torch::Tensor input_a = torch::from_blob(input_a_chunk_ptr, {m, k}, A.options());
torch::Tensor output_d = torch::from_blob(output_buf_chunk_ptr, {n, m}, _ubuf.options());
// torch::zeros({n, m}, _ubuf.options());
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[0]);
te_atomic_gemm(input_a, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb,
output_d, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms, _num_splits /*m_split*/, 0 /*n_split*/, true /*gemm_producer*/,
counter);
for (int i = 0; i < _num_splits; i++) {
const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC");
if (env_p != nullptr && env_p[0] == '1') {
if (i == _num_splits - 1) {
_ub_comm->sms = UB_MAX_SM;
}
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits,
&counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm);
} else {
reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m,
_num_splits, &counter_ptr[i], _ub_comm,
(cudaStream_t)_stream_comm);
}
} else if (env_p != nullptr && env_p[0] == '2') {
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e4m3>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits,
counter_ptr, _ub_comm, (cudaStream_t)_stream_comm);
} else {
reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n,
m, _num_splits, counter_ptr, _ub_comm,
(cudaStream_t)_stream_comm);
}
break;
} else {
consumer(counter_ptr, i, (cudaStream_t)_stream_comm);
// if (i == _num_splits-1) {
// _ub_comm->sms = UB_MAX_SM;
// }
reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm);
}
rs_output_ptr += m_chunk * rs_output.element_size();
}
_ub_comm->sms = ori_sms;
CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0]));
CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0));
at::cuda::setCurrentCUDAStream(stream_main);
return;
} // split_overlap_rs
/* /*
** Split FPROP GEMM + ReduceScatter ** Split FPROP GEMM + ReduceScatter
*/ */
...@@ -160,6 +330,9 @@ struct UbufCommOverlap : torch::CustomClassHolder { ...@@ -160,6 +330,9 @@ struct UbufCommOverlap : torch::CustomClassHolder {
size_t workspaceSize, bool accumulate, bool use_split_accumulator, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
bool gemm_overlap, at::Tensor rs_output) { bool gemm_overlap, at::Tensor rs_output) {
// Get GEMM dimensions // Get GEMM dimensions
_ub_comm->use_ce = use_ce;
_ub_comm->sms = comm_sms;
_ub_comm->cga_size = cga_size;
int m = A.size(0); int m = A.size(0);
int k = A.size(1); int k = A.size(1);
int n = B.size(0); int n = B.size(0);
...@@ -174,7 +347,6 @@ struct UbufCommOverlap : torch::CustomClassHolder { ...@@ -174,7 +347,6 @@ struct UbufCommOverlap : torch::CustomClassHolder {
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr()); char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr()); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
int ubuf_offset = 0;
int ori_sms = _ub_comm->sms; int ori_sms = _ub_comm->sms;
// Catch up the default torch stream // Catch up the default torch stream
...@@ -184,9 +356,11 @@ struct UbufCommOverlap : torch::CustomClassHolder { ...@@ -184,9 +356,11 @@ struct UbufCommOverlap : torch::CustomClassHolder {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
} }
if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; if (A_scale_inverse.numel())
A_scale_inverse = A_scale_inverse[A_fp8_tensor];
if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; if (B_scale_inverse.numel())
B_scale_inverse = B_scale_inverse[B_fp8_tensor];
assert(pre_gelu_out.numel() == 0); assert(pre_gelu_out.numel() == 0);
...@@ -223,10 +397,19 @@ struct UbufCommOverlap : torch::CustomClassHolder { ...@@ -223,10 +397,19 @@ struct UbufCommOverlap : torch::CustomClassHolder {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Communication chunk // Communication chunk
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm);
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size, reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size,
m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm); m_chunk, n, m, _ub_comm,
(cudaStream_t)_stream_comm);
}
rs_output_ptr += m_chunk * _ubuf.element_size(); rs_output_ptr += m_chunk * rs_output.element_size();
} }
int last_compute_stream_id = int last_compute_stream_id =
(_num_splits + _stream_compute.size() - 1) % _stream_compute.size(); (_num_splits + _stream_compute.size() - 1) % _stream_compute.size();
...@@ -236,9 +419,17 @@ struct UbufCommOverlap : torch::CustomClassHolder { ...@@ -236,9 +419,17 @@ struct UbufCommOverlap : torch::CustomClassHolder {
// Last communication chunk with max SM // Last communication chunk with max SM
_ub_comm->sms = UB_MAX_SM; _ub_comm->sms = UB_MAX_SM;
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk,
n, m, _ub_comm, (cudaStream_t)_stream_comm);
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg,
(_num_splits - 1) * output_chunk_size, m_chunk, n, m, (_num_splits - 1) * output_chunk_size, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm); _ub_comm, (cudaStream_t)_stream_comm);
}
} else { } else {
for (int i = 0; i < _num_splits; i++) { for (int i = 0; i < _num_splits; i++) {
torch::Tensor input_a_chunk = torch::Tensor input_a_chunk =
...@@ -259,13 +450,21 @@ struct UbufCommOverlap : torch::CustomClassHolder { ...@@ -259,13 +450,21 @@ struct UbufCommOverlap : torch::CustomClassHolder {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Communication chunk. Uses MAX_SM at the last chunk // Communication chunk. Uses MAX_SM at the last chunk
if (i == _num_splits-1) { if (i == _num_splits - 1) {
_ub_comm->sms = UB_MAX_SM; _ub_comm->sms = UB_MAX_SM;
} }
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm);
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size,
m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm); m_chunk, n, m, _ub_comm,
(cudaStream_t)_stream_comm);
rs_output_ptr += m_chunk * _ubuf.element_size(); }
rs_output_ptr += m_chunk * rs_output.element_size();
input_a_chunk_ptr += input_a_chunk_size * B.element_size(); input_a_chunk_ptr += input_a_chunk_size * B.element_size();
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
} }
...@@ -283,6 +482,12 @@ struct UbufCommOverlap : torch::CustomClassHolder { ...@@ -283,6 +482,12 @@ struct UbufCommOverlap : torch::CustomClassHolder {
return; return;
} // split_overlap_rs } // split_overlap_rs
void set_ubuf_scale_inv(const torch::Tensor &scale_inv) {
_ubuf_scale_inv = scale_inv;
_ubuf_scale_inv_initialized = true;
}
bool is_fp8_ubuf() { return (_ubuf.element_size() == 1); }
/* /*
** Helper function to copy input to _ubuf ** Helper function to copy input to _ubuf
*/ */
...@@ -311,7 +516,8 @@ struct UbufCommOverlap : torch::CustomClassHolder { ...@@ -311,7 +516,8 @@ struct UbufCommOverlap : torch::CustomClassHolder {
torch::Tensor &get_ubuf_output(int comm_type) { torch::Tensor &get_ubuf_output(int comm_type) {
char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.data_ptr()); char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
COMM_TYPE _comm_type = static_cast<COMM_TYPE>(comm_type); COMM_TYPE _comm_type = static_cast<COMM_TYPE>(comm_type);
if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) NVTE_ERROR("Invalid comm_type"); if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS)
NVTE_ERROR("Invalid comm_type");
if (_comm_type == COMM_TYPE::RS) if (_comm_type == COMM_TYPE::RS)
ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size();
int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size;
...@@ -321,35 +527,51 @@ struct UbufCommOverlap : torch::CustomClassHolder { ...@@ -321,35 +527,51 @@ struct UbufCommOverlap : torch::CustomClassHolder {
} }
}; // UbufCommOverlap }; // UbufCommOverlap
struct UbufP2PCommOverlap : torch::CustomClassHolder { struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
communicator *_ub_comm;
int _tp_id; int _tp_id;
int _tp_size; int _tp_size;
int _ub_reg; int _ub_reg;
int _next_rank, _prev_rank, _rank, _rank_round_tp; int _next_rank, _prev_rank, _rank, _rank_round_tp;
int _aggregate2; int _aggregate2;
int _math_sms; int _math_sms;
int _self_chunk_id;
void *_ubuf_ptr; void *_ubuf_ptr;
torch::Tensor _ubuf; torch::Tensor _ubuf;
torch::Tensor counter;
torch::Tensor _empty_tensor;
std::vector<torch::Tensor> _ubufs; std::vector<torch::Tensor> _ubufs;
at::cuda::CUDAStream _stream_send = at::cuda::getStreamFromPool(true); at::cuda::CUDAStream _stream_send = at::cuda::getStreamFromPool(true);
at::cuda::CUDAStream _stream_recv = at::cuda::getStreamFromPool(true); at::cuda::CUDAStream _stream_recv = at::cuda::getStreamFromPool(true);
std::vector<at::cuda::CUDAStream> _stream_compute; std::vector<at::cuda::CUDAStream> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _stop_send, _stop_recv; cudaEvent_t _start_compute, _stop_compute, _stop_send, _stop_recv;
int use_ce;
int sms;
int cga_size;
UbufP2PCommOverlap(torch::Tensor sample, int rank, int tp_size, bool aggregate2, UbufP2PCommOverlap(torch::Tensor sample, int rank, int tp_size, int num_comm_sm,
int num_max_streams) { int comm_cga_size, bool set_sm_margin, bool aggregate2, int num_max_streams,
torch::Tensor empty_tensor) {
// Initialize userbuf communicator // Initialize userbuf communicator
if (!comm_created) {
if (rank == 0) {
printf("!!! [UB] Create UbufP2PCommOverlap Communicator\n");
}
create_communicator_grouped2(&_ub_comm, 1, 1, tp_size, 1); create_communicator_grouped2(&_ub_comm, 1, 1, tp_size, 1);
_ub_comm->use_ce = 1; comm_created = true;
_ub_comm->sms = 1; }
_ub_comm->cga_size = 1; use_ce = 1;
sms = 1;
cga_size = 1;
_empty_tensor = empty_tensor;
// Create workspace tensor with userbuffer // Create workspace tensor with userbuffer
int ubuf_bytes = sample.numel() * sample.element_size(); int ubuf_bytes = sample.numel() * sample.element_size();
int ubuf_chunk_bytes = ubuf_bytes / tp_size; int ubuf_chunk_bytes = ubuf_bytes / tp_size;
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes, _ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, true); _ub_comm, true);
if (rank == 0) {
printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg);
}
_ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options());
// Create tensor chunks for easy management // Create tensor chunks for easy management
...@@ -372,7 +594,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder { ...@@ -372,7 +594,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder {
// Set the number of SMs for GEMM with margin // Set the number of SMs for GEMM with margin
cudaDeviceProp prop; cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0); cudaGetDeviceProperties(&prop, 0);
_math_sms = prop.multiProcessorCount; _math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount;
_tp_size = tp_size; _tp_size = tp_size;
_aggregate2 = aggregate2; _aggregate2 = aggregate2;
...@@ -383,6 +605,26 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder { ...@@ -383,6 +605,26 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder {
_next_rank = (tp_size + rank + 1) % tp_size + _rank_round_tp; _next_rank = (tp_size + rank + 1) % tp_size + _rank_round_tp;
_prev_rank = (tp_size + rank + -1) % tp_size + _rank_round_tp; _prev_rank = (tp_size + rank + -1) % tp_size + _rank_round_tp;
auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA);
counter = torch::zeros({tp_size * 2}, counter_options);
counter.index_put_({Slice(None, tp_size)}, 1);
_self_chunk_id = _tp_id;
const char *env_p = std::getenv("NVTE_AG_P2P_ATOMIC");
if (rank == 0 && env_p != nullptr) {
if (env_p[0] == '1') {
printf("!!userbuffers_sendrecv_atomic\n");
} else if (env_p[0] == '2') {
printf("!!userbuffers_sendrecv_multiatomic\n");
} else if (env_p[0] == '3') {
printf("!!userbuffers_sendrecv_multiatomic_shuffle\n");
_self_chunk_id = 0;
} else {
printf("!!userbuffers_sendrecv\n");
}
}
counter.index_put_({_self_chunk_id}, 0);
// CUDA event creation // CUDA event creation
cudaEventCreateWithFlags(&_start_compute, 0); cudaEventCreateWithFlags(&_start_compute, 0);
cudaEventCreateWithFlags(&_stop_compute, 0); cudaEventCreateWithFlags(&_stop_compute, 0);
...@@ -390,11 +632,144 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder { ...@@ -390,11 +632,144 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder {
cudaEventCreateWithFlags(&_stop_recv, 0); cudaEventCreateWithFlags(&_stop_recv, 0);
} }
/*
** Split AllGather + AtomicGEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is
*needed to have AG outputs
** in each rank to be in the contiguous memory space after all ring exchange
*phases.
*/
torch::Tensor atomic_gemm_overlap_ag(
at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse,
int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias,
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) {
_ub_comm->use_ce = use_ce;
_ub_comm->sms = sms;
_ub_comm->cga_size = cga_size;
// Get GEMM dimensions between TN and NN input layouts
const int m = (transa) ? A.size(0) : A.size(1);
const int k = (transa) ? A.size(1) : A.size(0);
const int n_chunk = _ubufs[0].size(0);
// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
// Get output and workspace data pointers
char *output_ptr = reinterpret_cast<char *>(D.data_ptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
int *counter_ptr = reinterpret_cast<int *>(counter.data_ptr());
int workspace_size_chunk = workspaceSize / _stream_compute.size();
if (A_scale_inverse.numel())
A_scale_inverse = A_scale_inverse[A_fp8_tensor];
if (B_scale_inverse.numel())
B_scale_inverse = B_scale_inverse[B_fp8_tensor];
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
assert(pre_gelu_out.numel() == 0);
// Catch up the default torch stream
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0));
torch::Tensor output_chunk = torch::from_blob(output_ptr, {_ubuf.size(0), m}, D.options());
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options());
for (int i = 0; i < _tp_size; i++) {
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
// have the AG output in all ranks to be contiguous after the ring
// exchanges
int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size;
int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
if (i < _tp_size - 1) {
const char *env_p = std::getenv("NVTE_AG_P2P_ATOMIC");
if (env_p != nullptr && env_p[0] == '1') {
userbuffers_sendrecv_atomic(_ub_reg, _ub_reg, send_offset, recv_offset, comm_bytes,
_ub_comm, _next_rank, _prev_rank, &counter_ptr[recv_chunk_id],
(cudaStream_t)_stream_recv);
} else if (env_p != nullptr && env_p[0] == '2') {
if (i == 0) {
userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes,
_ub_comm, _next_rank, _prev_rank, _tp_size,
counter_ptr, false, (cudaStream_t)_stream_recv);
}
} else if (env_p != nullptr && env_p[0] == '3') {
if (i == 0) {
userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes,
_ub_comm, _next_rank, _prev_rank, _tp_size,
counter_ptr, true, (cudaStream_t)_stream_recv);
}
} else {
// P2P communication
// userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset,
// comm_bytes, _ub_comm,
// _next_rank, (cudaStream_t)_stream_send);
// userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset,
// comm_bytes, _ub_comm,
// _prev_rank, (cudaStream_t)_stream_recv);
// CHECK_CUDA(cudaEventRecord(_stop_recv,
// (cudaStream_t)_stream_recv));
// CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send,
// _stop_recv, 0));
userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, recv_offset, comm_bytes, _ub_comm,
_next_rank, _prev_rank, (cudaStream_t)_stream_recv);
producer(counter_ptr, recv_chunk_id, (cudaStream_t)_stream_recv);
}
if (i == 0) {
at::cuda::setCurrentCUDAStream(_stream_compute[0]);
te_atomic_gemm(A, A_scale_inverse, A_type, transa, _ubuf, B_scale_inverse, B_type, transb,
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms, 0, _tp_size, false, counter);
}
} else {
// GEMM
// userbuffers_send_multiatomic(_ub_reg, 0, _ub_reg, 0, comm_bytes,
// _ub_comm,
// _next_rank, _tp_size, comm_bytes, comm_bytes,
// (cudaStream_t)_stream_send);
// userbuffers_recv_multiatomic(_ub_reg, 0, _ub_reg, 0, comm_bytes,
// _ub_comm,
// _prev_rank, _tp_size, counter_ptr,
// (cudaStream_t)_stream_recv);
if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
}
}
}
for (int i = 0; i < _tp_size; i++) {
if (i != _self_chunk_id) {
consumer(counter_ptr, i, (cudaStream_t)_stream_compute[0]);
}
}
at::cuda::setCurrentCUDAStream(stream_main);
CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
return D;
} // split_overlap_ag
/* /*
** Split AllGather + GEMM using P2P communication ** Split AllGather + GEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is
*outputs *needed to have AG outputs
** in each rank to be in the contiguous memory space after all ring exchange phases. ** in each rank to be in the contiguous memory space after all ring exchange
*phases.
*/ */
torch::Tensor split_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, torch::Tensor split_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B, transformer_engine::DType A_type, bool transa, at::Tensor B,
...@@ -405,6 +780,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder { ...@@ -405,6 +780,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder {
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize, bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { bool accumulate, bool use_split_accumulator, at::Tensor B_copy) {
_ub_comm->use_ce = use_ce;
_ub_comm->sms = sms;
_ub_comm->cga_size = cga_size;
// Get GEMM dimensions between TN and NN input layouts // Get GEMM dimensions between TN and NN input layouts
const int m = (transa) ? A.size(0) : A.size(1); const int m = (transa) ? A.size(0) : A.size(1);
const int k = (transa) ? A.size(1) : A.size(0); const int k = (transa) ? A.size(1) : A.size(0);
...@@ -419,9 +797,11 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder { ...@@ -419,9 +797,11 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder {
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr()); char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
int workspace_size_chunk = workspaceSize / _stream_compute.size(); int workspace_size_chunk = workspaceSize / _stream_compute.size();
if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; if (A_scale_inverse.numel())
A_scale_inverse = A_scale_inverse[A_fp8_tensor];
if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; if (B_scale_inverse.numel())
B_scale_inverse = B_scale_inverse[B_fp8_tensor];
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
...@@ -506,9 +886,10 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder { ...@@ -506,9 +886,10 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0));
for (int i = 0; i < _tp_size; i++) { for (int i = 0; i < _tp_size; i++) {
// Set the userbuffer id. Buffer under send is the input for the current GEMM chunk // Set the userbuffer id. Buffer under send is the input for the current
// The initial input chunk is stored _ubuf[rank]. This is to have the AG output in all ranks // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
// to be contiguous after the ring exchanges // have the AG output in all ranks to be contiguous after the ring
// exchanges
int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size; int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size;
int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size; int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size;
int send_offset = comm_bytes * send_chunk_id; int send_offset = comm_bytes * send_chunk_id;
...@@ -581,7 +962,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder { ...@@ -581,7 +962,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder {
torch::Tensor get_ubuf_output(int comm_type) { torch::Tensor get_ubuf_output(int comm_type) {
char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.data_ptr()); char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
COMM_TYPE _comm_type = static_cast<COMM_TYPE>(comm_type); COMM_TYPE _comm_type = static_cast<COMM_TYPE>(comm_type);
if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) NVTE_ERROR("Invalid comm_type"); if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS)
NVTE_ERROR("Invalid comm_type");
if (_comm_type == COMM_TYPE::RS) if (_comm_type == COMM_TYPE::RS)
ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size();
int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size;
......
...@@ -179,6 +179,32 @@ void te_gemm(at::Tensor A, ...@@ -179,6 +179,32 @@ void te_gemm(at::Tensor A,
int math_sm_count int math_sm_count
); );
void te_atomic_gemm(at::Tensor A,
at::Tensor A_scale_inverse,
transformer_engine::DType A_type,
bool transa,
at::Tensor B,
at::Tensor B_scale_inverse,
transformer_engine::DType B_type,
bool transb,
at::Tensor D,
at::Tensor D_scale,
transformer_engine::DType D_type,
at::Tensor D_amax,
at::Tensor bias,
transformer_engine::DType bias_type,
at::Tensor pre_gelu_out,
bool grad,
at::Tensor workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
at::Tensor counter
);
void fused_cast_transpose(at::Tensor input, void fused_cast_transpose(at::Tensor input,
at::Tensor scale, at::Tensor scale,
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "extensions.h" #include "extensions.h"
void te_gemm(at::Tensor A, void te_gemm(at::Tensor A,
at::Tensor A_scale_inverse, at::Tensor A_scale_inverse,
transformer_engine::DType A_type, transformer_engine::DType A_type,
...@@ -73,3 +74,82 @@ void te_gemm(at::Tensor A, ...@@ -73,3 +74,82 @@ void te_gemm(at::Tensor A,
math_sm_count, math_sm_count,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
} }
void te_atomic_gemm(at::Tensor A,
at::Tensor A_scale_inverse,
transformer_engine::DType A_type,
bool transa,
at::Tensor B,
at::Tensor B_scale_inverse,
transformer_engine::DType B_type,
bool transb,
at::Tensor D,
at::Tensor D_scale,
transformer_engine::DType D_type,
at::Tensor D_amax,
at::Tensor bias,
transformer_engine::DType bias_type,
at::Tensor pre_gelu_out,
bool grad,
at::Tensor workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
at::Tensor counter
) {
using namespace transformer_engine;
auto te_A = makeTransformerEngineTensor(A.data_ptr(),
{static_cast<size_t>(A.size(0)),
static_cast<size_t>(A.size(1))},
A_type, nullptr, nullptr,
A_scale_inverse.data_ptr());
auto te_B = makeTransformerEngineTensor(B.data_ptr(),
{static_cast<size_t>(B.size(0)),
static_cast<size_t>(B.size(1))},
B_type, nullptr, nullptr,
B_scale_inverse.data_ptr());
auto te_D = makeTransformerEngineTensor(D.data_ptr(),
{static_cast<size_t>(D.size(0)),
static_cast<size_t>(D.size(1))},
D_type, D_amax.data_ptr(),
D_scale.data_ptr(), nullptr);
auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), {static_cast<size_t>(bias.size(0))},
bias_type);
auto te_counter = makeTransformerEngineTensor(counter.data_ptr(),
{static_cast<size_t>(counter.size(0))},
DType::kInt32);
const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0))}
: std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0)),
static_cast<size_t>(pre_gelu_out.size(1))};
auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out.data_ptr(),
gelu_shape,
GetTransformerEngineDType(
pre_gelu_out.scalar_type()));
auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(),
{workspaceSize},
DType::kByte);
nvte_cublas_atomic_gemm(te_A.data(),
te_B.data(),
te_D.data(),
te_bias.data(),
te_pre_gelu_out.data(),
transa,
transb,
grad,
te_workspace.data(),
accumulate,
use_split_accumulator,
math_sm_count,
m_split,
n_split,
gemm_producer,
te_counter.data(),
at::cuda::getCurrentCUDAStream());
}
...@@ -91,18 +91,24 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -91,18 +91,24 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG) .value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG)
.value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS) .value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS)
.value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS) .value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS)
.value("SPLIT_PIPELINED_AG", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG); .value("SPLIT_PIPELINED_AG", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG)
.value("ATOMIC_GEMM_RS", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS)
.value("ATOMIC_GEMM_AG", ubuf::UBOverlapAlgo::ATOMIC_GEMM_AG);
py::class_<ubuf::UbufCommOverlap>(m, "UbufCommOverlap") py::class_<ubuf::UbufCommOverlap>(m, "UbufCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, int, bool, int>()) .def(py::init<torch::Tensor&, int, int, int, int, int, bool, int, torch::Tensor>())
.def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap) .def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap)
.def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs) .def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs)
.def("set_ubuf_scale_inv", &ubuf::UbufCommOverlap::set_ubuf_scale_inv)
.def("atomic_gemm_overlap_rs", &ubuf::UbufCommOverlap::atomic_gemm_overlap_rs)
.def("is_fp8_ubuf", &ubuf::UbufCommOverlap::is_fp8_ubuf)
.def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf) .def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output); .def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output);
py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap") py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap")
.def(py::init<torch::Tensor&, int, int, bool, int>()) .def(py::init<torch::Tensor&, int, int, int, int, bool, bool, int, torch::Tensor>())
.def("split_overlap_ag", &ubuf::UbufP2PCommOverlap::split_overlap_ag) .def("split_overlap_ag", &ubuf::UbufP2PCommOverlap::split_overlap_ag)
.def("atomic_gemm_overlap_ag", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag)
.def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf) .def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output); .def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output);
#else // NVTE_WITH_USERBUFFERS #else // NVTE_WITH_USERBUFFERS
......
...@@ -4,10 +4,13 @@ ...@@ -4,10 +4,13 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "userbuffers.h"
#include <assert.h> #include <assert.h>
#include <chrono>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <immintrin.h> #include <immintrin.h>
#include <iostream>
#include <math.h> #include <math.h>
#include <mpi.h> #include <mpi.h>
#include <sched.h> #include <sched.h>
...@@ -15,9 +18,6 @@ ...@@ -15,9 +18,6 @@
#include <string.h> #include <string.h>
#include <unistd.h> #include <unistd.h>
#include <x86intrin.h> #include <x86intrin.h>
#include <chrono>
#include <iostream>
#include "userbuffers.h"
static int oob_bcast(void *comm_context, void *buf, int size, int root) { static int oob_bcast(void *comm_context, void *buf, int size, int root) {
MPI_Bcast(buf, size, MPI_BYTE, root, MPI_Bcast(buf, size, MPI_BYTE, root,
...@@ -47,6 +47,17 @@ int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (co ...@@ -47,6 +47,17 @@ int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (co
} \ } \
} while (0) } while (0)
#define CUCHECK(cmd) \
do { \
CUresult retval = cmd; \
if (retval != CUDA_SUCCESS) { \
const char *error_string; \
cuGetErrorString(retval, &error_string); \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, error_string); \
exit(EXIT_FAILURE); \
} \
} while (0);
#define NVTE_UB_ERROR(x) \ #define NVTE_UB_ERROR(x) \
do { \ do { \
throw std::runtime_error(std::string(__FILE__ ":") + std::to_string(__LINE__) + \ throw std::runtime_error(std::string(__FILE__ ":") + std::to_string(__LINE__) + \
...@@ -89,12 +100,14 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode ...@@ -89,12 +100,14 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
(*comm)->push = 1; (*comm)->push = 1;
(*comm)->use_ce = 0; (*comm)->use_ce = 0;
(*comm)->cga_size = 2; (*comm)->cga_size = 2;
for (int i = 0; i < userbuffers_op_types; i++) (*comm)->basecounter[i] = 0; for (int i = 0; i < userbuffers_op_types; i++)
(*comm)->basecounter[i] = 0;
(*comm)->head = 0; (*comm)->head = 0;
(*comm)->tail = 0; (*comm)->tail = 0;
(*comm)->activeproxy = 1; (*comm)->activeproxy = 1;
(*comm)->active_nreqs = 0; (*comm)->active_nreqs = 0;
for (int i = 0; i < userbuffers_op_types; i++) (*comm)->active_req[i].active = -1; for (int i = 0; i < userbuffers_op_types; i++)
(*comm)->active_req[i].active = -1;
int ret = 0; int ret = 0;
// split communicator // split communicator
...@@ -112,8 +125,10 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode ...@@ -112,8 +125,10 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
color = 0; color = 0;
for (int n = 0; n < size; n++) { for (int n = 0; n < size; n++) {
if (n > 0 && strcmp(host_names[n - 1], host_names[n])) color++; if (n > 0 && strcmp(host_names[n - 1], host_names[n]))
if (strcmp(host_name, host_names[n]) == 0) break; color++;
if (strcmp(host_name, host_names[n]) == 0)
break;
} }
free(host_names); free(host_names);
...@@ -128,14 +143,22 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode ...@@ -128,14 +143,22 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
cpu_set_t cpuset; cpu_set_t cpuset;
CPU_ZERO(&cpuset); CPU_ZERO(&cpuset);
int core; int core;
if (mylocal == 0) core = 50; if (mylocal == 0)
if (mylocal == 1) core = 58; core = 50;
if (mylocal == 2) core = 18; if (mylocal == 1)
if (mylocal == 3) core = 26; core = 58;
if (mylocal == 4) core = 114; if (mylocal == 2)
if (mylocal == 5) core = 122; core = 18;
if (mylocal == 6) core = 82; if (mylocal == 3)
if (mylocal == 7) core = 90; core = 26;
if (mylocal == 4)
core = 114;
if (mylocal == 5)
core = 122;
if (mylocal == 6)
core = 82;
if (mylocal == 7)
core = 90;
CPU_SET(core, &cpuset); CPU_SET(core, &cpuset);
if (!getenv("NVTE_NODOUBLE")) { if (!getenv("NVTE_NODOUBLE")) {
...@@ -144,7 +167,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode ...@@ -144,7 +167,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
else else
CPU_SET(core + 128, &cpuset); CPU_SET(core + 128, &cpuset);
} }
if (getenv("NVTE_DOPIN")) pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset); if (getenv("NVTE_DOPIN"))
pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset);
if (ndev == numlocal) { // all visible devices if (ndev == numlocal) { // all visible devices
if (cur_dev != mylocal) if (cur_dev != mylocal)
...@@ -175,7 +199,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode ...@@ -175,7 +199,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
int datanodegroup_id = int datanodegroup_id =
myrank / numlocal / datanodes; // data reduction group node belongs, equals 0 for all if both myrank / numlocal / datanodes; // data reduction group node belongs, equals 0 for all if both
// pipenodes=1 and tensornodes=1 // pipenodes=1 and tensornodes=1
// mpi communicator only needed for SHARP which is always allreduce1/data-parallel // mpi communicator only needed for SHARP which is always
// allreduce1/data-parallel
MPI_Comm_split(MPI_COMM_WORLD, mylocal + numlocal * datanodegroup_id, rank, &(*comm)->comm_inter); MPI_Comm_split(MPI_COMM_WORLD, mylocal + numlocal * datanodegroup_id, rank, &(*comm)->comm_inter);
// different rails from same group are in different subcommunicators // different rails from same group are in different subcommunicators
...@@ -192,19 +217,37 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode ...@@ -192,19 +217,37 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
char *ib_dev_list; char *ib_dev_list;
int ZIONROCE = getenv("NVTE_ZIONROCE") ? atoi(getenv("NVTE_ZIONROCE")) : 0; int ZIONROCE = getenv("NVTE_ZIONROCE") ? atoi(getenv("NVTE_ZIONROCE")) : 0;
int ROCE = getenv("NVTE_ROCE") ? atoi(getenv("NVTE_ROCE")) : 0; int ROCE = getenv("NVTE_ROCE") ? atoi(getenv("NVTE_ROCE")) : 0;
if (ZIONROCE) ROCE = 1; if (ZIONROCE)
ROCE = 1;
int DGX_H100 = device_prop.major == 9; int DGX_H100 = device_prop.major == 9;
switch (mylocal) { switch (mylocal) {
case 0:ib_dev_list = "mlx5_0:1"; break; // NOLINT(*) case 0:
case 1:ib_dev_list = (char*)(DGX_H100?"mlx5_3:1":"mlx5_1:1"); break; // NOLINT(*) ib_dev_list = "mlx5_0:1";
case 2:ib_dev_list = (char*)(ZIONROCE?"mlx5_4:1":DGX_H100?"mlx5_4:1":"mlx5_2:1"); break; // NOLINT(*) break; // NOLINT(*)
case 3:ib_dev_list = (char*)(DGX_H100?"mlx5_5:1":"mlx5_3:1"); break; // NOLINT(*) case 1:
case 4:ib_dev_list = (char*)(DGX_H100?"mlx5_6:1":"mlx5_6:1"); break; // NOLINT(*) ib_dev_list = (char *)(DGX_H100 ? "mlx5_3:1" : "mlx5_1:1"); // NOLINT(*)
case 5:ib_dev_list = (char*)(DGX_H100?"mlx5_9:1":"mlx5_7:1"); break; // NOLINT(*) break; // NOLINT(*)
case 6:ib_dev_list = (char*)(ZIONROCE?"mlx5_10:1":DGX_H100?"mlx5_10:1":"mlx5_8:1"); break; // NOLINT(*) case 2:
case 7:ib_dev_list = (char*)(DGX_H100?"mlx5_11:1":"mlx5_9:1"); break; // NOLINT(*) ib_dev_list = (char *)(ZIONROCE ? "mlx5_4:1" : DGX_H100 ? "mlx5_4:1" : "mlx5_2:1"); // NOLINT(*)
default: break; break; // NOLINT(*)
case 3:
ib_dev_list = (char *)(DGX_H100 ? "mlx5_5:1" : "mlx5_3:1"); // NOLINT(*)
break; // NOLINT(*)
case 4:
ib_dev_list = (char *)(DGX_H100 ? "mlx5_6:1" : "mlx5_6:1"); // NOLINT(*)
break; // NOLINT(*)
case 5:
ib_dev_list = (char *)(DGX_H100 ? "mlx5_9:1" : "mlx5_7:1"); // NOLINT(*)
break; // NOLINT(*)
case 6:
ib_dev_list = (char *)(ZIONROCE ? "mlx5_10:1" : DGX_H100 ? "mlx5_10:1" : "mlx5_8:1"); // NOLINT(*)
break; // NOLINT(*)
case 7:
ib_dev_list = (char *)(DGX_H100 ? "mlx5_11:1" : "mlx5_9:1"); // NOLINT(*)
break; // NOLINT(*)
default:
break;
} }
(*comm)->fifo = reinterpret_cast<ub_request *>(malloc(sizeof(ub_request) * NVTE_MAX_REQUESTS)); (*comm)->fifo = reinterpret_cast<ub_request *>(malloc(sizeof(ub_request) * NVTE_MAX_REQUESTS));
...@@ -215,7 +258,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode ...@@ -215,7 +258,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
CUDACHECK(cudaMallocHost((void **)&(*comm)->hostflags, // NOLINT(*) CUDACHECK(cudaMallocHost((void **)&(*comm)->hostflags, // NOLINT(*)
(NVTE_MAX_SMS + 100) * sizeof(int))); (NVTE_MAX_SMS + 100) * sizeof(int)));
for (int i = 0; i < 100 + NVTE_MAX_SMS; i++) (*comm)->hostflags[i] = 0; for (int i = 0; i < 100 + NVTE_MAX_SMS; i++)
(*comm)->hostflags[i] = 0;
_mm_mfence(); _mm_mfence();
sleep(1); sleep(1);
...@@ -223,13 +267,16 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode ...@@ -223,13 +267,16 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
(*comm)->ibnvsize = (*comm)->nvsize; (*comm)->ibnvsize = (*comm)->nvsize;
#define NBUF 2 #define NBUF 2
#define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF) #define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF)
// peer pointers + op flags + comm buffer // peer pointers + op flags + comm buffer
CUDACHECK(cudaMalloc(&(*comm)->gpu_ptrs, LOCALSIZE)); // flags and pointers, no block data yet CUDACHECK(cudaMalloc(&(*comm)->gpu_ptrs,
LOCALSIZE)); // flags and pointers, no block data yet
CUDACHECK(cudaMemset((*comm)->gpu_ptrs, 0, LOCALSIZE)); CUDACHECK(cudaMemset((*comm)->gpu_ptrs, 0, LOCALSIZE));
CUDACHECK(cudaDeviceSynchronize()); CUDACHECK(cudaDeviceSynchronize());
register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm); // will use handler 0 register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE,
*comm); // will use handler 0
CUDACHECK(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int))); CUDACHECK(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int)));
CUDACHECK(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int))); CUDACHECK(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int)));
CUDACHECK(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int))); CUDACHECK(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int)));
...@@ -243,7 +290,6 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode ...@@ -243,7 +290,6 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
#define GPU_PAGE_MASK (~GPU_PAGE_OFFSET) #define GPU_PAGE_MASK (~GPU_PAGE_OFFSET)
CUDACHECK(cudaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE)); CUDACHECK(cudaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE));
unsigned int flag = 1; unsigned int flag = 1;
// cuPointerSetAttribute(&flag, CU_POINTER_ATTRIBUTE_SYNC_MEMOPS, (CUdeviceptr)(*comm)->flags);
CUDACHECK(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE)); CUDACHECK(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE));
(*comm)->flags = (*comm)->flags =
reinterpret_cast<int *>(((CUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK); reinterpret_cast<int *>(((CUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK);
...@@ -275,7 +321,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode ...@@ -275,7 +321,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
pthread_attr_setschedparam(&attr, &param); pthread_attr_setschedparam(&attr, &param);
if (getenv("NVTE_UBDEBUG")) if (getenv("NVTE_UBDEBUG"))
printf("%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP %dx%d TPGROUP %dx%d PIPE_ID %d/%d\n", printf("%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP %dx%d TPGROUP "
"%dx%d PIPE_ID %d/%d\n",
myrank, nranks, myrank / numlocal, myrank % numlocal, (*comm)->my_node, myrank, nranks, myrank / numlocal, myrank % numlocal, (*comm)->my_node,
(*comm)->ar_nvrank, (*comm)->my2_node, (*comm)->ar2_nvrank, (*comm)->num_nodes, (*comm)->ar_nvrank, (*comm)->my2_node, (*comm)->ar2_nvrank, (*comm)->num_nodes,
(*comm)->ar_nvsize, (*comm)->num2_nodes, (*comm)->ar2_nvsize, (*comm)->pipe_id, (*comm)->ar_nvsize, (*comm)->num2_nodes, (*comm)->ar2_nvsize, (*comm)->pipe_id,
...@@ -300,9 +347,9 @@ void destroy_communicator(communicator *comm) { ...@@ -300,9 +347,9 @@ void destroy_communicator(communicator *comm) {
} }
int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) { int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) {
if (comm->free_region > NVTE_MAX_REGIONS) return -1; if (comm->free_region > NVTE_MAX_REGIONS)
return -1;
int hndl = comm->free_region; int hndl = comm->free_region;
// printf("%d register %d size %lld\n",comm->myrank,hndl,bytes);fflush(NULL);
comm->peer_ptr[hndl] = reinterpret_cast<void **>(malloc(sizeof(void *) * (comm->nvsize))); comm->peer_ptr[hndl] = reinterpret_cast<void **>(malloc(sizeof(void *) * (comm->nvsize)));
if (alloc) { if (alloc) {
...@@ -313,25 +360,22 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * ...@@ -313,25 +360,22 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
reinterpret_cast<cudaIpcMemHandle_t *>(malloc(sizeof(cudaIpcMemHandle_t) * (comm->nvsize))); reinterpret_cast<cudaIpcMemHandle_t *>(malloc(sizeof(cudaIpcMemHandle_t) * (comm->nvsize)));
CUDACHECK(cudaIpcGetMemHandle(&memhndl[comm->nvrank], *gpubuff)); CUDACHECK(cudaIpcGetMemHandle(&memhndl[comm->nvrank], *gpubuff));
MPI_Allgather(&memhndl[comm->nvrank], sizeof(cudaIpcMemHandle_t), MPI_BYTE, memhndl, MPI_Allgather(&memhndl[comm->nvrank], sizeof(cudaIpcMemHandle_t), MPI_BYTE, memhndl,
sizeof(cudaIpcMemHandle_t), MPI_BYTE, comm->comm_intra); sizeof(cudaIpcMemHandle_t), MPI_BYTE, comm->comm_intra);
for (int i = 0; i < comm->nvsize; i++) for (int i = 0; i < comm->nvsize; i++)
if (i != comm->nvrank) if (i != comm->nvrank)
CUDACHECK(cudaIpcOpenMemHandle((void **)&(comm->peer_ptr[hndl][i]), // NOLINT(*) CUDACHECK(cudaIpcOpenMemHandle((void **)&(comm->peer_ptr[hndl][i]), // NOLINT(*)
memhndl[i], cudaIpcMemLazyEnablePeerAccess)); memhndl[i], cudaIpcMemLazyEnablePeerAccess));
comm->peer_ptr[hndl][comm->nvrank] = *gpubuff; comm->peer_ptr[hndl][comm->nvrank] = *gpubuff;
CUDACHECK(cudaDeviceSynchronize()); CUDACHECK(cudaDeviceSynchronize());
CUDACHECK( CUDACHECK(
cudaMemcpy(reinterpret_cast<char *>(comm->gpu_ptrs) + (hndl * comm->nvsize * sizeof(void *)), cudaMemcpy(reinterpret_cast<char *>(comm->gpu_ptrs) + (hndl * comm->nvsize * sizeof(void *)),
comm->peer_ptr[hndl], comm->nvsize * sizeof(void *), cudaMemcpyHostToDevice)); comm->peer_ptr[hndl], comm->nvsize * sizeof(void *), cudaMemcpyHostToDevice));
CUDACHECK(cudaDeviceSynchronize()); CUDACHECK(cudaDeviceSynchronize());
free(memhndl); free(memhndl);
comm->mem_ptr[hndl] = *gpubuff; comm->mem_ptr[hndl] = *gpubuff;
return comm->free_region++; return comm->free_region++;
} }
...@@ -352,8 +396,10 @@ int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons ...@@ -352,8 +396,10 @@ int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons
void allreduce_nonsharp_inplace(const int handler, const int offset, const int elements, void allreduce_nonsharp_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream, int op) { communicator *comm, cudaStream_t stream, int op) {
if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); if (elements < 64)
// if(comm->myrank==0) fprintf(stderr,"AR2(%d) user call launch_mode=%d\n",op,comm->launch_mode); NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
// if(comm->myrank==0) fprintf(stderr,"AR2(%d) user call
// launch_mode=%d\n",op,comm->launch_mode);
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
int blocksize = elements * 2; int blocksize = elements * 2;
int maxcredit = 0; int maxcredit = 0;
...@@ -361,19 +407,19 @@ void allreduce_nonsharp_inplace(const int handler, const int offset, const int e ...@@ -361,19 +407,19 @@ void allreduce_nonsharp_inplace(const int handler, const int offset, const int e
blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) / blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) /
comm->nblocks; // FIXME TUNING comm->nblocks; // FIXME TUNING
blocksize *= comm->alignblock; blocksize *= comm->alignblock;
if (blocksize < comm->minblock) blocksize = comm->minblock; if (blocksize < comm->minblock)
blocksize = comm->minblock;
maxcredit = (elements * 2 + blocksize - 1) / blocksize; maxcredit = (elements * 2 + blocksize - 1) / blocksize;
// if(maxcredit>4) maxcredit=4;
// if(maxcredit>4 && ar_nvsize==1) maxcredit=4;
size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit
if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize; if (blocksize > peerblock * ar_nvsize)
// blocksize=elements*2; blocksize = peerblock * ar_nvsize;
int sms = allreduce2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, comm, int sms = allreduce2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, comm,
stream, op); stream, op);
if (num_nodes > 1 && comm->launch_mode & NVTE_LAUNCH_CPU) { if (num_nodes > 1 && comm->launch_mode & NVTE_LAUNCH_CPU) {
if (!sms) return; if (!sms)
return;
comm->fifo[comm->head].optype = op; comm->fifo[comm->head].optype = op;
comm->fifo[comm->head].basecounter = comm->basecounter[op]; comm->fifo[comm->head].basecounter = comm->basecounter[op];
comm->fifo[comm->head].blocksize = blocksize; comm->fifo[comm->head].blocksize = blocksize;
...@@ -399,7 +445,8 @@ void allreduce2_userbuff_inplace(const int handler, const int offset, const int ...@@ -399,7 +445,8 @@ void allreduce2_userbuff_inplace(const int handler, const int offset, const int
void allreduce_userbuff_inplace(const int handler, const int offset, const int elements, void allreduce_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) { communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); if (elements < 64)
NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
allreduce_nonsharp_inplace(handler, offset, elements, comm, stream, allreduce_nonsharp_inplace(handler, offset, elements, comm, stream,
userbuffers_allreduceop_nonsharp); userbuffers_allreduceop_nonsharp);
return; return;
...@@ -407,7 +454,8 @@ void allreduce_userbuff_inplace(const int handler, const int offset, const int e ...@@ -407,7 +454,8 @@ void allreduce_userbuff_inplace(const int handler, const int offset, const int e
void reducescatter_userbuff_inplace(const int handler, const int offset, const int elements, void reducescatter_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) { communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); if (elements < 64)
NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
int op = userbuffers_allreduceop_nonsharp; int op = userbuffers_allreduceop_nonsharp;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
...@@ -418,17 +466,20 @@ void reducescatter_userbuff_inplace(const int handler, const int offset, const i ...@@ -418,17 +466,20 @@ void reducescatter_userbuff_inplace(const int handler, const int offset, const i
blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) / blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) /
comm->nblocks; // FIXME TUNING comm->nblocks; // FIXME TUNING
blocksize *= comm->alignblock; blocksize *= comm->alignblock;
if (blocksize < comm->minblock) blocksize = comm->minblock; if (blocksize < comm->minblock)
blocksize = comm->minblock;
maxcredit = (elements * 2 + blocksize - 1) / blocksize; maxcredit = (elements * 2 + blocksize - 1) / blocksize;
size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit
if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize; if (blocksize > peerblock * ar_nvsize)
blocksize = peerblock * ar_nvsize;
int sms = reducescatter2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, int sms = reducescatter2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize,
comm, stream, op); comm, stream, op);
if (num_nodes > 1 && comm->launch_mode & NVTE_LAUNCH_CPU) { if (num_nodes > 1 && comm->launch_mode & NVTE_LAUNCH_CPU) {
if (!sms) return; if (!sms)
return;
comm->fifo[comm->head].optype = op; comm->fifo[comm->head].optype = op;
comm->fifo[comm->head].basecounter = comm->basecounter[op]; comm->fifo[comm->head].basecounter = comm->basecounter[op];
comm->fifo[comm->head].blocksize = blocksize; comm->fifo[comm->head].blocksize = blocksize;
...@@ -448,7 +499,8 @@ void reducescatter_userbuff_inplace(const int handler, const int offset, const i ...@@ -448,7 +499,8 @@ void reducescatter_userbuff_inplace(const int handler, const int offset, const i
void allgather_userbuff_inplace(const int handler, const int offset, const int elements, void allgather_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) { communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); if (elements < 64)
NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
int op = userbuffers_allreduceop_nonsharp; int op = userbuffers_allreduceop_nonsharp;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
int blocksize = elements * 2; int blocksize = elements * 2;
...@@ -458,11 +510,13 @@ void allgather_userbuff_inplace(const int handler, const int offset, const int e ...@@ -458,11 +510,13 @@ void allgather_userbuff_inplace(const int handler, const int offset, const int e
blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) / blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) /
comm->nblocks; // FIXME TUNING comm->nblocks; // FIXME TUNING
blocksize *= comm->alignblock; blocksize *= comm->alignblock;
if (blocksize < comm->minblock) blocksize = comm->minblock; if (blocksize < comm->minblock)
blocksize = comm->minblock;
maxcredit = (elements * 2 + blocksize - 1) / blocksize; maxcredit = (elements * 2 + blocksize - 1) / blocksize;
size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit
if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize; if (blocksize > peerblock * ar_nvsize)
blocksize = peerblock * ar_nvsize;
int sms = allgather2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, comm, int sms = allgather2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, comm,
stream, op); stream, op);
......
...@@ -12,9 +12,10 @@ ...@@ -12,9 +12,10 @@
#else #else
#include <cuda_fp16.h> #include <cuda_fp16.h>
#endif #endif
#include "userbuffers.h"
#include <assert.h> #include <assert.h>
#include <cuda_fp8.h>
#include <stdio.h> #include <stdio.h>
#include "userbuffers.h"
#define MAX_THREADS 1024 #define MAX_THREADS 1024
#define TIMEOUT 200000000000ull #define TIMEOUT 200000000000ull
...@@ -28,6 +29,25 @@ ...@@ -28,6 +29,25 @@
} \ } \
} while (0) } while (0)
#define ATOMIC_CONSUMER(chunk) \
if (counters) { \
if (threadIdx.x == 0 && blockIdx.x == 0) { \
int old_val; \
while (0 != (old_val = atomicCAS(((unsigned int *)counters) + chunk, 0, 0))) { \
} \
((unsigned int *)counters)[chunk] = 1; \
asm volatile("fence.sc.gpu;\n"); \
} \
if (blockIdx.x == 0) \
__syncthreads(); \
}
#define ATOMIC_PRODUCER(chunk) \
if (counters) { \
((unsigned int *)counters)[chunk] = 0; \
asm volatile("fence.sc.gpu;\n"); \
}
template <int RANKS> template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS) __global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rw(const int op, const int flagoffset, const int firstrank, userbuffers_fp16_sum_inplace_gpu_rw(const int op, const int flagoffset, const int firstrank,
...@@ -36,8 +56,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -36,8 +56,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
__shared__ int4 *userptr[RANKS]; __shared__ int4 *userptr[RANKS];
int *flagptr, physgpu, targetgpu, *myptr; int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id; int *reduceidptr, reduce_id;
// if(blockIdx.x==0 && threadIdx.x==0) printf("%d/%d(phys %d gpustep %d firstrank %d):RRkernel(d)
// start, size %lld\n",myrank,RANKS,gpustep*myrank+firstrank,gpustep,firstrank,numlines*16ull);
if (threadIdx.x < RANKS) { if (threadIdx.x < RANKS) {
physgpu = myrank * gpustep + firstrank; physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank; targetgpu = threadIdx.x * gpustep + firstrank;
...@@ -66,7 +85,8 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -66,7 +85,8 @@ __global__ void __launch_bounds__(MAX_THREADS)
int warp = blockIdx.x + (threadIdx.x >> 5); int warp = blockIdx.x + (threadIdx.x >> 5);
int dest[RANKS]; int dest[RANKS];
#pragma unroll #pragma unroll
for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); for (int i = 0; i < RANKS; i++)
dest[i] = (i + myrank + warp) & (RANKS - 1);
__syncthreads(); __syncthreads();
for (int line = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x); line < numlines; for (int line = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x); line < numlines;
...@@ -86,7 +106,8 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -86,7 +106,8 @@ __global__ void __launch_bounds__(MAX_THREADS)
for (int i = 1; i < RANKS; i++) { for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]); half *x = reinterpret_cast<half *>(&val[i]);
#pragma unroll #pragma unroll
for (int j = 0; j < 8; j++) s[j] += x[j]; for (int j = 0; j < 8; j++)
s[j] += x[j];
} }
#pragma unroll #pragma unroll
for (int i = 0; i < RANKS; i++) { for (int i = 0; i < RANKS; i++) {
...@@ -96,7 +117,8 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -96,7 +117,8 @@ __global__ void __launch_bounds__(MAX_THREADS)
} }
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) __threadfence_system(); if (threadIdx.x == 0)
__threadfence_system();
__syncthreads(); __syncthreads();
if (threadIdx.x < RANKS) { if (threadIdx.x < RANKS) {
...@@ -111,7 +133,8 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -111,7 +133,8 @@ __global__ void __launch_bounds__(MAX_THREADS)
} }
} }
} }
if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; if (threadIdx.x == 0 && blockIdx.x == 0)
*reduceidptr = reduce_id;
} // fp16 inplace reduce kernel (Volta,Hopper) } // fp16 inplace reduce kernel (Volta,Hopper)
template <int RANKS> template <int RANKS>
...@@ -150,7 +173,8 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -150,7 +173,8 @@ __global__ void __launch_bounds__(MAX_THREADS)
int warp = blockIdx.x + (threadIdx.x >> 5); int warp = blockIdx.x + (threadIdx.x >> 5);
int dest[RANKS]; int dest[RANKS];
#pragma unroll #pragma unroll
for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); for (int i = 0; i < RANKS; i++)
dest[i] = (i + myrank + warp) & (RANKS - 1);
__syncthreads(); __syncthreads();
for (int line = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x); line < numlines; for (int line = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x); line < numlines;
...@@ -169,13 +193,15 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -169,13 +193,15 @@ __global__ void __launch_bounds__(MAX_THREADS)
for (int i = 1; i < RANKS; i++) { for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]); half *x = reinterpret_cast<half *>(&val[i]);
#pragma unroll #pragma unroll
for (int j = 0; j < 8; j++) s[j] += x[j]; for (int j = 0; j < 8; j++)
s[j] += x[j];
} }
userptr[myrank][lineoffset + line] = sum; userptr[myrank][lineoffset + line] = sum;
} }
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) __threadfence(); if (threadIdx.x == 0)
__threadfence();
__syncthreads(); __syncthreads();
if (threadIdx.x < RANKS) { if (threadIdx.x < RANKS) {
...@@ -217,7 +243,8 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -217,7 +243,8 @@ __global__ void __launch_bounds__(MAX_THREADS)
userptr[myrank][lineoffset + line + blockDim.x * dest[i]] = val[i]; userptr[myrank][lineoffset + line + blockDim.x * dest[i]] = val[i];
} }
} }
if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; if (threadIdx.x == 0 && blockIdx.x == 0)
*reduceidptr = reduce_id;
} // fp16 inplace reduce kernel (Ampere) } // fp16 inplace reduce kernel (Ampere)
template <int RANKS> template <int RANKS>
...@@ -227,18 +254,18 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -227,18 +254,18 @@ __global__ void __launch_bounds__(MAX_THREADS)
const int mylineoffset, const int totallines, const int mylineoffset, const int totallines,
void **commbuff, const int handleridx) { void **commbuff, const int handleridx) {
__shared__ int4 *userptr[RANKS]; __shared__ int4 *userptr[RANKS];
int *flagptr, physgpu, targetgpu, *myptr; volatile int *flagptr;
int physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id; int *reduceidptr, reduce_id;
int lastSM = 0;
if (threadIdx.x < RANKS) { if (threadIdx.x < RANKS) {
physgpu = myrank * gpustep + firstrank; physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank; targetgpu = threadIdx.x * gpustep + firstrank;
const int blockflagoffset = NVTE_MAX_NVLINK * 2 * blockIdx.x;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset; myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op; reduceidptr = myptr - NVTE_MAX_OPS; // +op;
reduce_id = (*reduceidptr) + 1; reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset + blockflagoffset; flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset;
myptr += blockflagoffset; if (blockIdx.x == 0)
flagptr[physgpu] = reduce_id; flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&(myptr[targetgpu]); volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]); userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
...@@ -252,11 +279,18 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -252,11 +279,18 @@ __global__ void __launch_bounds__(MAX_THREADS)
} }
} }
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) {
const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1;
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder);
if (old_val + adder == NVTE_MAX_SMS * reduce_id)
lastSM = 1;
}
int warp = blockIdx.x + (threadIdx.x >> 5); int warp = blockIdx.x + (threadIdx.x >> 5);
int dest[RANKS]; int dest[RANKS];
#pragma unroll #pragma unroll
for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); for (int i = 0; i < RANKS; i++)
dest[i] = (i + myrank + warp) & (RANKS - 1);
__syncthreads(); __syncthreads();
for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines;
...@@ -275,13 +309,15 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -275,13 +309,15 @@ __global__ void __launch_bounds__(MAX_THREADS)
for (int i = 1; i < RANKS; i++) { for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]); half *x = reinterpret_cast<half *>(&val[i]);
#pragma unroll #pragma unroll
for (int j = 0; j < 8; j++) s[j] += x[j]; for (int j = 0; j < 8; j++)
s[j] += x[j];
} }
userptr[myrank][mylineoffset + line] = sum; userptr[myrank][mylineoffset + line] = sum;
} }
if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; if (threadIdx.x == 0 && lastSM)
*reduceidptr = reduce_id;
} // fp16 inplace reduce-scatter kernel } // fp16 inplace reduce-scatter kernel
template <int RANKS> template <int RANKS>
...@@ -293,18 +329,18 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -293,18 +329,18 @@ __global__ void __launch_bounds__(MAX_THREADS)
const int skiplines, void **commbuff, const int skiplines, void **commbuff,
const int handleridx, void *outbuf) { const int handleridx, void *outbuf) {
__shared__ int4 *userptr[RANKS]; __shared__ int4 *userptr[RANKS];
int *flagptr, physgpu, targetgpu, *myptr; volatile int *flagptr;
int physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id; int *reduceidptr, reduce_id;
int lastSM = 0;
if (threadIdx.x < RANKS) { if (threadIdx.x < RANKS) {
physgpu = myrank * gpustep + firstrank; physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank; targetgpu = threadIdx.x * gpustep + firstrank;
const int blockflagoffset = NVTE_MAX_NVLINK * 2 * blockIdx.x;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset; myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op; reduceidptr = myptr - NVTE_MAX_OPS; // +op;
reduce_id = (*reduceidptr) + 1; reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset + blockflagoffset; flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset;
myptr += blockflagoffset; if (blockIdx.x == 0)
flagptr[physgpu] = reduce_id; flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&(myptr[targetgpu]); volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]); userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
...@@ -318,11 +354,18 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -318,11 +354,18 @@ __global__ void __launch_bounds__(MAX_THREADS)
} }
} }
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) {
const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1;
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder);
if (old_val + adder == NVTE_MAX_SMS * reduce_id)
lastSM = 1;
}
int warp = blockIdx.x + (threadIdx.x >> 5); int warp = blockIdx.x + (threadIdx.x >> 5);
int dest[RANKS]; int dest[RANKS];
#pragma unroll #pragma unroll
for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); for (int i = 0; i < RANKS; i++)
dest[i] = (i + myrank + warp) & (RANKS - 1);
__syncthreads(); __syncthreads();
for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines;
...@@ -341,24 +384,28 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -341,24 +384,28 @@ __global__ void __launch_bounds__(MAX_THREADS)
for (int i = 1; i < RANKS; i++) { for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]); half *x = reinterpret_cast<half *>(&val[i]);
#pragma unroll #pragma unroll
for (int j = 0; j < 8; j++) s[j] += x[j]; for (int j = 0; j < 8; j++)
s[j] += x[j];
} }
(reinterpret_cast<int4 *>(outbuf))[(line / rowlines) * skiplines + (line % rowlines)] = sum; (reinterpret_cast<int4 *>(outbuf))[(line / rowlines) * skiplines + (line % rowlines)] = sum;
} }
if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; if (threadIdx.x == 0 && lastSM)
*reduceidptr = reduce_id;
} // fp16 reduce-scatter kernel (out of place) } // fp16 reduce-scatter kernel (out of place)
#if 0
// All MC kernels here
template <int RANKS> template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS) __global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rr_ag(const int op, const int flagoffset, const int firstrank, userbuffers_fp16_sum_inplace_gpu_mc(const int op, const int flagoffset, const int firstrank,
const int myrank, const int gpustep, const int myrank, const int gpustep, const int lineoffset,
const int mylineoffset, const int totallines, const int numlines, void **commbuff, const int handleridx,
void **commbuff, const int handleridx) { float4 *mc_ptr) {
__shared__ int4 *userptr[RANKS];
int *flagptr, physgpu, targetgpu, *myptr; int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id; int *reduceidptr, reduce_id;
if (threadIdx.x < RANKS) { if (threadIdx.x < RANKS) {
physgpu = myrank * gpustep + firstrank; physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank; targetgpu = threadIdx.x * gpustep + firstrank;
...@@ -371,114 +418,322 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -371,114 +418,322 @@ __global__ void __launch_bounds__(MAX_THREADS)
flagptr[physgpu] = reduce_id; flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&(myptr[targetgpu]); volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64(); clock_t s = clock64();
while (*flag < reduce_id) {
if (clock64() - s > TIMEOUT) {
printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
break;
} }
int warp = blockIdx.x + (threadIdx.x >> 5);
int dest[RANKS];
int skipmy = 0;
#pragma unroll
for (int i = 0; i < RANKS; i++) {
int dst = (i + warp + myrank) & (RANKS - 1);
if (dst == myrank) {
skipmy++;
continue;
} }
dest[i - skipmy] = dst; reduce_id++;
} }
__syncthreads(); __syncthreads();
#define UNROLL_MC 8
const int loop_step0 = blockDim.x * gridDim.x * RANKS;
const int loop_step = loop_step0 * UNROLL_MC;
const int start_elem = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x);
const int end_elem = max(start_elem, numlines);
const int aligned_elem = ((end_elem - start_elem) / loop_step) * loop_step;
const int end_aligned = start_elem + aligned_elem;
for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; for (int line = start_elem; line < end_aligned; line += loop_step) {
line += blockDim.x * gridDim.x) { uint4 val[UNROLL_MC];
int4 val[RANKS - 1];
#pragma unroll #pragma unroll
for (int i = 0; i < RANKS - 1; i++) { for (int i = 0; i < UNROLL_MC; i++)
val[i] = userptr[dest[i]][mylineoffset + line + totallines * dest[i]]; #if defined(NVTE_UB_FP16)
asm("multimem.ld_reduce.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];"
: "=r"(val[i].x), "=r"(val[i].y), "=r"(val[i].z), "=r"(val[i].w)
: "l"(mc_ptr + (lineoffset + line + i * loop_step0))
: "memory");
#else
asm("multimem.ld_reduce.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];"
: "=r"(val[i].x), "=r"(val[i].y), "=r"(val[i].z), "=r"(val[i].w)
: "l"(mc_ptr + (lineoffset + line + i * loop_step0))
: "memory");
#endif
#pragma unroll
for (int i = 0; i < UNROLL_MC; i++)
asm volatile("multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(
mc_ptr + (lineoffset + line + i * loop_step0)),
"r"(val[i].x), "r"(val[i].y), "r"(val[i].z), "r"(val[i].w)
: "memory");
}
for (int line = end_aligned; line < end_elem; line += loop_step0) {
uint4 val;
#if defined(NVTE_UB_FP16)
asm("multimem.ld_reduce.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(mc_ptr + (lineoffset + line))
: "memory");
#else
asm("multimem.ld_reduce.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(mc_ptr + (lineoffset + line))
: "memory");
#endif
asm volatile(
"multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(mc_ptr + (lineoffset + line)),
"r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)
: "memory");
} }
#pragma unroll __syncthreads();
for (int i = 0; i < RANKS - 1; i++) { if (threadIdx.x == 0)
userptr[myrank][mylineoffset + line + totallines * dest[i]] = val[i]; __threadfence_system();
__syncthreads();
if (threadIdx.x < RANKS) {
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&myptr[targetgpu];
clock_t s = clock64();
while (*flag < reduce_id) {
if (clock64() - s > 2ull * TIMEOUT) {
printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
break;
} }
} }
if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; }
} // fp16 inplace reduce kernel (Ampere) if (threadIdx.x == 0 && blockIdx.x == 0)
*reduceidptr = reduce_id;
} // fp16 inplace reduce kernel (Hopper) MC
template <int RANKS> template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS) __global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rw_ag(const int op, const int flagoffset, const int firstrank, userbuffers_fp16_sum_inplace_gpu_mc_rs(const int op, const int flagoffset, const int firstrank,
const int myrank, const int gpustep, const int myrank, const int gpustep,
const int mylineoffset, const int totallines, const int mylineoffset, const int totallines,
void **commbuff, const int handleridx) { void **commbuff, const int handleridx, float4 *mc_ptr) {
__shared__ int4 *userptr[RANKS]; volatile int *flagptr;
int *flagptr, physgpu, targetgpu, *myptr; int physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id; int *reduceidptr, reduce_id;
int4 *localptr; uint4 *localptr = reinterpret_cast<uint4 *>(commbuff[myrank * gpustep + firstrank + handleridx]);
int lastSM = 0;
if (threadIdx.x < RANKS) { if (threadIdx.x < RANKS) {
physgpu = myrank * gpustep + firstrank; physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank; targetgpu = threadIdx.x * gpustep + firstrank;
const int blockflagoffset = NVTE_MAX_NVLINK * 2 * blockIdx.x;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset; myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op; reduceidptr = myptr - NVTE_MAX_OPS; // +op;
reduce_id = (*reduceidptr) + 1; reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset + blockflagoffset; flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset;
myptr += blockflagoffset; if (blockIdx.x == 0)
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]); flagptr[physgpu] = reduce_id;
reduce_id++; volatile int *flag = (volatile int *)&(myptr[targetgpu]);
clock_t s = clock64();
while (*flag < reduce_id) {
if (clock64() - s > TIMEOUT) {
printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
break;
} }
__syncthreads();
localptr = userptr[myrank];
int warp = blockIdx.x + (threadIdx.x >> 5);
int dest[RANKS - 1];
int skipmy = 0;
#pragma unroll
for (int i = 0; i < RANKS; i++) {
int dst = (i + warp + myrank) & (RANKS - 1);
if (dst == myrank) {
skipmy++;
continue;
} }
dest[i - skipmy] = dst;
} }
#define UNROLLAG 4
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) {
const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1;
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder);
if (old_val + adder == NVTE_MAX_SMS * reduce_id)
lastSM = 1;
}
const int loop_step0 = blockDim.x * gridDim.x; const int loop_step0 = blockDim.x * gridDim.x;
const int loop_step = loop_step0 * UNROLLAG; const int loop_step = loop_step0 * UNROLL_MC;
const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; const int start_elem = threadIdx.x + blockDim.x * blockIdx.x;
const int end_elem = max(start_elem, totallines); const int end_elem = max(start_elem, totallines);
const int aligned_elem = ((end_elem - start_elem) / loop_step) * loop_step; const int aligned_elem = ((end_elem - start_elem) / loop_step) * loop_step;
const int end_aligned = start_elem + aligned_elem; const int end_aligned = start_elem + aligned_elem;
for (int line = start_elem; line < end_aligned; line += loop_step) { for (int line = start_elem; line < end_aligned; line += loop_step) {
int4 val[UNROLLAG]; uint4 val[UNROLL_MC];
#pragma unroll
for (int i = 0; i < UNROLL_MC; i++)
#if defined(NVTE_UB_FP16)
asm("multimem.ld_reduce.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];"
: "=r"(val[i].x), "=r"(val[i].y), "=r"(val[i].z), "=r"(val[i].w)
: "l"(mc_ptr + (mylineoffset + line + i * loop_step0))
: "memory");
#else
asm("multimem.ld_reduce.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];"
: "=r"(val[i].x), "=r"(val[i].y), "=r"(val[i].z), "=r"(val[i].w)
: "l"(mc_ptr + (mylineoffset + line + i * loop_step0))
: "memory");
#endif
#pragma unroll #pragma unroll
for (int j = 0; j < UNROLLAG; j++) val[j] = localptr[mylineoffset + line + loop_step0 * j]; for (int i = 0; i < UNROLL_MC; i++)
localptr[mylineoffset + line + i * loop_step0] = val[i];
}
for (int line = end_aligned; line < end_elem; line += loop_step0) {
uint4 val;
#if defined(NVTE_UB_FP16)
asm("multimem.ld_reduce.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(mc_ptr + (mylineoffset + line))
: "memory");
#else
asm("multimem.ld_reduce.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(mc_ptr + (mylineoffset + line))
: "memory");
#endif
localptr[mylineoffset + line] = val;
}
if (threadIdx.x == 0 && lastSM)
*reduceidptr = reduce_id;
} // fp16 inplace reduce-scatter kernel MC
template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_mc_rs_oop(const int op, const int flagoffset,
const int firstrank, const int myrank,
const int gpustep, const int mylineoffset,
const int totallines, const int rowlines,
const int skiplines, void **commbuff,
const int handleridx, void *outbuf, float4 *mc_ptr) {
volatile int *flagptr;
int physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
int lastSM = 0;
if (threadIdx.x < RANKS) {
physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op;
reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset;
if (blockIdx.x == 0)
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
clock_t s = clock64();
while (*flag < reduce_id) {
if (clock64() - s > TIMEOUT) {
printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x,
threadIdx.x, reduce_id, *flag);
break;
}
}
}
__syncthreads();
if (threadIdx.x == 0) {
const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1;
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder);
if (old_val + adder == NVTE_MAX_SMS * reduce_id)
lastSM = 1;
}
const int loop_step0 = blockDim.x * gridDim.x;
const int loop_step = loop_step0 * UNROLL_MC;
const int start_elem = threadIdx.x + blockDim.x * blockIdx.x;
const int end_elem = max(start_elem, totallines);
const int aligned_elem = ((end_elem - start_elem) / loop_step) * loop_step;
const int end_aligned = start_elem + aligned_elem;
for (int line = start_elem; line < end_aligned; line += loop_step) {
uint4 val[UNROLL_MC];
#pragma unroll #pragma unroll
for (int j = 0; j < UNROLLAG; j++) for (int i = 0; i < UNROLL_MC; i++)
#if defined(NVTE_UB_FP16)
asm("multimem.ld_reduce.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];"
: "=r"(val[i].x), "=r"(val[i].y), "=r"(val[i].z), "=r"(val[i].w)
: "l"(mc_ptr + (mylineoffset + line + i * loop_step0))
: "memory");
#else
asm("multimem.ld_reduce.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];"
: "=r"(val[i].x), "=r"(val[i].y), "=r"(val[i].z), "=r"(val[i].w)
: "l"(mc_ptr + (mylineoffset + line + i * loop_step0))
: "memory");
#endif
#pragma unroll #pragma unroll
for (int i = 0; i < RANKS - 1; i++) { for (int i = 0; i < UNROLL_MC; i++)
userptr[dest[i]][mylineoffset + line + j * loop_step0] = val[j]; (reinterpret_cast<uint4 *>(outbuf))[((line + i * loop_step0) / rowlines) * skiplines +
((line + i * loop_step0) % rowlines)] = val[i];
} }
for (int line = end_aligned; line < end_elem; line += loop_step0) {
uint4 val;
#if defined(NVTE_UB_FP16)
asm("multimem.ld_reduce.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(mc_ptr + (mylineoffset + line))
: "memory");
#else
asm("multimem.ld_reduce.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(mc_ptr + (mylineoffset + line))
: "memory");
#endif
reinterpret_cast<uint4 *> (outbuf)[(line / rowlines) * skiplines + (line % rowlines)] = val;
} }
for (int line = end_aligned; line < end_elem; line += loop_step0) { if (threadIdx.x == 0 && lastSM)
int4 sum = localptr[mylineoffset + line]; *reduceidptr = reduce_id;
} // fp16 reduce-scatter kernel (out of place) fp16 MC
template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_mc_ag(const int op, const int flagoffset, const int firstrank,
const int myrank, const int gpustep,
const int mylineoffset, const int totallines,
void **commbuff, const int handleridx, uint4 *mc_ptr) {
volatile int *flagptr;
int physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
uint4 *localptr = reinterpret_cast<uint4 *>(commbuff[myrank * gpustep + firstrank + handleridx]);
if (threadIdx.x < RANKS) {
physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op;
reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset;
}
__syncthreads();
const int loop_step0 = blockDim.x * gridDim.x;
const int loop_step = loop_step0 * UNROLL_MC;
const int start_elem = threadIdx.x + blockDim.x * blockIdx.x;
const int end_elem = max(start_elem, totallines);
const int aligned_elem = ((end_elem - start_elem) / loop_step) * loop_step;
const int end_aligned = start_elem + aligned_elem;
for (int line = start_elem; line < end_aligned; line += loop_step) {
uint4 val[UNROLL_MC];
#pragma unroll #pragma unroll
for (int i = 0; i < RANKS - 1; i++) { for (int i = 0; i < UNROLL_MC; i++)
userptr[dest[i]][mylineoffset + line] = sum; val[i] = localptr[mylineoffset + line + i * loop_step0];
#pragma unroll
for (int i = 0; i < UNROLL_MC; i++)
asm volatile("multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(
mc_ptr + (mylineoffset + line + i * loop_step0)),
"r"(val[i].x), "r"(val[i].y), "r"(val[i].z), "r"(val[i].w)
: "memory");
} }
for (int line = end_aligned; line < end_elem; line += loop_step0) {
uint4 val = localptr[mylineoffset + line];
asm volatile(
"multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(mc_ptr + (mylineoffset + line)),
"r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)
: "memory");
} }
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) __threadfence_system(); if (threadIdx.x == 0)
__threadfence_system();
__syncthreads(); __syncthreads();
if (threadIdx.x < RANKS) { __shared__ int lastSM;
if (threadIdx.x == 0) {
const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1;
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder);
if (old_val + adder == NVTE_MAX_SMS * reduce_id)
lastSM = 1;
else
lastSM = 0;
}
__syncthreads();
if (lastSM && threadIdx.x < RANKS) {
if (threadIdx.x == 0)
*reduceidptr = reduce_id;
flagptr[physgpu] = reduce_id; flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&myptr[targetgpu]; volatile int *flag = (volatile int *)&myptr[targetgpu];
clock_t s = clock64(); clock_t s = clock64();
...@@ -490,43 +745,789 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -490,43 +745,789 @@ __global__ void __launch_bounds__(MAX_THREADS)
} }
} }
} }
if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; } // fp16 inplace allgather kernel (Hopper) MC
} // fp16 inplace allgather kernel (Volta,Hopper)
#else
template <int RANKS> template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS) __global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rr_blocked(const int op, const int flagoffset, userbuffers_fp16_sum_inplace_gpu_mc(const int op, const int flagoffset, const int firstrank,
const int firstrank, const int myrank, const int myrank, const int gpustep, const int lineoffset,
const int lineoffset, const int numlines, const int numlines, void **commbuff, const int handleridx,
void **commbuff, const int handleridx, float4 *mc_ptr) {}
const int peerblocklines, int *hostflags, template <int RANKS>
int *gpuflag, const int numblocks) { __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rs_oop(
const int basecounter = gpuflag[NVTE_GF_STATE + op]; const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep,
const int mylineoffset, const int totallines, const int rowlines, const int skiplines,
void **commbuff, const int handleridx, void *outbuf, float4 *mc_ptr) {}
template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_mc_ag(const int op, const int flagoffset, const int firstrank,
const int myrank, const int gpustep,
const int mylineoffset, const int totallines,
void **commbuff, const int handleridx, uint4 *mc_ptr) {}
template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_mc_rs(const int op, const int flagoffset, const int firstrank,
const int myrank, const int gpustep,
const int mylineoffset, const int totallines,
void **commbuff, const int handleridx, float4 *mc_ptr) {}
#endif
#define REDUCETHREADS (blockDim.x - 32) template <int RANKS, typename fp8type>
__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_fp8(
const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep,
const int mylineoffset, const int totallines, const int rowlines, const int skiplines,
void **commbuff, const int handleridx, void *outbuf, float *scale) {
__shared__ int4 *userptr[RANKS];
volatile int *flagptr;
int physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
int lastSM = 0;
half hscale = (half)*scale;
if (threadIdx.x < 32) {
int *flagptr;
if (threadIdx.x < RANKS) { if (threadIdx.x < RANKS) {
if (!blockIdx.x) { physgpu = myrank * gpustep + firstrank;
flagptr = reinterpret_cast<int *>(commbuff[threadIdx.x + firstrank]); targetgpu = threadIdx.x * gpustep + firstrank;
flagptr[flagoffset + myrank + firstrank] = basecounter; myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op;
reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset;
if (blockIdx.x == 0)
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64();
while (*flag < reduce_id) {
if (clock64() - s > TIMEOUT) {
printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x,
threadIdx.x, reduce_id, *flag);
break;
} }
volatile int *flag = (volatile int *)&((reinterpret_cast<int *>(
commbuff[myrank + firstrank]))[flagoffset + threadIdx.x + firstrank]);
while (*flag < basecounter) {
} }
} }
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) {
int startblock = 0, endblock = numblocks; const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1;
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder);
for (int nblock = 0; nblock < endblock; nblock++) { if (old_val + adder == NVTE_MAX_SMS * reduce_id)
lastSM = 1;
}
int warp = blockIdx.x + (threadIdx.x >> 5);
int dest[RANKS];
#pragma unroll
for (int i = 0; i < RANKS; i++)
dest[i] = (i + myrank + warp) & (RANKS - 1);
__syncthreads();
for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines;
line += blockDim.x * gridDim.x) {
int4 val[RANKS];
#pragma unroll
for (int i = 0; i < RANKS; i++) {
val[i] = userptr[dest[i]][mylineoffset + line];
}
int4 sum[2] = {{0, 0, 0, 0}, {0, 0, 0, 0}};
half *s = reinterpret_cast<half *>(&sum);
#pragma unroll
for (int i = 0; i < RANKS; i++) {
fp8type *x = reinterpret_cast<fp8type *>(&val[i]);
#pragma unroll
for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++)
s[j] += hscale * (half)(x[j]);
}
int hline = 2 * line;
(reinterpret_cast<int4 *>(outbuf))[(hline / rowlines) * skiplines + (hline % rowlines)] =
sum[0];
hline++;
(reinterpret_cast<int4 *>(outbuf))[(hline / rowlines) * skiplines + (hline % rowlines)] =
sum[1];
}
if (threadIdx.x == 0 && lastSM)
*reduceidptr = reduce_id;
} // fp16 reduce-scatter kernel (out of place) (fp8->fp16)
template <int RANKS, typename fp8type>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_atomic_fp8(
const int op, const int flagoffset, const int firstrank, const int myrank,
const int gpustep, const int mylineoffset, const int totallines, const int rowlines,
const int skiplines_out, const int skiplines_in, void **commbuff, const int handleridx,
void *outbuf, float *scale, void *counters, const int numchunks, const int atomicindex) {
__shared__ int4 *userptr[RANKS];
volatile int *flagptr;
int physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
int lastSM = 0;
half hscale = (half)*scale;
if (threadIdx.x < RANKS) {
physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank;
// const int blockflagoffset = MAX_NVLINK * 2 * blockIdx.x;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op;
reduce_id = (*reduceidptr);
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset; // + blockflagoffset;
}
for (int chunk_i = 0; chunk_i < numchunks; chunk_i++) {
ATOMIC_CONSUMER(chunk_i);
lastSM = 0;
if (threadIdx.x < RANKS) {
reduce_id++;
if (blockIdx.x == 0)
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64();
while (*flag < reduce_id) {
if (clock64() - s > TIMEOUT) {
printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x,
threadIdx.x, reduce_id, *flag);
break;
}
}
}
__syncthreads();
if (threadIdx.x == 0) {
const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1;
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), /*numchunks * */ adder);
if (old_val + adder == NVTE_MAX_SMS * (reduce_id /* + numchunks*/))
lastSM = 1;
}
int warp = blockIdx.x + (threadIdx.x >> 5);
int dest[RANKS];
#pragma unroll
for (int i = 0; i < RANKS; i++)
dest[i] = (i + myrank + warp) & (RANKS - 1);
__syncthreads();
for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines;
line += blockDim.x * gridDim.x) {
int4 val[RANKS];
const int rowlines_in = rowlines / 2;
const int index_in = skiplines_in == 0
? mylineoffset + myrank * totallines + line
: (numchunks <= 1 ? 1 : chunk_i) * mylineoffset +
myrank * (totallines * skiplines_in / rowlines_in) +
(line / rowlines_in) * skiplines_in + (line % rowlines_in);
const int index1_out = chunk_i * mylineoffset * 2 + ((2 * line) / rowlines) * skiplines_out +
((2 * line) % rowlines);
const int index2_out = chunk_i * mylineoffset * 2 +
((2 * line + 1) / rowlines) * skiplines_out +
((2 * line + 1) % rowlines);
#pragma unroll
for (int i = 0; i < RANKS; i++) {
val[i] = userptr[dest[i]][index_in];
}
int4 sum[2] = {{0, 0, 0, 0}, {0, 0, 0, 0}};
half *s = reinterpret_cast<half *>(&sum);
#pragma unroll
for (int i = 0; i < RANKS; i++) {
fp8type *x = reinterpret_cast<fp8type *>(&val[i]);
#pragma unroll
for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++)
s[j] += hscale * (half)(x[j]);
}
(reinterpret_cast<int4 *>(outbuf))[index1_out] = sum[0];
(reinterpret_cast<int4 *>(outbuf))[index2_out] = sum[1];
}
}
if (threadIdx.x == 0 && lastSM)
*reduceidptr = reduce_id;
} // fp16 reduce-scatter kernel (out of place) (fp8->fp16)
template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride(const int op, const int flagoffset,
const int firstrank, const int myrank,
const int gpustep, const int mylineoffset,
const int totallines, const int rowlines,
const int skiplines, void **commbuff,
const int handleridx, void *outbuf) {
__shared__ int4 *userptr[RANKS];
volatile int *flagptr;
int physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
int lastSM = 0;
if (threadIdx.x < RANKS) {
physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op;
reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset;
if (blockIdx.x == 0)
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64();
while (*flag < reduce_id) {
if (clock64() - s > TIMEOUT) {
printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x,
threadIdx.x, reduce_id, *flag);
break;
}
}
}
__syncthreads();
if (threadIdx.x == 0) {
const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1;
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder);
if (old_val + adder == NVTE_MAX_SMS * reduce_id)
lastSM = 1;
}
int warp = blockIdx.x + (threadIdx.x >> 5);
int dest[RANKS];
#pragma unroll
for (int i = 0; i < RANKS; i++)
dest[i] = (i + myrank + warp) & (RANKS - 1);
for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines;
line += blockDim.x * gridDim.x) {
int4 val[RANKS];
int index_in = mylineoffset + myrank * (totallines * skiplines / rowlines) +
(line / rowlines) * skiplines + (line % rowlines);
#pragma unroll
for (int i = 0; i < RANKS; i++) {
val[i] = userptr[dest[i]][index_in];
}
int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum);
#pragma unroll
for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]);
#pragma unroll
for (int j = 0; j < 8; j++)
s[j] += x[j];
}
int index_out = (line / rowlines) * skiplines + (line % rowlines);
(reinterpret_cast<int4 *>(outbuf))[index_out] = sum;
}
if (threadIdx.x == 0 && lastSM)
*reduceidptr = reduce_id;
} // fp16 reduce-scatter kernel (out of place) fp16
#if 0
template<int RANKS, typename fp8type>
__global__ void
__launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic_fp8(
const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep,
const int mylineoffset, const int totallines, const int rowlines, const int skiplines,
const int numchunks, void **commbuff, const int handleridx, void* outbuf, void *counters,
float* scale) {
if (counters) {
if ( threadIdx.x == 0 ) {
// spin-lock on counter from producer
int old_val;
while (0 != (old_val = atomicCAS(((unsigned int*)counters), 0, 0) )) {}
// make sure all threadblocks have read/waited on counters.
int old_val2;
atomicInc(((unsigned int *)counters)+numchunks, gridDim.x-1);
while (0 != (old_val2 = atomicCAS(((unsigned int*)counters)+numchunks, 0, 0) )) {}
// reset counter for next producer.
((unsigned int*)counters)[0] = 1;
asm volatile ("fence.sc.gpu;\n");
}
}
__syncthreads();
__shared__ int4* userptr[RANKS];
volatile int *flagptr;
int physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
int lastSM = 0;
half hscale = (half) *scale;
if (threadIdx.x < RANKS) {
physgpu = myrank*gpustep+firstrank;
targetgpu = threadIdx.x*gpustep+firstrank;
myptr = (reinterpret_cast<int*>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr-NVTE_MAX_OPS; // +op;
reduce_id =(*reduceidptr)+1;
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset;
if (blockIdx.x == 0) flagptr[physgpu] = reduce_id;
volatile int* flag = (volatile int*)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4*>(commbuff[targetgpu+handleridx]);
clock_t s = clock64();
while (*flag < reduce_id) {
if (clock64()-s > TIMEOUT) {
printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n",
myrank, blockIdx.x, threadIdx.x, reduce_id, *flag);
break;
}
}
}
__syncthreads();
if (threadIdx.x == 0) {
const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS-gridDim.x+1 : 1;
int old_val = atomicAdd(myptr+(NVTE_MAX_NVLINK*2), adder);
if (old_val+adder == NVTE_MAX_SMS*reduce_id) lastSM = 1;
}
int warp = blockIdx.x+(threadIdx.x>>5);
int dest[RANKS];
#pragma unroll
for (int i = 0; i < RANKS; i++)
dest[i] = (i+myrank+warp)&(RANKS-1);
for (int line = threadIdx.x+blockDim.x*blockIdx.x;
line < totallines; line+=blockDim.x*gridDim.x) {
int4 val[RANKS];
int index_in = mylineoffset + myrank*(totallines*skiplines/rowlines/2) +
(line/rowlines)*skiplines/2+(line%rowlines);
#pragma unroll
for (int i = 0; i < RANKS; i++) {
val[i] = userptr[dest[i]][index_in];
}
int4 sum[2] = {{0, 0, 0, 0}, {0, 0, 0, 0}};
half *s = reinterpret_cast<half*>(&sum);
#pragma unroll
for (int i = 0; i < RANKS; i++) {
fp8type *x = reinterpret_cast<fp8type*>(&val[i]);
#pragma unroll
for (int j=0; j < sizeof(int4)/sizeof(fp8type); j++) s[j] += hscale * (half)(x[j]);
}
int hline = 2*line;
int index_out1 = (hline/rowlines)*skiplines+(hline%rowlines);
(reinterpret_cast<int4*>(outbuf))[index_out1] = sum[0];
hline++;
int index_out2 = (hline/rowlines)*skiplines+(hline%rowlines);
(reinterpret_cast<int4*>(outbuf))[index_out2] = sum[1];
}
if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id;
} // fp16 reduce-scatter kernel (out of place) fp16
#endif
template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic(
const int op, const int flagoffset, const int firstrank, const int myrank,
const int gpustep, const int mylineoffset, const int totallines, const int rowlines,
const int skiplines, const int numchunks, void **commbuff, const int handleridx,
void *outbuf, void *counters) {
if (counters) {
if (threadIdx.x == 0) {
// spin-lock on counter from producer
int old_val;
while (0 != (old_val = atomicCAS(((unsigned int *)counters), 0, 0))) {
}
// make sure all threadblocks have read/waited on counters.
int old_val2;
atomicInc(((unsigned int *)counters) + numchunks, gridDim.x - 1);
while (0 != (old_val2 = atomicCAS(((unsigned int *)counters) + numchunks, 0, 0))) {
}
// reset counter for next producer.
((unsigned int *)counters)[0] = 1;
asm volatile("fence.sc.gpu;\n");
}
}
__syncthreads();
__shared__ int4 *userptr[RANKS];
volatile int *flagptr;
int physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
int lastSM = 0;
if (threadIdx.x < RANKS) {
physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op;
reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset;
if (blockIdx.x == 0)
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64();
while (*flag < reduce_id) {
if (clock64() - s > TIMEOUT) {
printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x,
threadIdx.x, reduce_id, *flag);
break;
}
}
}
__syncthreads();
if (threadIdx.x == 0) {
const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1;
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder);
if (old_val + adder == NVTE_MAX_SMS * reduce_id)
lastSM = 1;
}
int warp = blockIdx.x + (threadIdx.x >> 5);
int dest[RANKS];
#pragma unroll
for (int i = 0; i < RANKS; i++)
dest[i] = (i + myrank + warp) & (RANKS - 1);
for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines;
line += blockDim.x * gridDim.x) {
int4 val[RANKS];
int index_in = mylineoffset + myrank * (totallines * skiplines / rowlines) +
(line / rowlines) * skiplines + (line % rowlines);
#pragma unroll
for (int i = 0; i < RANKS; i++) {
val[i] = userptr[dest[i]][index_in];
}
int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum);
#pragma unroll
for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]);
#pragma unroll
for (int j = 0; j < 8; j++)
s[j] += x[j];
}
int index_out = (line / rowlines) * skiplines + (line % rowlines);
(reinterpret_cast<int4 *>(outbuf))[index_out] = sum;
}
if (threadIdx.x == 0 && lastSM)
*reduceidptr = reduce_id;
} // fp16 reduce-scatter kernel (out of place) fp16
template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_multiatomic(
const int op, const int flagoffset, const int firstrank, const int myrank,
const int gpustep, const int mylineoffset, const int totallines, const int rowlines,
const int skiplines, const int numchunks, void **commbuff, const int handleridx,
void *outbuf, void *counters) {
for (int chunk_i = 0; chunk_i < numchunks; chunk_i++) {
if (counters) {
if (threadIdx.x == 0) {
// spin-lock on counter from producer
int old_val;
while (0 != (old_val = atomicCAS(((unsigned int *)counters) + chunk_i, 0, 0))) {
}
// make sure all threadblocks have read/waited on counters.
int old_val2;
atomicInc(((unsigned int *)counters) + numchunks + chunk_i, gridDim.x - 1);
while (0 !=
(old_val2 = atomicCAS(((unsigned int *)counters) + numchunks + chunk_i, 0, 0))) {
}
// reset counter for next producer.
((unsigned int *)counters)[chunk_i] = 1;
asm volatile("fence.sc.gpu;\n");
}
}
__syncthreads();
__shared__ int4 *userptr[RANKS];
volatile int *flagptr;
int physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
int lastSM = 0;
if (threadIdx.x < RANKS) {
physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op;
reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset;
if (blockIdx.x == 0)
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64();
while (*flag < reduce_id) {
if (clock64() - s > TIMEOUT) {
printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x,
threadIdx.x, reduce_id, *flag);
break;
}
}
}
__syncthreads();
if (threadIdx.x == 0) {
const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1;
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder);
if (old_val + adder == NVTE_MAX_SMS * reduce_id)
lastSM = 1;
}
int warp = blockIdx.x + (threadIdx.x >> 5);
int dest[RANKS];
#pragma unroll
for (int i = 0; i < RANKS; i++)
dest[i] = (i + myrank + warp) & (RANKS - 1);
for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines;
line += blockDim.x * gridDim.x) {
int4 val[RANKS];
int index_in = chunk_i * mylineoffset + myrank * (totallines * skiplines / rowlines) +
(line / rowlines) * skiplines + (line % rowlines);
#pragma unroll
for (int i = 0; i < RANKS; i++) {
val[i] = userptr[dest[i]][index_in];
}
int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum);
#pragma unroll
for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]);
#pragma unroll
for (int j = 0; j < 8; j++)
s[j] += x[j];
}
int index_out = chunk_i * mylineoffset + (line / rowlines) * skiplines + (line % rowlines);
(reinterpret_cast<int4 *>(outbuf))[index_out] = sum;
}
if (threadIdx.x == 0 && lastSM)
*reduceidptr = reduce_id;
}
} // fp16 reduce-scatter kernel (out of place) fp16
template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rr_ag(const int op, const int flagoffset, const int firstrank,
const int myrank, const int gpustep,
const int mylineoffset, const int totallines,
void **commbuff, const int handleridx) {
__shared__ int4 *userptr[RANKS];
volatile int *flagptr;
int physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
if (threadIdx.x < RANKS) {
physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op;
reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset;
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64();
}
int warp = blockIdx.x + (threadIdx.x >> 5);
int dest[RANKS];
int skipmy = 0;
#pragma unroll
for (int i = 0; i < RANKS; i++) {
int dst = (i + warp + myrank) & (RANKS - 1);
if (dst == myrank) {
skipmy++;
continue;
}
dest[i - skipmy] = dst;
}
__syncthreads();
for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines;
line += blockDim.x * gridDim.x) {
int4 val[RANKS - 1];
#pragma unroll
for (int i = 0; i < RANKS - 1; i++) {
val[i] = userptr[dest[i]][mylineoffset + line + totallines * dest[i]];
}
#pragma unroll
for (int i = 0; i < RANKS - 1; i++) {
userptr[myrank][mylineoffset + line + totallines * dest[i]] = val[i];
}
}
__shared__ int lastSM;
if (threadIdx.x == 0) {
const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1;
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder);
if (old_val + adder == NVTE_MAX_SMS * reduce_id)
lastSM = 1;
else
lastSM = 0;
}
__syncthreads();
if (lastSM && threadIdx.x < RANKS) {
if (threadIdx.x == 0)
*reduceidptr = reduce_id;
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&myptr[targetgpu];
clock_t s = clock64();
while (*flag < reduce_id) {
if (clock64() - s > 2ull * TIMEOUT) {
printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
break;
}
}
}
} // fp16 inplace reduce kernel (Ampere)
template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rw_ag(const int op, const int flagoffset, const int firstrank,
const int myrank, const int gpustep,
const int mylineoffset, const int totallines,
void **commbuff, const int handleridx) {
__shared__ int4 *userptr[RANKS];
volatile int *flagptr;
int physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
int4 *localptr;
if (threadIdx.x < RANKS) {
physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op;
reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset;
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
}
__syncthreads();
localptr = userptr[myrank];
int warp = blockIdx.x + (threadIdx.x >> 5);
int dest[RANKS - 1];
int skipmy = 0;
#pragma unroll
for (int i = 0; i < RANKS; i++) {
int dst = (i + warp + myrank) & (RANKS - 1);
if (dst == myrank) {
skipmy++;
continue;
}
dest[i - skipmy] = dst;
}
#define UNROLLAG 4
__syncthreads();
const int loop_step0 = blockDim.x * gridDim.x;
const int loop_step = loop_step0 * UNROLLAG;
const int start_elem = threadIdx.x + blockDim.x * blockIdx.x;
const int end_elem = max(start_elem, totallines);
const int aligned_elem = ((end_elem - start_elem) / loop_step) * loop_step;
const int end_aligned = start_elem + aligned_elem;
for (int line = start_elem; line < end_aligned; line += loop_step) {
int4 val[UNROLLAG];
#pragma unroll
for (int j = 0; j < UNROLLAG; j++)
val[j] = localptr[mylineoffset + line + loop_step0 * j];
#pragma unroll
for (int j = 0; j < UNROLLAG; j++)
#pragma unroll
for (int i = 0; i < RANKS - 1; i++) {
userptr[dest[i]][mylineoffset + line + j * loop_step0] = val[j];
}
}
for (int line = end_aligned; line < end_elem; line += loop_step0) {
int4 sum = localptr[mylineoffset + line];
#pragma unroll
for (int i = 0; i < RANKS - 1; i++) {
userptr[dest[i]][mylineoffset + line] = sum;
}
}
__syncthreads();
if (threadIdx.x == 0)
__threadfence_system();
__syncthreads();
__shared__ int lastSM;
if (threadIdx.x == 0) {
const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1;
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder);
if (old_val + adder == NVTE_MAX_SMS * reduce_id)
lastSM = 1;
else
lastSM = 0;
}
__syncthreads();
if (lastSM && threadIdx.x < RANKS) {
if (threadIdx.x == 0)
*reduceidptr = reduce_id;
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&myptr[targetgpu];
clock_t s = clock64();
while (*flag < reduce_id) {
if (clock64() - s > 2ull * TIMEOUT) {
printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
break;
}
}
}
} // fp16 inplace allgather kernel (Volta,Hopper)
template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rr_blocked(const int op, const int flagoffset,
const int firstrank, const int myrank,
const int lineoffset, const int numlines,
void **commbuff, const int handleridx,
const int peerblocklines, int *hostflags,
int *gpuflag, const int numblocks) {
const int basecounter = gpuflag[NVTE_GF_STATE + op];
#define REDUCETHREADS (blockDim.x - 32)
if (threadIdx.x < 32) {
int *flagptr;
if (threadIdx.x < RANKS) {
if (!blockIdx.x) {
flagptr = reinterpret_cast<int *>(commbuff[threadIdx.x + firstrank]);
flagptr[flagoffset + myrank + firstrank] = basecounter;
}
volatile int *flag = (volatile int *)&((reinterpret_cast<int *>(
commbuff[myrank + firstrank]))[flagoffset + threadIdx.x + firstrank]);
while (*flag < basecounter) {
}
}
__syncthreads();
int startblock = 0, endblock = numblocks;
for (int nblock = 0; nblock < endblock; nblock++) {
asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32)); asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32));
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
__threadfence(); __threadfence();
if (blockIdx.x) gpuflag[op * NVTE_MAX_SMS * 2 + blockIdx.x] = nblock + basecounter + 1; if (blockIdx.x)
gpuflag[op * NVTE_MAX_SMS * 2 + blockIdx.x] = nblock + basecounter + 1;
} else if (blockIdx.x == 0) { } else if (blockIdx.x == 0) {
int expecting = (basecounter + nblock + 1); int expecting = (basecounter + nblock + 1);
if (threadIdx.x < gridDim.x) if (threadIdx.x < gridDim.x)
...@@ -535,7 +1536,8 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -535,7 +1536,8 @@ __global__ void __launch_bounds__(MAX_THREADS)
} }
if (!blockIdx.x) { if (!blockIdx.x) {
asm volatile("bar.sync 15, %0;" ::"r"(32)); asm volatile("bar.sync 15, %0;" ::"r"(32));
if (!threadIdx.x) hostflags[0] = nblock + basecounter + 1; if (!threadIdx.x)
hostflags[0] = nblock + basecounter + 1;
} }
} }
...@@ -546,13 +1548,15 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -546,13 +1548,15 @@ __global__ void __launch_bounds__(MAX_THREADS)
if (blockIdx.x == 0 && threadIdx.x < RANKS) { if (blockIdx.x == 0 && threadIdx.x < RANKS) {
while (cachedflag < basecounter + numblocks) { while (cachedflag < basecounter + numblocks) {
int newflag = ((volatile int *)gpuflag)[ALLGATHERFLAG]; int newflag = ((volatile int *)gpuflag)[ALLGATHERFLAG];
if (newflag == cachedflag) continue; if (newflag == cachedflag)
continue;
cachedflag = newflag; cachedflag = newflag;
flagptr[flagoffset + myrank + 32 + firstrank] = cachedflag; flagptr[flagoffset + myrank + 32 + firstrank] = cachedflag;
} }
} }
if (blockIdx.x == 0 && threadIdx.x == 0) gpuflag[NVTE_GF_STATE + op] = basecounter + numblocks; if (blockIdx.x == 0 && threadIdx.x == 0)
gpuflag[NVTE_GF_STATE + op] = basecounter + numblocks;
} else { } else {
const int warp = blockIdx.x + (threadIdx.x >> 5); const int warp = blockIdx.x + (threadIdx.x >> 5);
int4 *userptr[RANKS]; int4 *userptr[RANKS];
...@@ -587,7 +1591,8 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -587,7 +1591,8 @@ __global__ void __launch_bounds__(MAX_THREADS)
for (int i = 1; i < RANKS; i++) { for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]); half *x = reinterpret_cast<half *>(&val[i]);
#pragma unroll #pragma unroll
for (int j = 0; j < sizeof(int4) / sizeof(half); j++) s[j] += x[j]; for (int j = 0; j < sizeof(int4) / sizeof(half); j++)
s[j] += x[j];
} }
userptrmyrank[blockstart + line] = sum; userptrmyrank[blockstart + line] = sum;
...@@ -637,9 +1642,11 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -637,9 +1642,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
for (int line = start_elem; line < end_aligned; line += myblockDim * gridDim.x * UNROLL) { for (int line = start_elem; line < end_aligned; line += myblockDim * gridDim.x * UNROLL) {
int4 val[UNROLL]; int4 val[UNROLL];
#pragma unroll #pragma unroll
for (int i = 0; i < UNROLL; i++) val[i] = peerptr[line + i * myblockDim * gridDim.x]; for (int i = 0; i < UNROLL; i++)
val[i] = peerptr[line + i * myblockDim * gridDim.x];
#pragma unroll #pragma unroll
for (int i = 0; i < UNROLL; i++) myptr[line + i * myblockDim * gridDim.x] = val[i]; for (int i = 0; i < UNROLL; i++)
myptr[line + i * myblockDim * gridDim.x] = val[i];
} }
for (int line = end_aligned; line < end_elem; line += myblockDim * gridDim.x) for (int line = end_aligned; line < end_elem; line += myblockDim * gridDim.x)
myptr[line] = peerptr[line]; myptr[line] = peerptr[line];
...@@ -654,14 +1661,16 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -654,14 +1661,16 @@ __global__ void __launch_bounds__(MAX_THREADS)
asm volatile("bar.sync 13, %0;" ::"r"(blockDim.x)); \ asm volatile("bar.sync 13, %0;" ::"r"(blockDim.x)); \
if (threadIdx.x == 0) { \ if (threadIdx.x == 0) { \
__threadfence_system(); \ __threadfence_system(); \
if (blockIdx.x) gpuflag[offset + blockIdx.x] = block + basecounter + 1; \ if (blockIdx.x) \
gpuflag[offset + blockIdx.x] = block + basecounter + 1; \
} else if (blockIdx.x == 0) { \ } else if (blockIdx.x == 0) { \
int expecting = (basecounter + block + 1); \ int expecting = (basecounter + block + 1); \
if (threadIdx.x < gridDim.x) \ if (threadIdx.x < gridDim.x) \
while (((volatile int *)gpuflag)[offset + threadIdx.x] < expecting) { \ while (((volatile int *)gpuflag)[offset + threadIdx.x] < expecting) { \
} \ } \
} \ } \
if (blockIdx.x == 0) asm volatile("bar.sync 15, %0;" ::"r"(32)); if (blockIdx.x == 0) \
asm volatile("bar.sync 15, %0;" ::"r"(32));
template <int RANKS> template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_blocked2( __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_blocked2(
...@@ -722,7 +1731,8 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -722,7 +1731,8 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
} }
} }
if (blockIdx.x == 0 && threadIdx.x == 0) gpuflag[NVTE_GF_STATE + op] = basecounter + numblocks; if (blockIdx.x == 0 && threadIdx.x == 0)
gpuflag[NVTE_GF_STATE + op] = basecounter + numblocks;
} else { // sync warp } else { // sync warp
// reducethreads // reducethreads
const int warp = blockIdx.x + (threadIdx.x >> 5); const int warp = blockIdx.x + (threadIdx.x >> 5);
...@@ -762,7 +1772,8 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -762,7 +1772,8 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
for (int i = 1; i < RANKS; i++) { for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]); half *x = reinterpret_cast<half *>(&val[i]);
#pragma unroll #pragma unroll
for (int j = 0; j < sizeof(int4) / sizeof(half); j++) s[j] += x[j]; for (int j = 0; j < sizeof(int4) / sizeof(half); j++)
s[j] += x[j];
} }
userptrmyrank[blockstart + line] = sum; userptrmyrank[blockstart + line] = sum;
...@@ -801,13 +1812,15 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -801,13 +1812,15 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
: tempbufptr[i * ibblocklines + line]; : tempbufptr[i * ibblocklines + line];
half *x = reinterpret_cast<half *>(&val[(i + 1) % UNROLLRS]); half *x = reinterpret_cast<half *>(&val[(i + 1) % UNROLLRS]);
#pragma unroll #pragma unroll
for (int j = 0; j < 16; j++) s[j] += x[j]; for (int j = 0; j < 16; j++)
s[j] += x[j];
} }
#pragma unroll #pragma unroll
for (int i = 1; i < UNROLLRS; i++) { for (int i = 1; i < UNROLLRS; i++) {
half *x = reinterpret_cast<half *>(&val[i]); half *x = reinterpret_cast<half *>(&val[i]);
#pragma unroll #pragma unroll
for (int j = 0; j < 16; j++) s[j] += x[j]; for (int j = 0; j < 16; j++)
s[j] += x[j];
} }
userptrmyrank[tempstart + line] = sum; userptrmyrank[tempstart + line] = sum;
} }
...@@ -858,9 +1871,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -858,9 +1871,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
for (int line = start_elem; line < end_aligned; line += myblockDim * gridDim.x * UNROLL) { for (int line = start_elem; line < end_aligned; line += myblockDim * gridDim.x * UNROLL) {
int4 val[UNROLL]; int4 val[UNROLL];
#pragma unroll #pragma unroll
for (int i = 0; i < UNROLL; i++) val[i] = peerptr[line + i * myblockDim * gridDim.x]; for (int i = 0; i < UNROLL; i++)
val[i] = peerptr[line + i * myblockDim * gridDim.x];
#pragma unroll #pragma unroll
for (int i = 0; i < UNROLL; i++) myptr[line + i * myblockDim * gridDim.x] = val[i]; for (int i = 0; i < UNROLL; i++)
myptr[line + i * myblockDim * gridDim.x] = val[i];
} }
for (int line = end_aligned; line < end_elem; line += myblockDim * gridDim.x) for (int line = end_aligned; line < end_elem; line += myblockDim * gridDim.x)
myptr[line] = peerptr[line]; myptr[line] = peerptr[line];
...@@ -952,7 +1967,8 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -952,7 +1967,8 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
for (int i = 1; i < RANKS; i++) { for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]); half *x = reinterpret_cast<half *>(&val[i]);
#pragma unroll #pragma unroll
for (int j = 0; j < sizeof(int4) / sizeof(half); j++) s[j] += x[j]; for (int j = 0; j < sizeof(int4) / sizeof(half); j++)
s[j] += x[j];
} }
userptrmyrank[blockstart + line] = sum; userptrmyrank[blockstart + line] = sum;
...@@ -971,9 +1987,6 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -971,9 +1987,6 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
int4 *tempbufptr = &internalbuf[((nblock - headstart) % maxcredit) * peerblocklines]; int4 *tempbufptr = &internalbuf[((nblock - headstart) % maxcredit) * peerblocklines];
const int tempstart = lineoffset + (nblock - headstart) * peerblocklines * RANKS + const int tempstart = lineoffset + (nblock - headstart) * peerblocklines * RANKS +
myrank * blocklines + ibblocklines * myibrank; myrank * blocklines + ibblocklines * myibrank;
// if(threadIdx.x==32) printf("[%d] block%d thread %d offset %d line %d ibblocklines %d ptr
// %lx commbufoffset
// %d\n",myrank,blockIdx.x,threadIdx.x,tempstart,0,ibblocklines,(void*)&tempbufptr[(1-myibrank)*ibblocklines],(1-myibrank)*ibblocklines*16);
asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32)); asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32));
...@@ -994,13 +2007,15 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -994,13 +2007,15 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
: tempbufptr[i * ibblocklines + line]; : tempbufptr[i * ibblocklines + line];
half *x = reinterpret_cast<half *>(&val[(i + 1) % UNROLLRS]); half *x = reinterpret_cast<half *>(&val[(i + 1) % UNROLLRS]);
#pragma unroll #pragma unroll
for (int j = 0; j < 16; j++) s[j] += x[j]; for (int j = 0; j < 16; j++)
s[j] += x[j];
} }
#pragma unroll #pragma unroll
for (int i = 1; i < UNROLLRS; i++) { for (int i = 1; i < UNROLLRS; i++) {
half *x = reinterpret_cast<half *>(&val[i]); half *x = reinterpret_cast<half *>(&val[i]);
#pragma unroll #pragma unroll
for (int j = 0; j < 16; j++) s[j] += x[j]; for (int j = 0; j < 16; j++)
s[j] += x[j];
} }
userptrmyrank[tempstart + line] = sum; userptrmyrank[tempstart + line] = sum;
} }
...@@ -1048,7 +2063,8 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -1048,7 +2063,8 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
} }
} }
if (blockIdx.x == 0 && threadIdx.x == 0) gpuflag[NVTE_GF_STATE + op] = basecounter + numblocks; if (blockIdx.x == 0 && threadIdx.x == 0)
gpuflag[NVTE_GF_STATE + op] = basecounter + numblocks;
} else { // sync warp } else { // sync warp
// reducethreads // reducethreads
const int warp = blockIdx.x + (threadIdx.x >> 5); const int warp = blockIdx.x + (threadIdx.x >> 5);
...@@ -1105,9 +2121,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -1105,9 +2121,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
for (int line = start_elem; line < end_aligned; line += myblockDim * gridDim.x * UNROLL) { for (int line = start_elem; line < end_aligned; line += myblockDim * gridDim.x * UNROLL) {
int4 val[UNROLL]; int4 val[UNROLL];
#pragma unroll #pragma unroll
for (int i = 0; i < UNROLL; i++) val[i] = peerptr[line + i * myblockDim * gridDim.x]; for (int i = 0; i < UNROLL; i++)
val[i] = peerptr[line + i * myblockDim * gridDim.x];
#pragma unroll #pragma unroll
for (int i = 0; i < UNROLL; i++) myptr[line + i * myblockDim * gridDim.x] = val[i]; for (int i = 0; i < UNROLL; i++)
myptr[line + i * myblockDim * gridDim.x] = val[i];
} }
for (int line = end_aligned; line < end_elem; line += myblockDim * gridDim.x) for (int line = end_aligned; line < end_elem; line += myblockDim * gridDim.x)
myptr[line] = peerptr[line]; myptr[line] = peerptr[line];
...@@ -1140,10 +2158,14 @@ __global__ void userbuffers_fp16_sum_inplace_gpu_null(const int op, int *hostfla ...@@ -1140,10 +2158,14 @@ __global__ void userbuffers_fp16_sum_inplace_gpu_null(const int op, int *hostfla
if (ar_nvsize == x) { \ if (ar_nvsize == x) { \
int numblocks = (elements * 2 + blocksize - 1) / blocksize; \ int numblocks = (elements * 2 + blocksize - 1) / blocksize; \
int headstart = numblocks - 1; /*<3?numblocks-1:3;*/ \ int headstart = numblocks - 1; /*<3?numblocks-1:3;*/ \
if (headstart > maxcredit) headstart = maxcredit; \ if (headstart > maxcredit) \
if (x == 1) headstart = maxcredit; \ headstart = maxcredit; \
if (headstart > numblocks) headstart = numblocks; \ if (x == 1) \
if (headstart == 0) headstart = 1; \ headstart = maxcredit; \
if (headstart > numblocks) \
headstart = numblocks; \
if (headstart == 0) \
headstart = 1; \
userbuffers_fp16_sum_inplace_gpu_rr_blocked2<x><<<sms, warps * 32, 0, stream>>>( \ userbuffers_fp16_sum_inplace_gpu_rr_blocked2<x><<<sms, warps * 32, 0, stream>>>( \
op, maxcredit, headstart, my_node, num_nodes, \ op, maxcredit, headstart, my_node, num_nodes, \
NVTE_REG0_OFFSET(comm) + NVTE_REG0_FLAGS + \ NVTE_REG0_OFFSET(comm) + NVTE_REG0_FLAGS + \
...@@ -1158,10 +2180,14 @@ __global__ void userbuffers_fp16_sum_inplace_gpu_null(const int op, int *hostfla ...@@ -1158,10 +2180,14 @@ __global__ void userbuffers_fp16_sum_inplace_gpu_null(const int op, int *hostfla
if (ar_nvsize == x) { \ if (ar_nvsize == x) { \
int numblocks = (elements * 2 + blocksize - 1) / blocksize; \ int numblocks = (elements * 2 + blocksize - 1) / blocksize; \
int headstart = numblocks - 1; /*<3?numblocks-1:3;*/ \ int headstart = numblocks - 1; /*<3?numblocks-1:3;*/ \
if (headstart > maxcredit) headstart = maxcredit; \ if (headstart > maxcredit) \
if (x == 1) headstart = maxcredit; \ headstart = maxcredit; \
if (headstart > numblocks) headstart = numblocks; \ if (x == 1) \
if (headstart == 0) headstart = 1; \ headstart = maxcredit; \
if (headstart > numblocks) \
headstart = numblocks; \
if (headstart == 0) \
headstart = 1; \
userbuffers_fp16_sum_inplace_gpu_rr_blocked2_rs<x><<<sms, warps * 32, 0, stream>>>( \ userbuffers_fp16_sum_inplace_gpu_rr_blocked2_rs<x><<<sms, warps * 32, 0, stream>>>( \
op, maxcredit, headstart, my_node, num_nodes, \ op, maxcredit, headstart, my_node, num_nodes, \
NVTE_REG0_OFFSET(comm) + NVTE_REG0_FLAGS + \ NVTE_REG0_OFFSET(comm) + NVTE_REG0_FLAGS + \
...@@ -1176,10 +2202,14 @@ __global__ void userbuffers_fp16_sum_inplace_gpu_null(const int op, int *hostfla ...@@ -1176,10 +2202,14 @@ __global__ void userbuffers_fp16_sum_inplace_gpu_null(const int op, int *hostfla
if (ar_nvsize == x) { \ if (ar_nvsize == x) { \
int numblocks = (elements * 2 + blocksize - 1) / blocksize; \ int numblocks = (elements * 2 + blocksize - 1) / blocksize; \
int headstart = numblocks - 1; /*<3?numblocks-1:3;*/ \ int headstart = numblocks - 1; /*<3?numblocks-1:3;*/ \
if (headstart > maxcredit) headstart = maxcredit; \ if (headstart > maxcredit) \
if (x == 1) headstart = maxcredit; \ headstart = maxcredit; \
if (headstart > numblocks) headstart = numblocks; \ if (x == 1) \
if (headstart == 0) headstart = 1; \ headstart = maxcredit; \
if (headstart > numblocks) \
headstart = numblocks; \
if (headstart == 0) \
headstart = 1; \
userbuffers_fp16_sum_inplace_gpu_rr_blocked2_ag<x><<<sms, warps * 32, 0, stream>>>( \ userbuffers_fp16_sum_inplace_gpu_rr_blocked2_ag<x><<<sms, warps * 32, 0, stream>>>( \
op, maxcredit, headstart, my_node, num_nodes, \ op, maxcredit, headstart, my_node, num_nodes, \
NVTE_REG0_OFFSET(comm) + NVTE_REG0_FLAGS + \ NVTE_REG0_OFFSET(comm) + NVTE_REG0_FLAGS + \
...@@ -1212,6 +2242,26 @@ __global__ void userbuffers_fp16_sum_inplace_gpu_null(const int op, int *hostfla ...@@ -1212,6 +2242,26 @@ __global__ void userbuffers_fp16_sum_inplace_gpu_null(const int op, int *hostfla
kernelArgs)); \ kernelArgs)); \
} }
#define callranksMC(x) \
if (ar_nvsize == x) { \
int arg1 = op - NVTE_MAX_OPS, \
arg2 = NVTE_REG0_OFFSET(comm) - \
(op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \
NVTE_MAX_OPS, \
arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg6 = offset / 8, \
arg7 = elements / 8; \
void **arg8 = reinterpret_cast<void **>(comm->gpu_ptrs); \
int arg9 = handler * comm->nvsize; \
void *arg10 = comm->mc_ptr[handler]; \
void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2), \
reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4), \
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6), \
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10)}; \
CUDACHECK(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_mc<x>), kernelArgs)); \
}
#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \ #define SETUP_LAUNCH_CONFIG(sms, threads, stream) \
cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \
cudaLaunchAttribute attribute_ub[2]; \ cudaLaunchAttribute attribute_ub[2]; \
...@@ -1232,10 +2282,12 @@ int allreduce_userbuff_inplace_gpu(const int handler, const int offset, const in ...@@ -1232,10 +2282,12 @@ int allreduce_userbuff_inplace_gpu(const int handler, const int offset, const in
const int ar_nvsize = comm->nvsize; const int ar_nvsize = comm->nvsize;
const int ar_firstgpu = comm->ar_firstgpu; const int ar_firstgpu = comm->ar_firstgpu;
const int ar_nvrank = comm->ar_nvrank; const int ar_nvrank = comm->ar_nvrank;
if (elements < 8) return 0; if (elements < 8)
return 0;
int sms = sms = comm->sms; int sms = sms = comm->sms;
int warps = comm->threads / 32; int warps = comm->threads / 32;
if (warps < comm->ar_nvsize) warps = comm->ar_nvsize; if (warps < comm->ar_nvsize)
warps = comm->ar_nvsize;
if (comm->launch_mode & NVTE_LAUNCH_GPU) { if (comm->launch_mode & NVTE_LAUNCH_GPU) {
if (comm->ar_nvsize == 1) if (comm->ar_nvsize == 1)
...@@ -1260,10 +2312,12 @@ int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons ...@@ -1260,10 +2312,12 @@ int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank;
if (elements < 8) return 0; if (elements < 8)
return 0;
int sms = ar_nvsize == 1 ? 2 : comm->sms; int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32; int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize; if (warps < ar_nvsize)
warps = ar_nvsize;
if (num_nodes > 1) { if (num_nodes > 1) {
callranks2_block(1) callranks2_block(2) callranks2_block(4) callranks2_block(8) callranks2_block(1) callranks2_block(2) callranks2_block(4) callranks2_block(8)
} else { } else {
...@@ -1295,6 +2349,26 @@ int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons ...@@ -1295,6 +2349,26 @@ int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons
kernelArgs)); \ kernelArgs)); \
} }
#define callranks_agMC(x) \
if (ar_nvsize == x) { \
int arg1 = op - NVTE_MAX_OPS, \
arg2 = NVTE_REG0_OFFSET(comm) - \
(op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \
NVTE_MAX_OPS, \
arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \
arg6 = offset / 8 + arg4 * arg7; \
void **arg8 = reinterpret_cast<void **>(comm->gpu_ptrs); \
int arg9 = handler * comm->nvsize; \
uint4 *arg10 = reinterpret_cast<uint4 *>(comm->mc_ptr[handler]); \
void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2), \
reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4), \
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6), \
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10)}; \
CUDACHECK(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_mc_ag<x>), kernelArgs)); \
}
#define callranks_rs(x) \ #define callranks_rs(x) \
if (ar_nvsize == x) { \ if (ar_nvsize == x) { \
int arg1 = op - NVTE_MAX_OPS, \ int arg1 = op - NVTE_MAX_OPS, \
...@@ -1314,6 +2388,26 @@ int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons ...@@ -1314,6 +2388,26 @@ int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs<x>), kernelArgs)); \ &cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs<x>), kernelArgs)); \
} }
#define callranks_rsMC(x) \
if (ar_nvsize == x) { \
int arg1 = op - NVTE_MAX_OPS, \
arg2 = NVTE_REG0_OFFSET(comm) - \
(op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \
NVTE_MAX_OPS, \
arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \
arg6 = offset / 8 + arg4 * arg7; \
void **arg8 = reinterpret_cast<void **>(comm->gpu_ptrs); \
int arg9 = handler * comm->nvsize; \
void *arg10 = comm->mc_ptr[handler]; \
void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2), \
reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4), \
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6), \
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10)}; \
CUDACHECK(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_mc_rs<x>), kernelArgs)); \
}
#define callranks_rs_oop(x) \ #define callranks_rs_oop(x) \
if (ar_nvsize == x) { \ if (ar_nvsize == x) { \
int arg1 = op - NVTE_MAX_OPS, \ int arg1 = op - NVTE_MAX_OPS, \
...@@ -1336,32 +2430,383 @@ int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons ...@@ -1336,32 +2430,383 @@ int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons
kernelArgs)); \ kernelArgs)); \
} }
#define callranks_rs_oop_fp8(x) \
if (ar_nvsize == x) { \
int arg1 = op - NVTE_MAX_OPS, \
arg2 = NVTE_REG0_OFFSET(comm) - \
(op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \
NVTE_MAX_OPS, \
arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 16 / x, \
arg6 = offset / 16 + arg4 * arg7, arg8 = rowelements / 8, arg9 = strideelements / 8; \
void **arg10 = reinterpret_cast<void **>(comm->gpu_ptrs); \
int arg11 = handler * comm->nvsize; \
void *arg12 = output; \
float *arg13 = scale; \
void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2), \
reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4), \
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6), \
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13)}; \
CUDACHECK(cudaLaunchKernelExC( \
&cfg, \
reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_fp8<x, fp8type>), \
kernelArgs)); \
}
#define callranks_rs_oopMC(x) \
if (ar_nvsize == x) { \
int arg1 = op - NVTE_MAX_OPS, \
arg2 = NVTE_REG0_OFFSET(comm) - \
(op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \
NVTE_MAX_OPS, \
arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \
arg6 = offset / 8 + arg4 * arg7, arg8 = rowelements / 8, arg9 = strideelements / 8; \
void **arg10 = reinterpret_cast<void **>(comm->gpu_ptrs); \
int arg11 = handler * comm->nvsize; \
void *arg12 = output; \
void *arg13 = comm->mc_ptr[handler]; \
void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2), \
reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4), \
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6), \
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13)}; \
CUDACHECK(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_mc_rs_oop<x>), \
kernelArgs)); \
}
#define callranks_rs_oop_atomic_fp8(x) \
if (ar_nvsize == x) { \
int arg1 = op - NVTE_MAX_OPS, \
arg2 = NVTE_REG0_OFFSET(comm) - \
(op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \
NVTE_MAX_OPS, \
arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 16 / x, \
arg6 = offset / 16, arg8 = rowelements / 8, arg9 = strideelements_out / 8, \
arg10 = strideelements_in / 16; \
void **arg11 = reinterpret_cast<void **>(comm->gpu_ptrs); \
int arg12 = handler * comm->nvsize; \
void *arg13 = output; \
float *arg14 = scale; \
void *arg15 = counters; \
int arg16 = numchunks, arg17 = atomicindex; \
void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2), \
reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4), \
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6), \
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14), \
reinterpret_cast<void *>(&arg15), reinterpret_cast<void *>(&arg16), \
reinterpret_cast<void *>(&arg17)}; \
CUDACHECK(cudaLaunchKernelExC( \
&cfg, \
reinterpret_cast<void *>( \
userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_atomic_fp8<x, fp8type>), \
kernelArgs)); \
}
#define callranks_rs_oop_stride(x) \
if (ar_nvsize == x) { \
int arg1 = op - NVTE_MAX_OPS, \
arg2 = NVTE_REG0_OFFSET(comm) - \
(op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \
NVTE_MAX_OPS, \
arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \
arg6 = offset / 8, arg8 = rowelements / 8, arg9 = strideelements / 8; \
void **arg10 = reinterpret_cast<void **>(comm->gpu_ptrs); \
int arg11 = handler * comm->nvsize; \
void *arg12 = output; \
void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2), \
reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4), \
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6), \
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12)}; \
CUDACHECK(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride<x>), \
kernelArgs)); \
}
#if 0
#define callranks_rs_oop_stride_atomic_fp8(x) \
if (ar_nvsize == x) { \
int arg1 = op - NVTE_MAX_OPS, \
arg2 = NVTE_REG0_OFFSET(comm) - \
(op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \
NVTE_MAX_OPS, \
arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 16 / x, \
arg6 = offset / 16, arg8 = rowelements / 8, arg9 = strideelements / 8, arg10 = numchunks; \
void **arg11 = reinterpret_cast<void **>(comm->gpu_ptrs); \
int arg12 = handler * comm->nvsize; \
void *arg13 = output; \
void *arg14 = counters; \
float *arg15 = scale; \
void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2), \
reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4), \
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6), \
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14), \
reinterpret_cast<void *>(&arg15)}; \
CUDACHECK(cudaLaunchKernelExC( \
&cfg, \
reinterpret_cast<void *>( \
userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic_fp8<x, fp8type>), \
kernelArgs)); \
}
#endif
#define callranks_rs_oop_stride_atomic(x) \
if (ar_nvsize == x) { \
int arg1 = op - NVTE_MAX_OPS, \
arg2 = NVTE_REG0_OFFSET(comm) - \
(op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \
NVTE_MAX_OPS, \
arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \
arg6 = offset / 8, arg8 = rowelements / 8, arg9 = strideelements / 8, arg10 = numchunks; \
void **arg11 = reinterpret_cast<void **>(comm->gpu_ptrs); \
int arg12 = handler * comm->nvsize; \
void *arg13 = output; \
void *arg14 = counters; \
void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2), \
reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4), \
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6), \
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14)}; \
CUDACHECK(cudaLaunchKernelExC( \
&cfg, \
reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic<x>), \
kernelArgs)); \
}
#define callranks_rs_oop_stride_multiatomic(x) \
if (ar_nvsize == x) { \
int arg1 = op - NVTE_MAX_OPS, \
arg2 = NVTE_REG0_OFFSET(comm) - \
(op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \
NVTE_MAX_OPS, \
arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \
arg6 = offset / 8, arg8 = rowelements / 8, arg9 = strideelements / 8, arg10 = numchunks; \
void **arg11 = reinterpret_cast<void **>(comm->gpu_ptrs); \
int arg12 = handler * comm->nvsize; \
void *arg13 = output; \
void *arg14 = counters; \
void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2), \
reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4), \
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6), \
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14)}; \
CUDACHECK( \
cudaLaunchKernelExC(&cfg, \
reinterpret_cast<void *>( \
userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_multiatomic<x>), \
kernelArgs)); \
}
int reducescatter2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset, int reducescatter2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset,
const int elements, const int blocksize, communicator *comm, const int elements, const int blocksize, communicator *comm,
cudaStream_t stream, int op) { cudaStream_t stream, int op) {
// schedule GPU kernel only // schedule GPU kernel only
// CPU/SHARP part is responsibility of caller // CPU/SHARP part is responsibility of caller
const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes; const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes;
const int my_node = op == userbuffers_allreduceop_nonsharp ? comm->my_node : comm->my2_node; const int my_node = op == userbuffers_allreduceop_nonsharp ? comm->my_node : comm->my2_node;
const int ar_firstgpu =
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank;
if (elements < 8)
return 0;
int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32;
if (warps < ar_nvsize)
warps = ar_nvsize;
if (num_nodes > 1) {
callranks2_block_rs(1) callranks2_block_rs(2) callranks2_block_rs(4) callranks2_block_rs(8)
} else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs(2) callranks_rs(4) callranks_rs(8)
}
return sms;
}
void reducescatter2_userbuff_strided(void *output, const int handler, const int offset,
const int rowelements, const int colelements,
const int strideelements, communicator *comm,
cudaStream_t stream) {
const int elements = rowelements * colelements;
const int op = userbuffers_allreduceop_nonsharp2;
const int blocksize = elements * 2;
const int ar_firstgpu =
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank;
if (elements < 64)
return;
int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32;
if (warps < ar_nvsize)
warps = ar_nvsize;
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_stride(2) callranks_rs_oop_stride(4) callranks_rs_oop_stride(8)
}
void reducescatter2_userbuff_strided_atomic(void *output, const int handler, const int offset,
const int rowelements, const int colelements,
const int strideelements, const int numchunks,
void *counters, communicator *comm,
cudaStream_t stream) {
const int elements = rowelements * colelements;
const int op = userbuffers_allreduceop_nonsharp2;
const int blocksize = elements * 2;
const int ar_firstgpu =
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank;
if (elements < 64)
return;
int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32;
if (warps < ar_nvsize)
warps = ar_nvsize;
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_stride_atomic(2) callranks_rs_oop_stride_atomic(4)
callranks_rs_oop_stride_atomic(8)
}
#if 0
template<typename fp8type>
void reducescatter2_userbuff_strided_atomic_fp8(
void* output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, const int numchunks, void *counters,
communicator* comm, cudaStream_t stream) {
const int elements = rowelements*colelements;
const int op = userbuffers_allreduceop_nonsharp2;
const int blocksize = elements;
const int ar_firstgpu = op == userbuffers_allreduceop_nonsharp ?
comm->ar_firstgpu : comm->ar2_firstgpu;
const int ar_step = op == userbuffers_allreduceop_nonsharp2 ?
1 : comm->ar2_nvsize;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ?
comm->ar_nvsize : comm->ar2_nvsize;
const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ?
comm->ar_nvrank : comm->ar2_nvrank;
assert(comm->sm_arch >= 9);
if (elements < 128) return;
int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads/32;
if (warps < ar_nvsize) warps = ar_nvsize;
SETUP_LAUNCH_CONFIG(sms, warps*32, stream);
callranks_rs_oop_stride_atomic_fp8(2)
callranks_rs_oop_stride_atomic_fp8(4)
callranks_rs_oop_stride_atomic_fp8(8)
}
#endif
template <typename fp8type>
void reducescatter2_userbuff_strided_universal_fp8(void *output, float *scale, const int handler,
const int offset, const int rowelements,
const int colelements,
const int strideelements_out,
const int strideelements_in, const int numchunks,
const int atomicindex, void *counters,
communicator *comm, cudaStream_t stream) {
const int elements = rowelements * colelements;
const int op = userbuffers_allreduceop_nonsharp2;
const int blocksize = elements;
const int ar_firstgpu =
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank;
assert(comm->sm_arch >= 9);
if (elements < 128)
return;
int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32;
if (warps < ar_nvsize)
warps = ar_nvsize;
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_atomic_fp8(2) callranks_rs_oop_atomic_fp8(4) callranks_rs_oop_atomic_fp8(8)
}
template <typename fp8type>
void reducescatter2_userbuff_strided_atomic_fp8(void *output, float *scale, const int handler,
const int offset, const int rowelements,
const int colelements, const int strideelements_out,
const int strideelements_in, const int numchunks,
void *counters, communicator *comm,
cudaStream_t stream) {
reducescatter2_userbuff_strided_universal_fp8<fp8type>(
output, scale, handler, offset, rowelements, colelements, strideelements_out,
strideelements_in, 1, numchunks, counters /*nullptr*/, comm, stream);
}
template <typename fp8type>
void reducescatter2_userbuff_strided_multiatomic_fp8(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream) {
reducescatter2_userbuff_strided_universal_fp8<fp8type>(
output, scale, handler, offset, rowelements, colelements, strideelements_out,
strideelements_in, numchunks, 0, counters /*nullptr*/, comm, stream);
}
void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler, const int offset,
const int rowelements, const int colelements,
const int strideelements, const int numchunks,
void *counters, communicator *comm,
cudaStream_t stream) {
const int elements = rowelements * colelements;
const int op = userbuffers_allreduceop_nonsharp2;
const int blocksize = elements * 2;
const int ar_firstgpu = const int ar_firstgpu =
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank;
if (elements < 8) return 0; if (elements < 64)
return;
int sms = ar_nvsize == 1 ? 2 : comm->sms; int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32; int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize; if (warps < ar_nvsize)
warps = ar_nvsize;
if (num_nodes > 1) {
callranks2_block_rs(1) callranks2_block_rs(2) callranks2_block_rs(4) callranks2_block_rs(8)
} else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs(2) callranks_rs(4) callranks_rs(8) // if(comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) {
} // //callranks_rs_oopMC(2)
return sms; // //callranks_rs_oopMC(4)
// //callranks_rs_oopMC(8)
// } else {
// if(comm->memflags[handler] & NVTE_UB_MEM_UC_CONTIG) {
// //callranks_rs_oopUCPTR(2)
// //callranks_rs_oopUCPTR(4)
// //callranks_rs_oopUCPTR(8)
// } else {
callranks_rs_oop_stride_multiatomic(2) callranks_rs_oop_stride_multiatomic(4)
callranks_rs_oop_stride_multiatomic(8)
// }
//}
} }
int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset, int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset,
...@@ -1378,10 +2823,12 @@ int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons ...@@ -1378,10 +2823,12 @@ int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank;
if (elements < 8) return 0; if (elements < 8)
return 0;
int sms = ar_nvsize == 1 ? 2 : comm->sms; int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32; int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize; if (warps < ar_nvsize)
warps = ar_nvsize;
if (num_nodes > 1) { if (num_nodes > 1) {
callranks2_block_ag(1) callranks2_block_ag(2) callranks2_block_ag(4) callranks2_block_ag(8) callranks2_block_ag(1) callranks2_block_ag(2) callranks2_block_ag(4) callranks2_block_ag(8)
...@@ -1402,10 +2849,12 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int ...@@ -1402,10 +2849,12 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank;
if (elements < 64) return; if (elements < 64)
return;
int sms = ar_nvsize == 1 ? 2 : comm->sms; int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32; int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize; if (warps < ar_nvsize)
warps = ar_nvsize;
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_ag(2) callranks_ag(4) callranks_ag(8) callranks_ag(2) callranks_ag(4) callranks_ag(8)
...@@ -1436,10 +2885,12 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const ...@@ -1436,10 +2885,12 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank;
if (elements < 64) return; if (elements < 64)
return;
int sms = ar_nvsize == 1 ? 2 : comm->sms; int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32; int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize; if (warps < ar_nvsize)
warps = ar_nvsize;
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs(2) callranks_rs(4) callranks_rs(8) callranks_rs(2) callranks_rs(4) callranks_rs(8)
...@@ -1457,10 +2908,12 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons ...@@ -1457,10 +2908,12 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank;
if (elements < 64) return; if (elements < 64)
return;
int sms = ar_nvsize == 1 ? 2 : comm->sms; int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32; int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize; if (warps < ar_nvsize)
warps = ar_nvsize;
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8)
...@@ -1470,8 +2923,109 @@ void reducescatter2_userbuff(void *output, const int handler, const int offset, ...@@ -1470,8 +2923,109 @@ void reducescatter2_userbuff(void *output, const int handler, const int offset,
reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream); reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream);
} }
template <typename fp8type>
void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler,
const int offset, const int rowelements,
const int colelements, const int strideelements,
communicator *comm, cudaStream_t stream) {
const int elements = rowelements * colelements;
const int op = userbuffers_allreduceop_nonsharp2;
const int blocksize = elements;
const int ar_firstgpu =
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank;
assert(comm->sm_arch >= 9);
if (elements < 128)
return;
int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32;
if (warps < ar_nvsize)
warps = ar_nvsize;
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
}
template <typename fp8type>
void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset,
const int elements, communicator *comm, cudaStream_t stream) {
reducescatter2_userbuff_stridedoutput_fp8<fp8type>(output, scale, handler, offset, elements, 1, 0,
comm, stream);
}
template void reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(void *output, float *scale,
const int handler, const int offset,
const int elements, communicator *comm,
cudaStream_t stream = 0);
template void reducescatter2_userbuff_fp8<__nv_fp8_e4m3>(void *output, float *scale,
const int handler, const int offset,
const int elements, communicator *comm,
cudaStream_t stream = 0);
#if 0
template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>(
void* output, float *scale, const int handler, const int offset,
const int rowelements, const int colelements, const int strideelements,
const int numchunks, void *counters, communicator* comm, cudaStream_t stream = 0);
#endif
template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream = 0);
template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream = 0);
__global__ void __launch_bounds__(MAX_THREADS)
kuserbuffers_pullsendrecv(int myrank, int peer, int *recv_id, int *send_flagptr,
int *recv_flagptr, int4 *srcptr, int4 *dstptr, const int lines) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
atomicAdd_system(send_flagptr, 1);
}
#define UNROLLCOPY 8
const int start_elem = threadIdx.x + blockDim.x * blockIdx.x;
const int end_elem = lines;
const int aligned_elem = (end_elem - start_elem) & (~(blockDim.x * gridDim.x * UNROLLCOPY - 1));
const int end_aligned = start_elem + aligned_elem;
if (threadIdx.x == 0) {
const int signal_id = (*recv_id) + 1;
volatile int *flag = (volatile int *)recv_flagptr;
clock_t s = clock64();
while (*flag < signal_id) {
if (clock64() - s > TIMEOUT) {
printf("[%d from %d] pullrecv: expected %d, stuck with %d\n", myrank, peer, signal_id,
*flag);
break;
}
}
if (lines == 0) {
*recv_id = signal_id;
return;
} // otherwise need an extra kernel
}
__syncthreads();
if (end_elem <= start_elem)
return;
for (int line = start_elem; line < end_aligned; line += blockDim.x * gridDim.x * UNROLLCOPY) {
int4 val[UNROLLCOPY];
#pragma unroll
for (int i = 0; i < UNROLLCOPY; i++)
val[i] = srcptr[line + i * blockDim.x * gridDim.x];
#pragma unroll
for (int i = 0; i < UNROLLCOPY; i++)
dstptr[line + i * blockDim.x * gridDim.x] = val[i];
}
for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x)
dstptr[line] = srcptr[line];
}
__global__ void kuserbuffers_pullsend(int myrank, int peer, int *send_id, int *flagptr) { __global__ void kuserbuffers_pullsend(int myrank, int peer, int *send_id, int *flagptr) {
atomicAdd(flagptr, 1); atomicAdd_system(flagptr, 1);
} }
__global__ void kuserbuffers_inc(int *id) { __global__ void kuserbuffers_inc(int *id) {
...@@ -1514,14 +3068,17 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1514,14 +3068,17 @@ __global__ void __launch_bounds__(MAX_THREADS)
} }
__syncthreads(); __syncthreads();
if (end_elem <= start_elem) return; if (end_elem <= start_elem)
return;
for (int line = start_elem; line < end_aligned; line += blockDim.x * gridDim.x * UNROLLCOPY) { for (int line = start_elem; line < end_aligned; line += blockDim.x * gridDim.x * UNROLLCOPY) {
int4 val[UNROLLCOPY]; int4 val[UNROLLCOPY];
#pragma unroll #pragma unroll
for (int i = 0; i < UNROLLCOPY; i++) val[i] = srcptr[line + i * blockDim.x * gridDim.x]; for (int i = 0; i < UNROLLCOPY; i++)
val[i] = srcptr[line + i * blockDim.x * gridDim.x];
#pragma unroll #pragma unroll
for (int i = 0; i < UNROLLCOPY; i++) dstptr[line + i * blockDim.x * gridDim.x] = val[i]; for (int i = 0; i < UNROLLCOPY; i++)
dstptr[line + i * blockDim.x * gridDim.x] = val[i];
} }
for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x) for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x)
dstptr[line] = srcptr[line]; dstptr[line] = srcptr[line];
...@@ -1539,17 +3096,21 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1539,17 +3096,21 @@ __global__ void __launch_bounds__(MAX_THREADS)
for (int line = start_elem; line < end_aligned; line += blockDim.x * gridDim.x * UNROLLCOPY) { for (int line = start_elem; line < end_aligned; line += blockDim.x * gridDim.x * UNROLLCOPY) {
int4 val[UNROLLCOPY]; int4 val[UNROLLCOPY];
#pragma unroll #pragma unroll
for (int i = 0; i < UNROLLCOPY; i++) val[i] = srcptr[line + i * blockDim.x * gridDim.x]; for (int i = 0; i < UNROLLCOPY; i++)
val[i] = srcptr[line + i * blockDim.x * gridDim.x];
#pragma unroll #pragma unroll
for (int i = 0; i < UNROLLCOPY; i++) dstptr[line + i * blockDim.x * gridDim.x] = val[i]; for (int i = 0; i < UNROLLCOPY; i++)
dstptr[line + i * blockDim.x * gridDim.x] = val[i];
} }
for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x) for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x)
dstptr[line] = srcptr[line]; dstptr[line] = srcptr[line];
} }
__syncthreads(); __syncthreads();
if (threadIdx.x) return; if (threadIdx.x)
return;
__threadfence_system(); __threadfence_system();
atomicAdd(flagptr, 1); // otherwise need local SM sync before sending flag atomicAdd_system(flagptr,
1); // otherwise need local SM sync before sending flag
} else { // 0 bytes and 1 SM only } else { // 0 bytes and 1 SM only
atomicAdd_system(flagptr, 1); atomicAdd_system(flagptr, 1);
} }
...@@ -1559,7 +3120,8 @@ __global__ void kuserbuffers_pushrecv(int myrank, int peer, int *recv_id, int *f ...@@ -1559,7 +3120,8 @@ __global__ void kuserbuffers_pushrecv(int myrank, int peer, int *recv_id, int *f
const int signal_id = (*recv_id) + adder; const int signal_id = (*recv_id) + adder;
*recv_id = signal_id; *recv_id = signal_id;
volatile int *flag = (volatile int *)flagptr; volatile int *flag = (volatile int *)flagptr;
if (*flag >= signal_id) return; if (*flag >= signal_id)
return;
clock_t s = clock64(); clock_t s = clock64();
while (atomicAdd_system(flagptr, 0) < signal_id) { while (atomicAdd_system(flagptr, 0) < signal_id) {
if (clock64() - s > TIMEOUT) { if (clock64() - s > TIMEOUT) {
...@@ -1569,6 +3131,196 @@ __global__ void kuserbuffers_pushrecv(int myrank, int peer, int *recv_id, int *f ...@@ -1569,6 +3131,196 @@ __global__ void kuserbuffers_pushrecv(int myrank, int peer, int *recv_id, int *f
} }
} }
__global__ void __launch_bounds__(MAX_THREADS)
kuserbuffers_pushsendrecv(int *send_id, int *send_flagptr, int4 *srcptr, int4 *dstptr,
const int lines, int myrank, int peer, int *recv_id,
int *recv_flagptr, int adder) {
if (lines) {
const int start_elem = threadIdx.x + blockDim.x * blockIdx.x;
const int end_elem = lines;
const int aligned_elem =
((end_elem - start_elem) & (~(blockDim.x * gridDim.x * UNROLLCOPY - 1)));
const int end_aligned = start_elem + aligned_elem;
if (end_elem > start_elem) {
for (int line = start_elem; line < end_aligned; line += blockDim.x * gridDim.x * UNROLLCOPY) {
int4 val[UNROLLCOPY];
#pragma unroll
for (int i = 0; i < UNROLLCOPY; i++) {
val[i] = srcptr[line + i * blockDim.x * gridDim.x];
}
#pragma unroll
for (int i = 0; i < UNROLLCOPY; i++) {
dstptr[line + i * blockDim.x * gridDim.x] = val[i];
}
}
for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x) {
dstptr[line] = srcptr[line];
}
}
__syncthreads();
if (threadIdx.x)
return;
__threadfence_system();
atomicAdd_system(send_flagptr,
1); // otherwise need local SM sync before sending flag
} else { // 0 bytes and 1 SM only
atomicAdd_system(send_flagptr, 1);
}
if (blockIdx.x == 0 && threadIdx.x == 0) {
const int signal_id = (*recv_id) + adder;
*recv_id = signal_id;
volatile int *flag = (volatile int *)recv_flagptr;
if (*flag >= signal_id)
return;
clock_t s = clock64();
while (*flag < signal_id) {
if (clock64() - s > TIMEOUT) {
printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id,
*flag);
return;
}
}
}
}
__global__ void __launch_bounds__(MAX_THREADS)
kuserbuffers_pushsendrecv_atomic(int *send_id, int *send_flagptr, int4 *srcptr, int4 *dstptr,
const int lines, int myrank, int peer, int *recv_id,
int *recv_flagptr, int adder, void *counters) {
if (lines) {
const int start_elem = threadIdx.x + blockDim.x * blockIdx.x;
const int end_elem = lines;
const int aligned_elem =
((end_elem - start_elem) & (~(blockDim.x * gridDim.x * UNROLLCOPY - 1)));
const int end_aligned = start_elem + aligned_elem;
if (end_elem > start_elem) {
for (int line = start_elem; line < end_aligned; line += blockDim.x * gridDim.x * UNROLLCOPY) {
int4 val[UNROLLCOPY];
#pragma unroll
for (int i = 0; i < UNROLLCOPY; i++) {
val[i] = srcptr[line + i * blockDim.x * gridDim.x];
}
#pragma unroll
for (int i = 0; i < UNROLLCOPY; i++) {
dstptr[line + i * blockDim.x * gridDim.x] = val[i];
}
}
for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x) {
dstptr[line] = srcptr[line];
}
}
__syncthreads();
if (threadIdx.x)
return;
__threadfence_system();
atomicAdd_system(send_flagptr,
1); // otherwise need local SM sync before sending flag
} else { // 0 bytes and 1 SM only
atomicAdd_system(send_flagptr, 1);
}
if (blockIdx.x == 0 && threadIdx.x == 0) {
const int signal_id = (*recv_id) + adder;
*recv_id = signal_id;
volatile int *flag = (volatile int *)recv_flagptr;
// if(*flag>=signal_id) return;
clock_t s = clock64();
while (*flag < signal_id) {
if (clock64() - s > TIMEOUT) {
printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id,
*flag); /*return;*/
}
}
// Decrement atomic val to signal current output tile finish
if (counters) {
((unsigned int *)counters)[0] = 0;
asm volatile("fence.sc.gpu;\n");
}
}
}
__global__ void __launch_bounds__(MAX_THREADS)
kuserbuffers_pushsendrecv_multiatomic(int *send_id, int *send_flagptr, int4 *srcptr,
int4 *dstptr, const int lines, int myrank, int peer,
int *recv_id, int *recv_flagptr, int adder,
void *counters, int nchunks, int send_stride,
int recv_stride, bool shuffle) {
for (int chunk_i = 0; chunk_i < nchunks - 1; chunk_i++) {
int send_chunk_id = shuffle ? chunk_i : (nchunks + myrank - chunk_i) % nchunks;
int recv_chunk_id = shuffle ? chunk_i + 1 : (nchunks + myrank - chunk_i - 1) % nchunks;
int send_offset = (send_chunk_id * send_stride) / 16;
int recv_offset = ((shuffle ? recv_chunk_id : send_chunk_id) * recv_stride) / 16;
if (lines) {
const int start_elem = threadIdx.x + blockDim.x * blockIdx.x;
const int end_elem = lines;
const int aligned_elem =
((end_elem - start_elem) & (~(blockDim.x * gridDim.x * UNROLLCOPY - 1)));
const int end_aligned = start_elem + aligned_elem;
if (end_elem > start_elem) {
for (int line = start_elem; line < end_aligned;
line += blockDim.x * gridDim.x * UNROLLCOPY) {
int4 val[UNROLLCOPY];
#pragma unroll
for (int i = 0; i < UNROLLCOPY; i++) {
val[i] = srcptr[send_offset + line + i * blockDim.x * gridDim.x];
}
#pragma unroll
for (int i = 0; i < UNROLLCOPY; i++) {
dstptr[recv_offset + line + i * blockDim.x * gridDim.x] = val[i];
}
}
for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x) {
dstptr[recv_offset + line] = srcptr[send_offset + line];
}
}
__syncthreads();
if (!threadIdx.x) {
__threadfence_system();
atomicAdd_system(send_flagptr,
1); // otherwise need local SM sync before sending flag
}
} else { // 0 bytes and 1 SM only
atomicAdd_system(send_flagptr, 1);
}
// wait for message to arrive.
if (blockIdx.x == 0 && threadIdx.x == 0) {
const int signal_id = (*recv_id) + adder;
*recv_id = signal_id;
volatile int *flag = (volatile int *)recv_flagptr;
// if(*flag>=signal_id) return;
clock_t s = clock64();
while (*flag < signal_id) {
if (clock64() - s > TIMEOUT) {
printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id,
*flag); /*return;*/
}
}
}
// Producer must update counters.
if (blockIdx.x == 0 && threadIdx.x == 0) {
// Decrement atomic val to signal current output tile finish
if (counters) {
((unsigned int *)counters)[recv_chunk_id /*chunk_i+1*/] = 0;
asm volatile("fence.sc.gpu;\n");
}
}
// sync all CTAs before moving to next chunk.
if (threadIdx.x == 0) {
int old_val2;
atomicInc(((unsigned int *)counters) + nchunks + chunk_i, gridDim.x - 1);
while (0 != (old_val2 = atomicCAS(((unsigned int *)counters) + nchunks + chunk_i, 0, 0))) {
}
}
__syncthreads();
}
}
#define CUDACHECK(cmd) \ #define CUDACHECK(cmd) \
do { \ do { \
cudaError_t e = cmd; \ cudaError_t e = cmd; \
...@@ -1611,7 +3363,8 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds ...@@ -1611,7 +3363,8 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds
comm->hostflags + userbuffers_sendop); comm->hostflags + userbuffers_sendop);
return; return;
} }
if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) return; if (!(comm->launch_mode & NVTE_LAUNCH_GPU))
return;
if (comm->push == 0) { if (comm->push == 0) {
kuserbuffers_pullsend<<<1, 1, 0, stream>>>(comm->myrank, peer, &(comm->send_id[peer]), kuserbuffers_pullsend<<<1, 1, 0, stream>>>(comm->myrank, peer, &(comm->send_id[peer]),
reinterpret_cast<int *>(flagptr)); reinterpret_cast<int *>(flagptr));
...@@ -1633,10 +3386,145 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds ...@@ -1633,10 +3386,145 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds
} }
} }
void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size_t send_offset,
const size_t recv_offset, const size_t bytes, communicator *comm,
const int send_peer, const int recv_peer, cudaStream_t stream) {
bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0);
int send_peerlocal = send_peer % comm->nvsize;
int recv_peerlocal = recv_peer % comm->nvsize;
void *flagptr_send =
(comm->peer_ptr[0][send_peerlocal]) +
((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + comm->myrank * NVTE_MAX_REGIONS + dsthandler) *
sizeof(int));
void *flagptr_recv =
(comm->mem_ptr[0]) +
((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + recv_peer * NVTE_MAX_REGIONS + dsthandler) *
sizeof(int));
void *send_srcptr = (comm->mem_ptr[srchandler]) + send_offset;
void *send_dstptr = (comm->peer_ptr[dsthandler][send_peerlocal]) + send_offset;
if (comm->use_ce)
CUDACHECK(cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream);
int *arg1 = &comm->send_id[send_peer];
int *arg2 = reinterpret_cast<int *>(flagptr_send);
int4 *arg3 = reinterpret_cast<int4 *>(send_srcptr);
int4 *arg4 = reinterpret_cast<int4 *>(send_dstptr);
int arg5 = signalonly ? 0 : bytes / 16;
int arg6 = comm->myrank;
int arg7 = recv_peer;
int *arg8 = &comm->recv_id[recv_peer * NVTE_MAX_REGIONS + dsthandler];
int *arg9 = reinterpret_cast<int *>(flagptr_recv);
int arg10 = signalonly ? 1 : comm->sms;
void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2),
reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4),
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6),
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8),
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10)};
CUDACHECK(
cudaLaunchKernelExC(&cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv), kernelArgs));
//}
}
void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler,
const size_t send_offset, const size_t recv_offset,
const size_t bytes, communicator *comm, const int send_peer,
const int recv_peer, void *counters, cudaStream_t stream) {
assert(comm->push && comm->use_ce == 0);
bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0);
int send_peerlocal = send_peer % comm->nvsize;
int recv_peerlocal = recv_peer % comm->nvsize;
void *flagptr_send =
(comm->peer_ptr[0][send_peerlocal]) +
((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + comm->myrank * NVTE_MAX_REGIONS + dsthandler) *
sizeof(int));
void *flagptr_recv =
(comm->mem_ptr[0]) +
((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + recv_peer * NVTE_MAX_REGIONS + dsthandler) *
sizeof(int));
void *send_srcptr = (comm->mem_ptr[srchandler]) + send_offset;
void *send_dstptr = (comm->peer_ptr[dsthandler][send_peerlocal]) + send_offset;
if (comm->use_ce) {
CUDACHECK(cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
}
SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream);
int *arg1 = &comm->send_id[send_peer];
int *arg2 = reinterpret_cast<int *>(flagptr_send);
int4 *arg3 = reinterpret_cast<int4 *>(send_srcptr);
int4 *arg4 = reinterpret_cast<int4 *>(send_dstptr);
int arg5 = signalonly ? 0 : bytes / 16;
int arg6 = comm->myrank;
int arg7 = recv_peer;
int *arg8 = &comm->recv_id[recv_peer * NVTE_MAX_REGIONS + dsthandler];
int *arg9 = reinterpret_cast<int *>(flagptr_recv);
int arg10 = signalonly ? 1 : comm->sms;
void *arg11 = counters;
void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2),
reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4),
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6),
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8),
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10),
reinterpret_cast<void *>(&arg11)};
CUDACHECK(cudaLaunchKernelExC(&cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv_atomic),
kernelArgs));
}
void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler,
const size_t send_stride, const size_t recv_stride,
const size_t bytes, communicator *comm, const int send_peer,
const int recv_peer, const int nchunks, void *counters,
bool shuffle, cudaStream_t stream) {
assert(comm->push && comm->use_ce == 0);
int send_peerlocal = send_peer % comm->nvsize;
int recv_peerlocal = recv_peer % comm->nvsize;
void *flagptr_send =
(comm->peer_ptr[0][send_peerlocal]) +
((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + comm->myrank * NVTE_MAX_REGIONS + dsthandler) *
sizeof(int));
void *flagptr_recv =
(comm->mem_ptr[0]) +
((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + recv_peer * NVTE_MAX_REGIONS + dsthandler) *
sizeof(int));
SETUP_LAUNCH_CONFIG(comm->sms, 1024, stream);
int *arg1 = &comm->send_id[send_peer];
int *arg2 = reinterpret_cast<int *>(flagptr_send);
int4 *arg3 = reinterpret_cast<int4 *>((comm->mem_ptr[srchandler]));
int4 *arg4 = reinterpret_cast<int4 *>((comm->peer_ptr[dsthandler][send_peerlocal]));
int arg5 = bytes / 16;
int arg6 = comm->myrank;
int arg7 = recv_peer;
int *arg8 = &comm->recv_id[recv_peer * NVTE_MAX_REGIONS + dsthandler];
int *arg9 = reinterpret_cast<int *>(flagptr_recv);
int arg10 = comm->sms;
void *arg11 = counters;
int arg12 = nchunks;
int arg13 = send_stride;
int arg14 = recv_stride;
bool arg15 = shuffle;
void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2),
reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4),
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6),
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8),
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10),
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12),
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14),
reinterpret_cast<void *>(&arg15)};
CUDACHECK(cudaLaunchKernelExC(
&cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv_multiatomic), kernelArgs));
}
__global__ void __launch_bounds__(MAX_THREADS) __global__ void __launch_bounds__(MAX_THREADS)
kuserbuffers_alltoall(void **baseflagptrs, int flagoffset, int4 *basesrcptr, void **dstptrs, kuserbuffers_alltoall(void **baseflagptrs, int flagoffset, int4 *basesrcptr, void **dstptrs,
size_t dstoffset, const int lines, const int myrank) { size_t dstoffset, const int lines, const int myrank) {
if (blockIdx.x == myrank) return; if (blockIdx.x == myrank)
return;
int4 *dstptr = reinterpret_cast<int4 *>(dstptrs[blockIdx.x] + dstoffset); int4 *dstptr = reinterpret_cast<int4 *>(dstptrs[blockIdx.x] + dstoffset);
int *flagptr = reinterpret_cast<int *>(baseflagptrs[blockIdx.x] + flagoffset); int *flagptr = reinterpret_cast<int *>(baseflagptrs[blockIdx.x] + flagoffset);
const size_t myblockoffset = blockIdx.x * lines; const size_t myblockoffset = blockIdx.x * lines;
...@@ -1652,14 +3540,18 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1652,14 +3540,18 @@ __global__ void __launch_bounds__(MAX_THREADS)
for (int line = start_elem; line < end_aligned; line += blockDim.x * UNROLLCOPY) { for (int line = start_elem; line < end_aligned; line += blockDim.x * UNROLLCOPY) {
int4 val[UNROLLCOPY]; int4 val[UNROLLCOPY];
#pragma unroll #pragma unroll
for (int i = 0; i < UNROLLCOPY; i++) val[i] = srcptr[line + i * blockDim.x]; for (int i = 0; i < UNROLLCOPY; i++)
val[i] = srcptr[line + i * blockDim.x];
#pragma unroll #pragma unroll
for (int i = 0; i < UNROLLCOPY; i++) dstptr[line + i * blockDim.x] = val[i]; for (int i = 0; i < UNROLLCOPY; i++)
dstptr[line + i * blockDim.x] = val[i];
} }
for (int line = end_aligned; line < end_elem; line += blockDim.x) dstptr[line] = srcptr[line]; for (int line = end_aligned; line < end_elem; line += blockDim.x)
dstptr[line] = srcptr[line];
} }
__syncthreads(); __syncthreads();
if (threadIdx.x) return; if (threadIdx.x)
return;
__threadfence_system(); __threadfence_system();
atomicAdd(flagptr, 1); atomicAdd(flagptr, 1);
...@@ -1702,7 +3594,8 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds ...@@ -1702,7 +3594,8 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
sizeof(int)); sizeof(int));
bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0);
bool intranode = INTRANODE(peer); bool intranode = INTRANODE(peer);
if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) return; if (!(comm->launch_mode & NVTE_LAUNCH_GPU))
return;
if (comm->push == 0 && intranode) { if (comm->push == 0 && intranode) {
void *dstptr = (comm->mem_ptr[dsthandler]) + dstoffset; void *dstptr = (comm->mem_ptr[dsthandler]) + dstoffset;
void *srcptr = (comm->peer_ptr[srchandler][peerlocal]) + srcoffset; void *srcptr = (comm->peer_ptr[srchandler][peerlocal]) + srcoffset;
...@@ -1728,7 +3621,45 @@ void userbuffers_alltoall_recv(communicator *comm, cudaStream_t stream) { ...@@ -1728,7 +3621,45 @@ void userbuffers_alltoall_recv(communicator *comm, cudaStream_t stream) {
(comm->mem_ptr[0]) + (comm->mem_ptr[0]) +
((NVTE_REG0_OFFSET(comm) + NVTE_REG0_OPFLAGS * userbuffers_alltoall) * sizeof(int)); ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_OPFLAGS * userbuffers_alltoall) * sizeof(int));
if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) return; if (!(comm->launch_mode & NVTE_LAUNCH_GPU))
return;
kuserbuffers_pushrecv<<<1, 1, 0, stream>>>(comm->myrank, -1, reinterpret_cast<int *>(flagptr + 4), kuserbuffers_pushrecv<<<1, 1, 0, stream>>>(comm->myrank, -1, reinterpret_cast<int *>(flagptr + 4),
reinterpret_cast<int *>(flagptr), comm->nranks - 1); reinterpret_cast<int *>(flagptr), comm->nranks - 1);
} }
// producer
static __global__ void producer_kernel(void *atomic_ptr, int chunk_i) {
// Decrement atomic val to signal current output tile finish
if (blockIdx.x == 0 && threadIdx.x == 0) {
((unsigned int *)atomic_ptr)[chunk_i] = 0;
}
// COMM kernel need to explicitely flash gmem.
// GEMM kernel already executed, and can not see gmem
// change without COMM kernel explicitely make change
asm volatile("fence.sc.gpu;\n");
}
// consumer
static __global__ void consumer_kernel(void *atomic_ptr, int chunk_i) {
// Wait for producer to change the val to 0, which signal producer ready
if (blockIdx.x == 0 && threadIdx.x == 0) {
int old_val;
while (0 != (old_val = atomicCAS((unsigned int *)atomic_ptr + chunk_i, 0, 0))) {
}
((unsigned int *)atomic_ptr)[chunk_i] = 1;
asm volatile("fence.sc.gpu;\n");
}
}
void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) {
dim3 block(1);
dim3 grid(1);
producer_kernel<<<grid, block, 0, stream>>>(atomic_ptr, chunk_i);
}
void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream) {
dim3 block(1);
dim3 grid(1);
consumer_kernel<<<grid, block, 0, stream>>>(atomic_ptr, chunk_i);
}
...@@ -24,6 +24,18 @@ ...@@ -24,6 +24,18 @@
#define NVTE_LAUNCH_CPU 2 #define NVTE_LAUNCH_CPU 2
#define NVTE_MAX_NVLINK 8 #define NVTE_MAX_NVLINK 8
#define UB_MEM_UC_CONTIG 1
#define UB_MEM_MC_CREATED 2
#define UB_MEM_ALLOCATED 4
#define NVTE_UB_MEM_UC_CONTIG 1
#define NVTE_UB_MEM_MC_CREATED 2
#define NVTE_UB_MEM_ALLOCATED 4
#ifdef UCP
#include <ucp/api/ucp.h>
#endif
// region 0 flag offsets // region 0 flag offsets
#define NVTE_REG0_OPFLAGS 1024 #define NVTE_REG0_OPFLAGS 1024
#define NVTE_REG0_RECV (NVTE_REG0_OPFLAGS * userbuffers_op_types) #define NVTE_REG0_RECV (NVTE_REG0_OPFLAGS * userbuffers_op_types)
...@@ -35,6 +47,10 @@ ...@@ -35,6 +47,10 @@
#define NVTE_REG0_IBRS 32 #define NVTE_REG0_IBRS 32
#define NVTE_REG0_IBAG 512 #define NVTE_REG0_IBAG 512
#if defined(UCP) || !defined(NOSHARP)
#undef REG0_COMMBUFFER
#define REG0_COMMBUFFER (1024*1024*16)
#endif
// gpuflags map offsets // gpuflags map offsets
#define NVTE_GF_STATE 16000 #define NVTE_GF_STATE 16000
#define NVTE_GF_IBSHARPDONE 0 #define NVTE_GF_IBSHARPDONE 0
...@@ -81,6 +97,19 @@ struct communicator { ...@@ -81,6 +97,19 @@ struct communicator {
void *mem_ptr[NVTE_MAX_REGIONS]; void *mem_ptr[NVTE_MAX_REGIONS];
void **peer_ptr[NVTE_MAX_REGIONS]; void **peer_ptr[NVTE_MAX_REGIONS];
int memflags[NVTE_MAX_REGIONS]; // UC,MC, user/lib allocated
CUmemGenericAllocationHandle *uchandles[NVTE_MAX_REGIONS];
void* ucbase_ptr[NVTE_MAX_REGIONS]; // only for cuMem allocated memory
size_t mem_size[NVTE_MAX_REGIONS];
void* mc_ptr[NVTE_MAX_REGIONS];
void* mc_baseptr;
CUmemGenericAllocationHandle mc_handle;
size_t mc_offset, mc_maxsize;
int use_mc; // 1: use MC if available, 0: override not to use MC
int ar_nvsize, ar_firstgpu, int ar_nvsize, ar_firstgpu,
ar_nvrank; // number of gpus(and first gpu in a group) of gpus per node in reduction subgroup ar_nvrank; // number of gpus(and first gpu in a group) of gpus per node in reduction subgroup
// (_splitar init used) would be equal to (nvsize,0) for regular comm_create // (_splitar init used) would be equal to (nvsize,0) for regular comm_create
...@@ -120,6 +149,8 @@ struct communicator { ...@@ -120,6 +149,8 @@ struct communicator {
}; };
typedef struct communicator communicator; typedef struct communicator communicator;
void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream);
void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream);
int create_communicator(communicator **comm); int create_communicator(communicator **comm);
/* creates communicator, allocates all internal buffers if necessary */ /* creates communicator, allocates all internal buffers if necessary */
...@@ -191,6 +222,45 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons ...@@ -191,6 +222,45 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons
const int rowelements, const int colelements, const int rowelements, const int colelements,
const int strideelements, communicator *comm, const int strideelements, communicator *comm,
cudaStream_t stream = 0); cudaStream_t stream = 0);
template<typename fp8type>
void reducescatter2_userbuff_stridedoutput_fp8(void* output, float* scale, const int handler,
const int offset, const int rowelements,
const int colelements, const int strideelements,
communicator* comm, cudaStream_t stream = 0);
template<typename fp8type>
void reducescatter2_userbuff_fp8(void* output, float* scale, const int handler, const int offset,
const int elements, communicator* comm, cudaStream_t stream = 0);
#if 0
template<typename fp8type>
void reducescatter2_userbuff_strided_atomic_fp8(void* output, float *scale, const int handler,
const int offset, const int rowelements,
const int colelements, const int strideelements,
const int numchunks, void *counters,
communicator* comm, cudaStream_t stream = 0);
#endif
template<typename fp8type>
void reducescatter2_userbuff_strided_atomic_fp8(void* output, float *scale, const int handler,
const int offset, const int rowelements,
const int colelements, const int strideelements_out,
const int strideelements_in, const int numchunks,
void *counters, communicator* comm,
cudaStream_t stream = 0);
template<typename fp8type>
void reducescatter2_userbuff_strided_multiatomic_fp8(
void* output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator* comm, cudaStream_t stream = 0);
void reducescatter2_userbuff_strided(
void* output, const int handler, const int offset, const int rowelements, const int colelements,
const int strideelements, communicator* comm, cudaStream_t stream = 0);
void reducescatter2_userbuff_strided_atomic(
void* output, const int handler , const int offset, const int rowelements, const int colelements,
const int strideelements, const int numchunks, void *counters, communicator* comm,
cudaStream_t stream = 0);
void reducescatter2_userbuff_strided_multiatomic(
void* output, const int handler, const int offset, const int rowelements, const int colelements,
const int strideelements, const int numchunks, void *counters, communicator* comm,
cudaStream_t stream = 0);
/* everything should be 16byte aligned = 8 elts aligned /* everything should be 16byte aligned = 8 elts aligned
output is strided: row starts separated by stride elements*/ output is strided: row starts separated by stride elements*/
...@@ -208,6 +278,19 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds ...@@ -208,6 +278,19 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds
void userbuffers_recv(const int srchandler, const size_t srcoffset, const int dsthandler, void userbuffers_recv(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes, communicator *comm, const size_t dstoffset, const size_t bytes, communicator *comm,
const int peer, cudaStream_t stream = 0); const int peer, cudaStream_t stream = 0);
void userbuffers_sendrecv(
const int srchandler, const int dsthandler, const size_t send_offset, const size_t recv_offset,
const size_t bytes, communicator* comm, const int send_peer, const int recv_peer,
cudaStream_t stream = 0);
void userbuffers_sendrecv_atomic(
const int srchandler, const int dsthandler, const size_t send_offset, const size_t recv_offset,
const size_t bytes, communicator* comm, const int send_peer, const int recv_peer, void *counters,
cudaStream_t stream = 0);
void userbuffers_sendrecv_multiatomic(
const int srchandler, const int dsthandler, const size_t send_offset, const size_t recv_offset,
const size_t bytes, communicator* comm, const int send_peer, const int recv_peer,
const int nchunks, void *counters, bool shuffle, cudaStream_t stream = 0);
// alltoall split send and recv to allow for overlap // alltoall split send and recv to allow for overlap
// send kicks in sending data to the destination - invoke on same stream as data generation // send kicks in sending data to the destination - invoke on same stream as data generation
......
...@@ -124,6 +124,8 @@ def initialize_ub( ...@@ -124,6 +124,8 @@ def initialize_ub(
fp8_buf = [ fp8_buf = [
"qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad" "qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad"
] ]
if bool(int(os.getenv("NVTE_UB_FP8_RS", "0"))):
fp8_buf.append ("proj_fprop")
# Default overlap methods for layers # Default overlap methods for layers
methods = { methods = {
"ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"], "ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
...@@ -153,8 +155,12 @@ def initialize_ub( ...@@ -153,8 +155,12 @@ def initialize_ub(
sample_buffer, # Sample userbuffer sample_buffer, # Sample userbuffer
rank_id, # Rank id rank_id, # Rank id
tp_size, # TP size tp_size, # TP size
num_sm, # Number of communication SMs
cga_size, # CGA cluster size
set_sm_margin, # Set SM margin
aggregate, # Aggregate 2X GEMM chunks aggregate, # Aggregate 2X GEMM chunks
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
torch.Tensor(), # empty tensor to pass to counters
) )
else: else:
ub_obj = tex.UbufCommOverlap( ub_obj = tex.UbufCommOverlap(
...@@ -166,6 +172,7 @@ def initialize_ub( ...@@ -166,6 +172,7 @@ def initialize_ub(
num_splits, # Number of communication splits num_splits, # Number of communication splits
set_sm_margin, # Set SM margin set_sm_margin, # Set SM margin
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
torch.Tensor(), # empty tensor to pass to counters
) )
_ub_communicators[name] = ub_obj _ub_communicators[name] = ub_obj
...@@ -676,10 +683,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -676,10 +683,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_output_mat = grad_output.view((-1, grad_output.shape[-1])) grad_output_mat = grad_output.view((-1, grad_output.shape[-1]))
gather_grad_output = row_parallel_mode and ctx.sequence_parallel gather_grad_output = row_parallel_mode and ctx.sequence_parallel
if gather_grad_output:
ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag
# No-FP8 case: bgrad is fused with wgrad for this case. # No-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8: if not ctx.fp8:
if gather_grad_output: if gather_grad_output:
if not ctx.ub_split_ag: if not ub_overlap_ag:
grad_output_mat, _ = gather_along_first_dim( grad_output_mat, _ = gather_along_first_dim(
grad_output_mat, ctx.tp_group grad_output_mat, ctx.tp_group
) )
...@@ -698,8 +707,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -698,8 +707,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
and ctx.fp8_meta["recipe"].override_linear_precision.wgrad and ctx.fp8_meta["recipe"].override_linear_precision.wgrad
): ):
assert ( assert (
not ctx.ub_split_ag not ub_overlap_ag
), "override_linear_precision.wgrad not supported with ub_split_ag" ), "override_linear_precision.wgrad not supported with UB AG overlap"
grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group) grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group)
# FP8 case with gather: unfused bgrad, cast, transpose for efficient gather # FP8 case with gather: unfused bgrad, cast, transpose for efficient gather
elif gather_grad_output: elif gather_grad_output:
...@@ -707,7 +716,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -707,7 +716,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_bias = grad_output_mat.sum(dim=0) grad_bias = grad_output_mat.sum(dim=0)
else: else:
grad_bias = None grad_bias = None
if ctx.ub_split_ag: if ub_overlap_ag:
grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0) grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0)
else: else:
grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8) grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8)
...@@ -718,7 +727,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -718,7 +727,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
fp8_dtype_backward, fp8_dtype_backward,
out=grad_output_c, out=grad_output_c,
) )
if not ctx.ub_split_ag: if not ub_overlap_ag:
grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group) grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group)
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
else: else:
......
...@@ -83,6 +83,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -83,6 +83,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_bulk_dgrad: bool, ub_bulk_dgrad: bool,
ub_split_ag: bool, ub_split_ag: bool,
normalization: str, normalization: str,
ub_atomic_gemm_ag: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -100,11 +101,12 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -100,11 +101,12 @@ class _LayerNormLinear(torch.autograd.Function):
if ln_bias is not None: if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype)
if ub_split_ag: if ub_split_ag or ub_atomic_gemm_ag:
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output:
ub_split_ag = False ub_split_ag = False
if ub_split_ag: ub_atomic_gemm_ag = False
if ub_split_ag or ub_atomic_gemm_ag:
dim_size = list(inputmat.size()) dim_size = list(inputmat.size())
dim_size[0] = dim_size[0] * tp_world_size dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("qkv_fprop") ub_obj_lnout = get_ub("qkv_fprop")
...@@ -112,6 +114,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -112,6 +114,8 @@ class _LayerNormLinear(torch.autograd.Function):
else: else:
ln_out_dtype = torch.uint8 if fp8 else inputmat.dtype ln_out_dtype = torch.uint8 if fp8 else inputmat.dtype
ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype) ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype)
if ub_atomic_gemm_ag:
assert fp8, "AtomicGemm overlap supported only for FP8 GEMM."
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
...@@ -139,7 +143,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -139,7 +143,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
) )
# Column Parallel Linear # Column Parallel Linear
if ub_split_ag: if ub_split_ag or ub_atomic_gemm_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out) ln_out = torch.empty_like(ln_out)
elif parallel_mode == "column" and sequence_parallel: elif parallel_mode == "column" and sequence_parallel:
...@@ -173,6 +177,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -173,6 +177,8 @@ class _LayerNormLinear(torch.autograd.Function):
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward) fp8_dtype_forward)
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo
out, _ = tex.fp8_gemm( out, _ = tex.fp8_gemm(
weight_fp8, weight_fp8,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
...@@ -187,9 +193,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -187,9 +193,9 @@ class _LayerNormLinear(torch.autograd.Function):
bias=bias, bias=bias,
use_bias=use_bias, use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None, ub_algo=ub_algo,
ub=ub_obj_lnout if ub_split_ag else None, ub=ub_obj_lnout if (ub_split_ag or ub_atomic_gemm_ag) else None,
extra_output_tensor=ln_out if ub_split_ag else None, extra_output_tensor=ln_out if (ub_split_ag or ub_atomic_gemm_ag) else None,
) )
else: else:
# Cast for native AMP # Cast for native AMP
...@@ -339,6 +345,14 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -339,6 +345,14 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_backward = get_fp8_te_dtype( fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False ctx.fp8_meta["recipe"], fprop_tensor=False
) )
out_index, meta_tensor, out_te_type, out_type = (
None, None, None, ctx.activation_dtype)
if ctx.ub_bulk_wgrad and ub_obj_dgrad.is_fp8_ubuf():
out_index = tex.FP8BwdTensors.GRAD_INPUT1
meta_tensor = ctx.fp8_meta["scaling_bwd"]
out_te_type = fp8_dtype_backward
out_type = torch.uint8
ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index])
# DGRAD: Evaluated unconditionally to feed into Linear backward # DGRAD: Evaluated unconditionally to feed into Linear backward
_ = tex.fp8_gemm( _ = tex.fp8_gemm(
...@@ -350,12 +364,15 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -350,12 +364,15 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.fp8_meta["scaling_bwd"].scale_inv, ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1, tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward, fp8_dtype_backward,
ctx.activation_dtype, out_type,
get_workspace(), get_workspace(),
out=dgrad, out=dgrad,
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None,
out_index=out_index,
fp8_meta_tensor = meta_tensor,
D_dtype = out_te_type,
) )
else: else:
# DGRAD: Evaluated unconditionally to feed into Linear backward # DGRAD: Evaluated unconditionally to feed into Linear backward
...@@ -387,6 +404,15 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -387,6 +404,15 @@ class _LayerNormLinear(torch.autograd.Function):
if weight.requires_grad: if weight.requires_grad:
if ctx.fp8: if ctx.fp8:
# WGRAD # WGRAD
extra_output_tensor = None
if ctx.ub_bulk_wgrad:
if ub_obj_dgrad.is_fp8_ubuf():
dim_size = list(ub_obj_dgrad.get_ubuf_output(0).size()) # RS output
extra_output_tensor = torch.empty(
dim_size, dtype=ctx.activation_dtype, device=dgrad.device)
dgrad = extra_output_tensor
else:
dgrad = ub_obj_dgrad.get_ubuf_output(0)
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
wgrad, _ = tex.fp8_gemm( wgrad, _ = tex.fp8_gemm(
...@@ -405,7 +431,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -405,7 +431,8 @@ class _LayerNormLinear(torch.autograd.Function):
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None, if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor
) )
else: else:
ln_out_total_c = tex.cast_from_fp8( ln_out_total_c = tex.cast_from_fp8(
...@@ -426,7 +453,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -426,7 +453,8 @@ class _LayerNormLinear(torch.autograd.Function):
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None, if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor
) )
else: else:
# WGRAD # WGRAD
...@@ -443,11 +471,14 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -443,11 +471,14 @@ class _LayerNormLinear(torch.autograd.Function):
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
) )
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output
# Column Parallel Linear # Column Parallel Linear
elif ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None: if ((not ctx.ub_bulk_wgrad)
and ctx.parallel_mode == "column"
and ctx.tensor_parallel
and handle is not None):
handle.wait() handle.wait()
# LayerNorm gradient # LayerNorm gradient
...@@ -512,6 +543,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -512,6 +543,7 @@ class _LayerNormLinear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -624,6 +656,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -624,6 +656,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_split_ag: bool = False, ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
ub_atomic_gemm_ag: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -650,12 +683,18 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -650,12 +683,18 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_ag = ub_split_ag self.ub_split_ag = ub_split_ag
self.ub_atomic_gemm_ag = ub_atomic_gemm_ag
if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_ag: if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_ag or ub_atomic_gemm_ag:
assert ( assert (
tex.userbuf_comm_available() tex.userbuf_comm_available()
), "Userbuffer communication backend not available." ), "Userbuffer communication backend not available."
if ub_atomic_gemm_ag:
warnings.warn(
"Atomic gemm uses a beta API from cublas and is not tested for all use cases."
)
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
if tp_size == 1: if tp_size == 1:
...@@ -919,6 +958,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -919,6 +958,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.ub_bulk_dgrad, self.ub_bulk_dgrad,
self.ub_split_ag, self.ub_split_ag,
self.normalization, self.normalization,
self.ub_atomic_gemm_ag,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""LayerNormMLP API""" """LayerNormMLP API"""
import os import os
import warnings
from typing import Union, Optional, Callable, Tuple, List, Dict, Any from typing import Union, Optional, Callable, Tuple, List, Dict, Any
import torch import torch
...@@ -107,7 +108,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -107,7 +108,9 @@ class _LayerNormMLP(torch.autograd.Function):
ub_bulk_wgrad: bool, ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool, ub_bulk_dgrad: bool,
ub_split_rs: bool, ub_split_rs: bool,
ub_atomic_gemm_rs: bool,
ub_split_ag: bool, ub_split_ag: bool,
ub_atomic_gemm_ag: bool,
activation: str, activation: str,
normalization: str, normalization: str,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
...@@ -130,20 +133,25 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -130,20 +133,25 @@ class _LayerNormMLP(torch.autograd.Function):
if ln_bias is not None: if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype)
if ub_split_ag: if ub_split_ag or ub_atomic_gemm_ag:
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output:
ub_split_ag = False ub_split_ag = False
if ub_split_ag: ub_atomic_gemm_ag = False
ub_overlap_ag = ub_split_ag or ub_atomic_gemm_ag
if ub_overlap_ag:
ub_obj_lnout = get_ub("fc1_fprop") ub_obj_lnout = get_ub("fc1_fprop")
ln_out = ub_obj_lnout.get_ubuf_output(0) ln_out = ub_obj_lnout.get_ubuf_output(0)
else: else:
ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype
ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype) ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype)
if ub_split_rs: if ub_split_rs or ub_atomic_gemm_rs:
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1: if tp_world_size == 1:
ub_split_rs = False ub_split_rs = False
ub_atomic_gemm_rs = False
if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
assert fp8, "AtomicGemm overlap supported only for FP8 GEMM."
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
...@@ -171,7 +179,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -171,7 +179,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
) )
# Column Parallel Linear # Column Parallel Linear
if ub_split_ag: if ub_overlap_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out) ln_out = torch.empty_like(ln_out)
elif set_parallel_mode and sequence_parallel: elif set_parallel_mode and sequence_parallel:
...@@ -223,6 +231,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -223,6 +231,8 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
) )
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo
fc1_out, _ = tex.fp8_gemm( fc1_out, _ = tex.fp8_gemm(
fc1_weight_fp8, fc1_weight_fp8,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
...@@ -237,9 +247,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -237,9 +247,9 @@ class _LayerNormMLP(torch.autograd.Function):
bias=fc1_bias, bias=fc1_bias,
use_bias=use_fc1_bias, use_bias=use_fc1_bias,
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None, ub_algo=ub_algo,
ub=ub_obj_lnout if ub_split_ag else None, ub=ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if ub_split_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None,
) )
gelu_out = activation_func( gelu_out = activation_func(
...@@ -249,18 +259,29 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -249,18 +259,29 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
) )
if ub_split_rs: fc2_out_index, fc2_meta_tensor, fc2_te_type, out_type = (
None, None, None, activation_dtype)
if ub_split_rs or ub_atomic_gemm_rs:
ub_obj_fc2out = get_ub("fc2_fprop") ub_obj_fc2out = get_ub("fc2_fprop")
fc2_out = ub_obj_fc2out.get_ubuf_output(1) fc2_out = ub_obj_fc2out.get_ubuf_output(1)
dim_size = list(gelu_out.size()) dim_size = list(gelu_out.size())
dim_size[0] = dim_size[0] // tp_world_size dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = fc2_weight.size(0) dim_size[1] = fc2_weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
if ub_obj_fc2out.is_fp8_ubuf():
fc2_out_index = tex.FP8FwdTensors.GEMM2_OUTPUT
fc2_meta_tensor = fp8_meta["scaling_fwd"]
fc2_te_type = fp8_dtype_forward
out_type = torch.uint8
ub_obj_fc2out.set_ubuf_scale_inv(fc2_meta_tensor.scale_inv[fc2_out_index])
else: else:
dim_size = list(gelu_out.size()) dim_size = list(gelu_out.size())
dim_size[1] = fc2_weight.size(0) dim_size[1] = fc2_weight.size(0)
fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo
_ = tex.fp8_gemm( _ = tex.fp8_gemm(
fc2_weight_fp8, fc2_weight_fp8,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
...@@ -270,15 +291,18 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -270,15 +291,18 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM2_INPUT, tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
activation_dtype, out_type,
get_workspace(), get_workspace(),
bias=fc2_bias, bias=fc2_bias,
use_bias=use_fc2_bias, use_bias=use_fc2_bias,
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
out=fc2_out, out=fc2_out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None, ub_algo=ub_algo,
ub=ub_obj_fc2out if ub_split_rs else None, ub=ub_obj_fc2out if ub_split_rs or ub_atomic_gemm_rs else None,
extra_output_tensor=rs_out if ub_split_rs else None, extra_output_tensor=rs_out if ub_split_rs or ub_atomic_gemm_rs else None,
out_index=fc2_out_index,
fp8_meta_tensor = fc2_meta_tensor,
D_dtype = fc2_te_type,
) )
else: else:
# Cast for native AMP # Cast for native AMP
...@@ -394,11 +418,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -394,11 +418,12 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_wgrad = ub_bulk_wgrad
ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.ub_split_ag = ub_split_ag ctx.ub_split_ag = ub_split_ag
ctx.ub_atomic_gemm_ag = ub_atomic_gemm_ag
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization ctx.normalization = normalization
# Row Parallel Linear # Row Parallel Linear
if ub_split_rs: if ub_split_rs or ub_atomic_gemm_rs:
fc2_out = rs_out fc2_out = rs_out
elif set_parallel_mode and sequence_parallel: elif set_parallel_mode and sequence_parallel:
fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group) fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group)
...@@ -447,11 +472,15 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -447,11 +472,15 @@ class _LayerNormMLP(torch.autograd.Function):
dim_size[0] = dim_size[0] * tp_world_size dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("fc1_dgrad") ub_obj_lnout = get_ub("fc1_dgrad")
ub_obj_lnout.copy_input_to_ubuf(ln_out, 1) ub_obj_lnout.copy_input_to_ubuf(ln_out, 1)
if ctx.ub_split_ag: ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag
if ub_overlap_ag:
tp_world_size = get_distributed_world_size(ctx.tp_group) tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1: if tp_world_size == 1:
ctx.ub_split_ag = False ctx.ub_split_ag = False
if ctx.ub_split_ag: ctx.ub_overlap_ag = False
ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag
if ub_overlap_ag:
dim_size = list(grad_outputs[0].size()) dim_size = list(grad_outputs[0].size())
dim_size[0] = dim_size[0] * tp_world_size dim_size[0] = dim_size[0] * tp_world_size
ctx.ub_obj_gradout = get_ub("fc2_dgrad") ctx.ub_obj_gradout = get_ub("fc2_dgrad")
...@@ -497,6 +526,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -497,6 +526,8 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fp8_meta["recipe"], fprop_tensor=False ctx.fp8_meta["recipe"], fprop_tensor=False
) )
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo
# FC2 DGRAD; Unconditional # FC2 DGRAD; Unconditional
fc2_dgrad, _ = tex.fp8_gemm( fc2_dgrad, _ = tex.fp8_gemm(
fc2_weight_t_fp8, fc2_weight_t_fp8,
...@@ -510,10 +541,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -510,10 +541,10 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None, ub_algo=ub_algo,
ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None, ub=ctx.ub_obj_gradout if ub_overlap_ag else None,
) )
if ctx.ub_split_ag: if ub_overlap_ag:
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
# FC2 WGRAD # FC2 WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
...@@ -595,11 +626,19 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -595,11 +626,19 @@ class _LayerNormMLP(torch.autograd.Function):
) )
dgelu_t = None dgelu_t = None
out_index, meta_tensor, out_te_type, out_type = (
None, None, None, ctx.activation_dtype)
fc1_dgrad_size = list(dgelu.size()) fc1_dgrad_size = list(dgelu.size())
fc1_dgrad_size[1] = fc1_weight.size(1) fc1_dgrad_size[1] = fc1_weight.size(1)
if ctx.ub_bulk_wgrad: # allocate dgrad output if ctx.ub_bulk_wgrad: # allocate dgrad output
ub_obj_dgrad = get_ub("fc1_wgrad") ub_obj_dgrad = get_ub("fc1_wgrad")
fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
if ub_obj_dgrad.is_fp8_ubuf():
out_index = tex.FP8BwdTensors.GRAD_INPUT2
meta_tensor = ctx.fp8_meta["scaling_bwd"]
out_te_type = fp8_dtype_backward
out_type = torch.uint8
ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index])
else: else:
fc1_dgrad = torch.empty( fc1_dgrad = torch.empty(
fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device
...@@ -614,12 +653,15 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -614,12 +653,15 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fp8_meta["scaling_bwd"].scale_inv, ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT2, tex.FP8BwdTensors.GRAD_OUTPUT2,
fp8_dtype_backward, fp8_dtype_backward,
ctx.activation_dtype, out_type,
get_workspace(), get_workspace(),
out=fc1_dgrad, out=fc1_dgrad,
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None,
out_index=out_index,
fp8_meta_tensor = meta_tensor,
D_dtype = out_te_type,
) )
else: else:
# FC2 DGRAD; Unconditional # FC2 DGRAD; Unconditional
...@@ -703,6 +745,15 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -703,6 +745,15 @@ class _LayerNormMLP(torch.autograd.Function):
if fc1_weight.requires_grad: if fc1_weight.requires_grad:
if ctx.fp8: if ctx.fp8:
# FC1 WGRAD # FC1 WGRAD
extra_output_tensor = None
if ctx.ub_bulk_wgrad:
if ub_obj_dgrad.is_fp8_ubuf():
dim_size = list(ub_obj_dgrad.get_ubuf_output(0).size()) # RS output
extra_output_tensor = torch.empty(
dim_size, dtype=ctx.activation_dtype, device=fc1_dgrad.device)
fc1_dgrad = extra_output_tensor
else:
fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0)
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
fc1_wgrad, _ = tex.fp8_gemm( fc1_wgrad, _ = tex.fp8_gemm(
...@@ -724,6 +775,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -724,6 +775,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None, if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor,
) )
else: else:
ln_out_total_c = tex.cast_from_fp8( ln_out_total_c = tex.cast_from_fp8(
...@@ -747,6 +799,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -747,6 +799,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None, if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor,
) )
else: else:
# FC1 WGRAD # FC1 WGRAD
...@@ -768,11 +821,14 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -768,11 +821,14 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_wgrad, _, _ = fc1_wgrad_outputs fc1_wgrad, _, _ = fc1_wgrad_outputs
else: else:
fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs
# Column Parallel Linear
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output
elif ctx.set_parallel_mode and ctx.tensor_parallel and handle is not None:
# Column Parallel Linear
if ((not ctx.ub_bulk_wgrad)
and ctx.set_parallel_mode
and ctx.tensor_parallel
and handle is not None):
handle.wait() handle.wait()
# LayerNorm gradient # LayerNorm gradient
...@@ -850,6 +906,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -850,6 +906,8 @@ class _LayerNormMLP(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
...@@ -965,8 +1023,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -965,8 +1023,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False, ub_split_rs: bool = False,
ub_atomic_gemm_rs: bool = False,
ub_split_ag: bool = False, ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
ub_atomic_gemm_ag: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -987,12 +1047,24 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -987,12 +1047,24 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_rs = ub_split_rs self.ub_split_rs = ub_split_rs
self.ub_split_ag = ub_split_ag self.ub_split_ag = ub_split_ag
self.ub_atomic_gemm_rs = ub_atomic_gemm_rs
if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_rs or ub_split_ag: self.ub_atomic_gemm_ag = ub_atomic_gemm_ag
if (ub_bulk_wgrad # pylint: disable=too-many-boolean-expressions
or ub_bulk_dgrad
or ub_split_rs
or ub_split_ag
or ub_atomic_gemm_rs
or ub_atomic_gemm_ag):
assert ( assert (
tex.userbuf_comm_available() tex.userbuf_comm_available()
), "Userbuffer communication backend not available." ), "Userbuffer communication backend not available."
if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
warnings.warn(
"Atomic gemm uses a beta API from cublas and is not tested for all use cases."
)
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
if tp_size == 1: if tp_size == 1:
...@@ -1210,7 +1282,9 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1210,7 +1282,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.ub_bulk_wgrad, self.ub_bulk_wgrad,
self.ub_bulk_dgrad, self.ub_bulk_dgrad,
self.ub_split_rs, self.ub_split_rs,
self.ub_atomic_gemm_rs,
self.ub_split_ag, self.ub_split_ag,
self.ub_atomic_gemm_ag,
self.activation, self.activation,
self.normalization, self.normalization,
) )
......
...@@ -77,6 +77,8 @@ class _Linear(torch.autograd.Function): ...@@ -77,6 +77,8 @@ class _Linear(torch.autograd.Function):
is_grad_enabled: bool, is_grad_enabled: bool,
ub_split_rs: bool, ub_split_rs: bool,
ub_split_ag: bool, ub_split_ag: bool,
ub_atomic_gemm_rs: bool,
ub_atomic_gemm_ag: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = weight.shape[-1] in_features = weight.shape[-1]
...@@ -88,10 +90,13 @@ class _Linear(torch.autograd.Function): ...@@ -88,10 +90,13 @@ class _Linear(torch.autograd.Function):
update_fp8_weights = is_first_microbatch is None or is_first_microbatch update_fp8_weights = is_first_microbatch is None or is_first_microbatch
if ub_split_rs: if ub_split_rs or ub_atomic_gemm_rs:
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1: if tp_world_size == 1:
ub_split_rs = False ub_split_rs = False
ub_atomic_gemm_rs = False
if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
assert fp8, "AtomicGemm overlap supported only for FP8 GEMM."
# Cast for native AMP # Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype) inputmat = cast_if_needed(inputmat, activation_dtype)
inputmat_no_fp8 = inputmat inputmat_no_fp8 = inputmat
...@@ -155,18 +160,29 @@ class _Linear(torch.autograd.Function): ...@@ -155,18 +160,29 @@ class _Linear(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
) )
if ub_split_rs: proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
None, None, None, activation_dtype)
if ub_split_rs or ub_atomic_gemm_rs:
ub_obj_projout = get_ub("proj_fprop") ub_obj_projout = get_ub("proj_fprop")
out = ub_obj_projout.get_ubuf_output(1) out = ub_obj_projout.get_ubuf_output(1)
dim_size = list(inputmat_total.size()) dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // tp_world_size dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = weight.size(0) dim_size[1] = weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
if ub_obj_projout.is_fp8_ubuf():
proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT
meta_tensor = fp8_meta["scaling_fwd"]
proj_out_tetype = fp8_dtype_forward
proj_out_pttype = torch.uint8
ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index])
else: else:
dim_size = list(inputmat_total.size()) dim_size = list(inputmat_total.size())
dim_size[1] = weight.size(0) dim_size[1] = weight.size(0)
out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo
_ = fp8_gemm( _ = fp8_gemm(
weight_fp8, weight_fp8,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
...@@ -176,15 +192,18 @@ class _Linear(torch.autograd.Function): ...@@ -176,15 +192,18 @@ class _Linear(torch.autograd.Function):
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
activation_dtype, proj_out_pttype,
get_workspace(), get_workspace(),
bias=bias, bias=bias,
use_bias=use_bias, use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
out=out, out=out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None, ub_algo=ub_algo,
ub=ub_obj_projout if ub_split_rs else None, ub=ub_obj_projout if (ub_split_rs or ub_atomic_gemm_rs) else None,
extra_output_tensor=rs_out if ub_split_rs else None, extra_output_tensor=rs_out if (ub_split_rs or ub_atomic_gemm_rs) else None,
out_index=proj_out_index,
fp8_meta_tensor = meta_tensor,
D_dtype = proj_out_tetype,
) )
else: else:
# Cast for native AMP # Cast for native AMP
...@@ -245,11 +264,12 @@ class _Linear(torch.autograd.Function): ...@@ -245,11 +264,12 @@ class _Linear(torch.autograd.Function):
ctx.parallel_mode = parallel_mode ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.ub_split_ag = ub_split_ag ctx.ub_split_ag = ub_split_ag
ctx.ub_atomic_gemm_ag = ub_atomic_gemm_ag
ctx.tp_size = tp_size ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
# Row Parallel Linear # Row Parallel Linear
if ub_split_rs: if ub_split_rs or ub_atomic_gemm_rs:
out = rs_out out = rs_out
elif parallel_mode == "row" and sequence_parallel: elif parallel_mode == "row" and sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group) out, _ = reduce_scatter_along_first_dim(out, tp_group)
...@@ -275,11 +295,12 @@ class _Linear(torch.autograd.Function): ...@@ -275,11 +295,12 @@ class _Linear(torch.autograd.Function):
fwd_scale_inverses, fwd_scale_inverses,
) = ctx.saved_tensors ) = ctx.saved_tensors
if ctx.ub_split_ag: if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
tp_world_size = get_distributed_world_size(ctx.tp_group) tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1: if tp_world_size == 1:
ctx.ub_split_ag = False ctx.ub_split_ag = False
if ctx.ub_split_ag: ctx.ub_atomic_gemm_ag = False
if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
dim_size = list(grad_output.size()) dim_size = list(grad_output.size())
dim_size[0] = dim_size[0] * tp_world_size dim_size[0] = dim_size[0] * tp_world_size
ctx.ub_obj_gradout = get_ub("proj_dgrad") ctx.ub_obj_gradout = get_ub("proj_dgrad")
...@@ -323,6 +344,8 @@ class _Linear(torch.autograd.Function): ...@@ -323,6 +344,8 @@ class _Linear(torch.autograd.Function):
ctx.fp8_meta["recipe"], fprop_tensor=False ctx.fp8_meta["recipe"], fprop_tensor=False
) )
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo
if ctx.requires_dgrad: if ctx.requires_dgrad:
if ctx.fp8: if ctx.fp8:
dgrad, _ = fp8_gemm( dgrad, _ = fp8_gemm(
...@@ -337,8 +360,8 @@ class _Linear(torch.autograd.Function): ...@@ -337,8 +360,8 @@ class _Linear(torch.autograd.Function):
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None, ub_algo=ub_algo,
ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None, ub=ctx.ub_obj_gradout if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag else None,
) )
else: else:
dgrad, _, _ = gemm( dgrad, _, _ = gemm(
...@@ -366,7 +389,7 @@ class _Linear(torch.autograd.Function): ...@@ -366,7 +389,7 @@ class _Linear(torch.autograd.Function):
if ctx.fp8: if ctx.fp8:
# WGRAD # WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if ctx.ub_split_ag: if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
wgrad, _ = fp8_gemm( wgrad, _ = fp8_gemm(
inputmat_t_total, inputmat_t_total,
...@@ -445,6 +468,8 @@ class _Linear(torch.autograd.Function): ...@@ -445,6 +468,8 @@ class _Linear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
...@@ -538,6 +563,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -538,6 +563,8 @@ class Linear(TransformerEngineBaseModule):
ub_split_rs: bool = False, ub_split_rs: bool = False,
ub_split_ag: bool = False, ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
ub_atomic_gemm_rs: bool = False,
ub_atomic_gemm_ag: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -559,12 +586,19 @@ class Linear(TransformerEngineBaseModule): ...@@ -559,12 +586,19 @@ class Linear(TransformerEngineBaseModule):
self.parameters_split = parameters_split self.parameters_split = parameters_split
self.ub_split_rs = ub_split_rs self.ub_split_rs = ub_split_rs
self.ub_split_ag = ub_split_ag self.ub_split_ag = ub_split_ag
self.ub_atomic_gemm_rs = ub_atomic_gemm_rs
self.ub_atomic_gemm_ag = ub_atomic_gemm_ag
if ub_split_rs or ub_split_ag: if ub_split_rs or ub_split_ag or ub_atomic_gemm_rs:
assert ( assert (
tex.userbuf_comm_available() tex.userbuf_comm_available()
), "Userbuffer communication backend not available." ), "Userbuffer communication backend not available."
if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
warnings.warn(
"Atomic gemm uses a beta API from cublas and is not tested for all use cases."
)
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
if tp_size == 1: if tp_size == 1:
...@@ -785,6 +819,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -785,6 +819,8 @@ class Linear(TransformerEngineBaseModule):
torch.is_grad_enabled(), torch.is_grad_enabled(),
self.ub_split_rs, self.ub_split_rs,
self.ub_split_ag, self.ub_split_ag,
self.ub_atomic_gemm_rs,
self.ub_atomic_gemm_ag,
) )
out = linear_fn(*args) out = linear_fn(*args)
......
...@@ -263,6 +263,22 @@ class TransformerLayer(torch.nn.Module): ...@@ -263,6 +263,22 @@ class TransformerLayer(torch.nn.Module):
ub_bulk_dgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_DGRAD", "1"))) ub_bulk_dgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_DGRAD", "1")))
ub_split_ag = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_AG", "1"))) ub_split_ag = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_AG", "1")))
ub_split_rs = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_RS", "1"))) ub_split_rs = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_RS", "1")))
ub_atomic_gemm_rs = (ub_tp_comm_overlap
and bool(int(os.getenv("NVTE_UB_ATOMIC_GEMM_RS", "0"))))
assert (
not (ub_split_rs and ub_atomic_gemm_rs)
), "Only one type of RS overlap NVTE_UB_SPLIT_RS/NVTE_UB_ATOMIC_GEMM_RS should be enabled."
ub_atomic_gemm_ag = (ub_tp_comm_overlap
and bool(int(os.getenv("NVTE_UB_ATOMIC_GEMM_AG", "0"))))
assert (
not (ub_split_ag and ub_atomic_gemm_ag)
), "Only one type of AG overlap NVTE_UB_SPLIT_AG/NVTE_UB_ATOMIC_GEMM_AG should be enabled."
if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
warnings.warn(
"Atomic gemm uses a beta API from cublas and is not tested for all use cases."
)
bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1"))) bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1")))
self.layer_number = layer_number self.layer_number = layer_number
self.output_layernorm = output_layernorm self.output_layernorm = output_layernorm
...@@ -323,6 +339,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -323,6 +339,8 @@ class TransformerLayer(torch.nn.Module):
"ub_bulk_dgrad" : ub_bulk_dgrad, "ub_bulk_dgrad" : ub_bulk_dgrad,
"ub_split_ag" : ub_split_ag, "ub_split_ag" : ub_split_ag,
"ub_split_rs" : ub_split_rs, "ub_split_rs" : ub_split_rs,
"ub_atomic_gemm_rs" : ub_atomic_gemm_rs,
"ub_atomic_gemm_ag" : ub_atomic_gemm_ag,
} }
self.self_attention = MultiheadAttention( self.self_attention = MultiheadAttention(
...@@ -377,6 +395,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -377,6 +395,8 @@ class TransformerLayer(torch.nn.Module):
ub_bulk_dgrad=ub_bulk_dgrad, ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_rs=ub_split_rs, ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag, ub_split_ag=ub_split_ag,
ub_atomic_gemm_rs=ub_atomic_gemm_rs,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
activation=activation, activation=activation,
normalization=normalization, normalization=normalization,
device=device, device=device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment