Unverified Commit 958e1889 authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

Atomic gemm and FP8 Reduce Scatter (#449)



* Initial commit
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Repro for RS output mismatch with Single GEMM + Split pipelined RS
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* minor changes for AG->GEMM pipelined overlap
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Add Atomic Gemm cublasApi attributes and initial implementation of AG->Atomic GEMM
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* AtomicGemm+RS functional with workaround
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* add amax update to layernorm_linear for FP8 unit test accuracy
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Enable reducescatter2_userbuff_strided variants
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fix
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* AG+AtomicGemm overlap functional but gemm doesnt overlap with comm
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Add userbuffers_sendrecv kernel variants
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* TransformerLayer API changes to enable AtomicGemm+RS overlap
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Code cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Code cleanup2
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [UB] AllGather Atomic GEMM overlap using userbuffer_sendrecv kernels
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Code cleanup + bug fix for multiatomic sendrecv kernel
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fixes
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [UB] Add shuffling for better AG AtomicGEMM overlap
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fix for AG AtomicGemm overlap
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fix for multiAtomicAG and singleAtomicAG
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Use chunk_i+1 as recv_chunk for multiatomic_AG with shuffling
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Launch AtomicGEMM after first-chunk AG
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Rebase to main
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Add FP8 ReduceScatter kernels, AtomicGEMM+FP8 RS not functional
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Revert "Add FP8 ReduceScatter kernels, AtomicGEMM+FP8 RS not functional"

This reverts commit 80a47a76355440cd5fb4314c96fe9fda632d87f9.
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Add support for NVLS-MC and FP8 Reduce Scatter
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fix
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Atomic and Multiatomic FP8 RS functional
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Remove debug print
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* UB comm initialization hang fix
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Code cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Create new GEMM API for Atomic GEMM
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* CI ready
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* more fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* license
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fix
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Revert NVLS-MC
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Check cu* versions for running atomic gemms
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Add experimental warning
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Better wording
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add warning to c api
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix wording
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent be67f219
...@@ -506,7 +506,7 @@ def test_export_gemm( ...@@ -506,7 +506,7 @@ def test_export_gemm(
self.fp8_tensor_weight, self.fp8_tensor_weight,
self.weights_type) self.weights_type)
ret = fp8_gemm( ret, _ = fp8_gemm(
weight_fp8, weight_fp8,
self.meta_weight.scale_inv, self.meta_weight.scale_inv,
self.fp8_tensor_weight, self.fp8_tensor_weight,
...@@ -1323,7 +1323,7 @@ def test_export_gemm_layernorm( ...@@ -1323,7 +1323,7 @@ def test_export_gemm_layernorm(
self.fp8_tensor_weight, self.fp8_tensor_weight,
self.weights_type) self.weights_type)
ret = fp8_gemm( ret, _ = fp8_gemm(
weight_fp8, weight_fp8,
self.meta_weight.scale_inv, self.meta_weight.scale_inv,
self.fp8_tensor_weight, self.fp8_tensor_weight,
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <transformer_engine/logging.h> #include <transformer_engine/logging.h>
#include <transformer_engine/gemm.h> #include <transformer_engine/gemm.h>
#include <cuda.h>
#include <cublasLt.h> #include <cublasLt.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#include "../common.h" #include "../common.h"
...@@ -50,6 +51,10 @@ void cublas_gemm(const Tensor *inputA, ...@@ -50,6 +51,10 @@ void cublas_gemm(const Tensor *inputA,
bool accumulate, bool accumulate,
bool use_split_accumulator, bool use_split_accumulator,
int math_sm_count, int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const Tensor *inputCounter,
cudaStream_t stream cudaStream_t stream
) { ) {
void *A = inputA->data.dptr; void *A = inputA->data.dptr;
...@@ -63,6 +68,10 @@ void cublas_gemm(const Tensor *inputA, ...@@ -63,6 +68,10 @@ void cublas_gemm(const Tensor *inputA,
void *bias_ptr = inputBias->data.dptr; void *bias_ptr = inputBias->data.dptr;
const bool bias = bias_ptr != nullptr; const bool bias = bias_ptr != nullptr;
void *pre_gelu_out = outputPreGelu->data.dptr; void *pre_gelu_out = outputPreGelu->data.dptr;
void *counter = nullptr;
if (inputCounter != nullptr) {
counter = inputCounter->data.dptr;
}
const bool gelu = pre_gelu_out != nullptr; const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype); is_fp8_dtype(inputB->data.dtype);
...@@ -223,6 +232,27 @@ void cublas_gemm(const Tensor *inputA, ...@@ -223,6 +232,27 @@ void cublas_gemm(const Tensor *inputA,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue, sizeof(epilogue))); &epilogue, sizeof(epilogue)));
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205
if (counter != nullptr) {
if (m_split == 0) m_split=1;
if (n_split == 0) n_split=1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS,
&m_split, sizeof(m_split)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS,
&n_split, sizeof(n_split)));
if (gemm_producer) {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER,
&counter, sizeof(counter)));
} else {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER,
&counter, sizeof(counter)));
}
}
#endif
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
...@@ -254,7 +284,6 @@ void cublas_gemm(const Tensor *inputA, ...@@ -254,7 +284,6 @@ void cublas_gemm(const Tensor *inputA,
workspaceSize, workspaceSize,
stream)); /* stream */ stream)); /* stream */
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc));
...@@ -320,5 +349,82 @@ void nvte_cublas_gemm(const NVTETensor A, ...@@ -320,5 +349,82 @@ void nvte_cublas_gemm(const NVTETensor A,
wspace->data.shape[0], wspace->data.shape[0],
accumulate, use_split_accumulator, accumulate, use_split_accumulator,
math_sm_count, math_sm_count,
0,
0,
false,
nullptr,
stream);
}
void nvte_cublas_atomic_gemm(const NVTETensor A,
const NVTETensor B,
NVTETensor D,
const NVTETensor bias,
NVTETensor pre_gelu_out,
bool transa,
bool transb,
bool grad,
NVTETensor workspace,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const NVTETensor counter,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_atomic_gemm);
int cudart_version;
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&cudart_version));
NVTE_CHECK(cudart_version >= 12020, "Cuda version 12.2 is required for atomic gemm.");
NVTE_CHECK(cublasLtGetVersion() >= 120205, "Cublas version 12.2.5 is required for atomic gemm.");
using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor*>(A);
const Tensor *inputB = reinterpret_cast<const Tensor*>(B);
Tensor *outputD = reinterpret_cast<Tensor*>(D);
const Tensor *biasTensor = reinterpret_cast<const Tensor*>(bias);
Tensor *outputGelu = reinterpret_cast<Tensor*>(pre_gelu_out);
const Tensor *inputCounter = reinterpret_cast<const Tensor*>(counter);
Tensor *wspace = reinterpret_cast<Tensor*>(workspace);
const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
int lda, ldb, ldd;
if (transa && !transb) { // TN
lda = k;
ldb = k;
ldd = m;
} else if (!transa && !transb) { // NN
lda = m;
ldb = k;
ldd = m;
} else if (!transa && transb) { // NT
lda = m;
ldb = n;
ldd = m;
} else { // TT
NVTE_ERROR("TT layout not allowed.");
}
cublas_gemm(inputA,
inputB,
outputD,
biasTensor,
outputGelu,
m, n, k,
lda, ldb, ldd,
(transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N,
grad, wspace->data.dptr,
wspace->data.shape[0],
accumulate, use_split_accumulator,
math_sm_count,
m_split,
n_split,
gemm_producer,
inputCounter,
stream); stream);
} }
...@@ -54,6 +54,52 @@ void nvte_cublas_gemm(const NVTETensor A, ...@@ -54,6 +54,52 @@ void nvte_cublas_gemm(const NVTETensor A,
cudaStream_t stream cudaStream_t stream
); );
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
*
* \warning Cublas atomic gemm uses a beta API and is not tested for all use cases.
*
* Computes:
* - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors
* - `D = AB + bias` if `pre_gelu_out` is empty and `bias` is not empty
* - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors
*
* \param[in] A The A matrix.
* \param[in] B The B matrix.
* \param[in,out] D Output matrix.
* \param[in] bias Bias tensor.
* \param[in,out] pre_gelu_out Output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of the
* gradient computation.
* \param[out] workspace Workspace tensor.
* \param[in] accumulate Whether to accumulate the result into the D matrix.
* \param[in] use_split_accumulator Whether to use split accumulator in the FP8 GEMM.
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] m_split Number of chunks/splits along m-dimension for Atomic GEMM.
* \param[in] n_split Number of chunks/splits along n-dimension for Atomic GEMM.
* \param[in] gemm_producer Whether Atomic GEMM is the producer or consumer.
* \param[in,out] counter counter[chunk_i]=0 indicates chunk_i has been produced.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cublas_atomic_gemm(const NVTETensor A,
const NVTETensor B,
NVTETensor D,
const NVTETensor bias,
NVTETensor pre_gelu_out,
bool transa,
bool transb,
bool grad,
NVTETensor workspace,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const NVTETensor counter,
cudaStream_t stream
);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -2262,6 +2262,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2262,6 +2262,8 @@ class MultiheadAttention(torch.nn.Module):
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False, ub_split_rs: bool = False,
ub_split_ag: bool = False, ub_split_ag: bool = False,
ub_atomic_gemm_rs: bool = False,
ub_atomic_gemm_ag: bool = False,
bias: bool = True, bias: bool = True,
normalization: str = "LayerNorm", normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
...@@ -2342,6 +2344,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2342,6 +2344,7 @@ class MultiheadAttention(torch.nn.Module):
ub_bulk_dgrad=ub_bulk_dgrad, ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag, ub_split_ag=ub_split_ag,
normalization=normalization, normalization=normalization,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
else: else:
...@@ -2372,6 +2375,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2372,6 +2375,7 @@ class MultiheadAttention(torch.nn.Module):
ub_bulk_dgrad=ub_bulk_dgrad, ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag, ub_split_ag=ub_split_ag,
normalization=normalization, normalization=normalization,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
else: else:
...@@ -2418,6 +2422,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2418,6 +2422,8 @@ class MultiheadAttention(torch.nn.Module):
parallel_mode="row" if set_parallel_mode else None, parallel_mode="row" if set_parallel_mode else None,
ub_split_rs=ub_split_rs, ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag, ub_split_ag=ub_split_ag,
ub_atomic_gemm_rs=ub_atomic_gemm_rs,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
......
...@@ -91,22 +91,40 @@ def fp8_gemm( ...@@ -91,22 +91,40 @@ def fp8_gemm(
assert ub is not None, 'ub object is None!' assert ub is not None, 'ub object is None!'
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG: if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap fn = ub.bulk_overlap
args = tuple(args + (1,)) extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (1, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS: elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap fn = ub.bulk_overlap
args = tuple(args + (0,)) extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (0, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG: elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag fn = ub.split_overlap_ag
extra_output_tensor = ( extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor empty_tensor if extra_output_tensor is None else extra_output_tensor
) )
args = tuple(args + (extra_output_tensor,)) args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG:
fn = ub.atomic_gemm_overlap_ag
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS: elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS:
fn = ub.split_overlap_rs fn = ub.split_overlap_rs
assert ( assert (
extra_output_tensor is not None extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS requires extra output tensor' ), 'SPLIT_PIPELINED_RS requires extra output tensor'
args = tuple(args + (True, extra_output_tensor,)) args = tuple(args + (True, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS:
fn = ub.atomic_gemm_overlap_rs
assert (
extra_output_tensor is not None
), 'ATOMIC_GEMM_RS requires extra output tensor'
args = tuple(args + (True, extra_output_tensor,))
_ = fn(*args) _ = fn(*args)
return out, gelu_input return out, gelu_input
...@@ -195,10 +213,10 @@ def gemm( ...@@ -195,10 +213,10 @@ def gemm(
assert ub is not None, 'ub object is None!' assert ub is not None, 'ub object is None!'
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG: if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap fn = ub.bulk_overlap
args = tuple(args + (1,)) args = tuple(args + (1, empty_tensor))
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS: elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap fn = ub.bulk_overlap
args = tuple(args + (0,)) args = tuple(args + (0, empty_tensor))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG: elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag fn = ub.split_overlap_ag
extra_output_tensor = ( extra_output_tensor = (
......
...@@ -179,6 +179,32 @@ void te_gemm(at::Tensor A, ...@@ -179,6 +179,32 @@ void te_gemm(at::Tensor A,
int math_sm_count int math_sm_count
); );
void te_atomic_gemm(at::Tensor A,
at::Tensor A_scale_inverse,
transformer_engine::DType A_type,
bool transa,
at::Tensor B,
at::Tensor B_scale_inverse,
transformer_engine::DType B_type,
bool transb,
at::Tensor D,
at::Tensor D_scale,
transformer_engine::DType D_type,
at::Tensor D_amax,
at::Tensor bias,
transformer_engine::DType bias_type,
at::Tensor pre_gelu_out,
bool grad,
at::Tensor workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
at::Tensor counter
);
void fused_cast_transpose(at::Tensor input, void fused_cast_transpose(at::Tensor input,
at::Tensor scale, at::Tensor scale,
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "extensions.h" #include "extensions.h"
void te_gemm(at::Tensor A, void te_gemm(at::Tensor A,
at::Tensor A_scale_inverse, at::Tensor A_scale_inverse,
transformer_engine::DType A_type, transformer_engine::DType A_type,
...@@ -73,3 +74,82 @@ void te_gemm(at::Tensor A, ...@@ -73,3 +74,82 @@ void te_gemm(at::Tensor A,
math_sm_count, math_sm_count,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
} }
void te_atomic_gemm(at::Tensor A,
at::Tensor A_scale_inverse,
transformer_engine::DType A_type,
bool transa,
at::Tensor B,
at::Tensor B_scale_inverse,
transformer_engine::DType B_type,
bool transb,
at::Tensor D,
at::Tensor D_scale,
transformer_engine::DType D_type,
at::Tensor D_amax,
at::Tensor bias,
transformer_engine::DType bias_type,
at::Tensor pre_gelu_out,
bool grad,
at::Tensor workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
at::Tensor counter
) {
using namespace transformer_engine;
auto te_A = makeTransformerEngineTensor(A.data_ptr(),
{static_cast<size_t>(A.size(0)),
static_cast<size_t>(A.size(1))},
A_type, nullptr, nullptr,
A_scale_inverse.data_ptr());
auto te_B = makeTransformerEngineTensor(B.data_ptr(),
{static_cast<size_t>(B.size(0)),
static_cast<size_t>(B.size(1))},
B_type, nullptr, nullptr,
B_scale_inverse.data_ptr());
auto te_D = makeTransformerEngineTensor(D.data_ptr(),
{static_cast<size_t>(D.size(0)),
static_cast<size_t>(D.size(1))},
D_type, D_amax.data_ptr(),
D_scale.data_ptr(), nullptr);
auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), {static_cast<size_t>(bias.size(0))},
bias_type);
auto te_counter = makeTransformerEngineTensor(counter.data_ptr(),
{static_cast<size_t>(counter.size(0))},
DType::kInt32);
const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0))}
: std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0)),
static_cast<size_t>(pre_gelu_out.size(1))};
auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out.data_ptr(),
gelu_shape,
GetTransformerEngineDType(
pre_gelu_out.scalar_type()));
auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(),
{workspaceSize},
DType::kByte);
nvte_cublas_atomic_gemm(te_A.data(),
te_B.data(),
te_D.data(),
te_bias.data(),
te_pre_gelu_out.data(),
transa,
transb,
grad,
te_workspace.data(),
accumulate,
use_split_accumulator,
math_sm_count,
m_split,
n_split,
gemm_producer,
te_counter.data(),
at::cuda::getCurrentCUDAStream());
}
...@@ -91,18 +91,24 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -91,18 +91,24 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG) .value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG)
.value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS) .value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS)
.value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS) .value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS)
.value("SPLIT_PIPELINED_AG", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG); .value("SPLIT_PIPELINED_AG", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG)
.value("ATOMIC_GEMM_RS", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS)
.value("ATOMIC_GEMM_AG", ubuf::UBOverlapAlgo::ATOMIC_GEMM_AG);
py::class_<ubuf::UbufCommOverlap>(m, "UbufCommOverlap") py::class_<ubuf::UbufCommOverlap>(m, "UbufCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, int, bool, int>()) .def(py::init<torch::Tensor&, int, int, int, int, int, bool, int, torch::Tensor>())
.def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap) .def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap)
.def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs) .def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs)
.def("set_ubuf_scale_inv", &ubuf::UbufCommOverlap::set_ubuf_scale_inv)
.def("atomic_gemm_overlap_rs", &ubuf::UbufCommOverlap::atomic_gemm_overlap_rs)
.def("is_fp8_ubuf", &ubuf::UbufCommOverlap::is_fp8_ubuf)
.def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf) .def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output); .def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output);
py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap") py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap")
.def(py::init<torch::Tensor&, int, int, bool, int>()) .def(py::init<torch::Tensor&, int, int, int, int, bool, bool, int, torch::Tensor>())
.def("split_overlap_ag", &ubuf::UbufP2PCommOverlap::split_overlap_ag) .def("split_overlap_ag", &ubuf::UbufP2PCommOverlap::split_overlap_ag)
.def("atomic_gemm_overlap_ag", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag)
.def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf) .def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output); .def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output);
#else // NVTE_WITH_USERBUFFERS #else // NVTE_WITH_USERBUFFERS
......
...@@ -4,10 +4,13 @@ ...@@ -4,10 +4,13 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "userbuffers.h"
#include <assert.h> #include <assert.h>
#include <chrono>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <immintrin.h> #include <immintrin.h>
#include <iostream>
#include <math.h> #include <math.h>
#include <mpi.h> #include <mpi.h>
#include <sched.h> #include <sched.h>
...@@ -15,9 +18,6 @@ ...@@ -15,9 +18,6 @@
#include <string.h> #include <string.h>
#include <unistd.h> #include <unistd.h>
#include <x86intrin.h> #include <x86intrin.h>
#include <chrono>
#include <iostream>
#include "userbuffers.h"
static int oob_bcast(void *comm_context, void *buf, int size, int root) { static int oob_bcast(void *comm_context, void *buf, int size, int root) {
MPI_Bcast(buf, size, MPI_BYTE, root, MPI_Bcast(buf, size, MPI_BYTE, root,
...@@ -47,6 +47,17 @@ int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (co ...@@ -47,6 +47,17 @@ int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (co
} \ } \
} while (0) } while (0)
#define CUCHECK(cmd) \
do { \
CUresult retval = cmd; \
if (retval != CUDA_SUCCESS) { \
const char *error_string; \
cuGetErrorString(retval, &error_string); \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, error_string); \
exit(EXIT_FAILURE); \
} \
} while (0);
#define NVTE_UB_ERROR(x) \ #define NVTE_UB_ERROR(x) \
do { \ do { \
throw std::runtime_error(std::string(__FILE__ ":") + std::to_string(__LINE__) + \ throw std::runtime_error(std::string(__FILE__ ":") + std::to_string(__LINE__) + \
...@@ -89,12 +100,14 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode ...@@ -89,12 +100,14 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
(*comm)->push = 1; (*comm)->push = 1;
(*comm)->use_ce = 0; (*comm)->use_ce = 0;
(*comm)->cga_size = 2; (*comm)->cga_size = 2;
for (int i = 0; i < userbuffers_op_types; i++) (*comm)->basecounter[i] = 0; for (int i = 0; i < userbuffers_op_types; i++)
(*comm)->basecounter[i] = 0;
(*comm)->head = 0; (*comm)->head = 0;
(*comm)->tail = 0; (*comm)->tail = 0;
(*comm)->activeproxy = 1; (*comm)->activeproxy = 1;
(*comm)->active_nreqs = 0; (*comm)->active_nreqs = 0;
for (int i = 0; i < userbuffers_op_types; i++) (*comm)->active_req[i].active = -1; for (int i = 0; i < userbuffers_op_types; i++)
(*comm)->active_req[i].active = -1;
int ret = 0; int ret = 0;
// split communicator // split communicator
...@@ -112,8 +125,10 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode ...@@ -112,8 +125,10 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
color = 0; color = 0;
for (int n = 0; n < size; n++) { for (int n = 0; n < size; n++) {
if (n > 0 && strcmp(host_names[n - 1], host_names[n])) color++; if (n > 0 && strcmp(host_names[n - 1], host_names[n]))
if (strcmp(host_name, host_names[n]) == 0) break; color++;
if (strcmp(host_name, host_names[n]) == 0)
break;
} }
free(host_names); free(host_names);
...@@ -128,14 +143,22 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode ...@@ -128,14 +143,22 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
cpu_set_t cpuset; cpu_set_t cpuset;
CPU_ZERO(&cpuset); CPU_ZERO(&cpuset);
int core; int core;
if (mylocal == 0) core = 50; if (mylocal == 0)
if (mylocal == 1) core = 58; core = 50;
if (mylocal == 2) core = 18; if (mylocal == 1)
if (mylocal == 3) core = 26; core = 58;
if (mylocal == 4) core = 114; if (mylocal == 2)
if (mylocal == 5) core = 122; core = 18;
if (mylocal == 6) core = 82; if (mylocal == 3)
if (mylocal == 7) core = 90; core = 26;
if (mylocal == 4)
core = 114;
if (mylocal == 5)
core = 122;
if (mylocal == 6)
core = 82;
if (mylocal == 7)
core = 90;
CPU_SET(core, &cpuset); CPU_SET(core, &cpuset);
if (!getenv("NVTE_NODOUBLE")) { if (!getenv("NVTE_NODOUBLE")) {
...@@ -144,7 +167,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode ...@@ -144,7 +167,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
else else
CPU_SET(core + 128, &cpuset); CPU_SET(core + 128, &cpuset);
} }
if (getenv("NVTE_DOPIN")) pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset); if (getenv("NVTE_DOPIN"))
pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset);
if (ndev == numlocal) { // all visible devices if (ndev == numlocal) { // all visible devices
if (cur_dev != mylocal) if (cur_dev != mylocal)
...@@ -175,7 +199,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode ...@@ -175,7 +199,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
int datanodegroup_id = int datanodegroup_id =
myrank / numlocal / datanodes; // data reduction group node belongs, equals 0 for all if both myrank / numlocal / datanodes; // data reduction group node belongs, equals 0 for all if both
// pipenodes=1 and tensornodes=1 // pipenodes=1 and tensornodes=1
// mpi communicator only needed for SHARP which is always allreduce1/data-parallel // mpi communicator only needed for SHARP which is always
// allreduce1/data-parallel
MPI_Comm_split(MPI_COMM_WORLD, mylocal + numlocal * datanodegroup_id, rank, &(*comm)->comm_inter); MPI_Comm_split(MPI_COMM_WORLD, mylocal + numlocal * datanodegroup_id, rank, &(*comm)->comm_inter);
// different rails from same group are in different subcommunicators // different rails from same group are in different subcommunicators
...@@ -192,19 +217,37 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode ...@@ -192,19 +217,37 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
char *ib_dev_list; char *ib_dev_list;
int ZIONROCE = getenv("NVTE_ZIONROCE") ? atoi(getenv("NVTE_ZIONROCE")) : 0; int ZIONROCE = getenv("NVTE_ZIONROCE") ? atoi(getenv("NVTE_ZIONROCE")) : 0;
int ROCE = getenv("NVTE_ROCE") ? atoi(getenv("NVTE_ROCE")) : 0; int ROCE = getenv("NVTE_ROCE") ? atoi(getenv("NVTE_ROCE")) : 0;
if (ZIONROCE) ROCE = 1; if (ZIONROCE)
ROCE = 1;
int DGX_H100 = device_prop.major == 9; int DGX_H100 = device_prop.major == 9;
switch (mylocal) { switch (mylocal) {
case 0:ib_dev_list = "mlx5_0:1"; break; // NOLINT(*) case 0:
case 1:ib_dev_list = (char*)(DGX_H100?"mlx5_3:1":"mlx5_1:1"); break; // NOLINT(*) ib_dev_list = "mlx5_0:1";
case 2:ib_dev_list = (char*)(ZIONROCE?"mlx5_4:1":DGX_H100?"mlx5_4:1":"mlx5_2:1"); break; // NOLINT(*) break; // NOLINT(*)
case 3:ib_dev_list = (char*)(DGX_H100?"mlx5_5:1":"mlx5_3:1"); break; // NOLINT(*) case 1:
case 4:ib_dev_list = (char*)(DGX_H100?"mlx5_6:1":"mlx5_6:1"); break; // NOLINT(*) ib_dev_list = (char *)(DGX_H100 ? "mlx5_3:1" : "mlx5_1:1"); // NOLINT(*)
case 5:ib_dev_list = (char*)(DGX_H100?"mlx5_9:1":"mlx5_7:1"); break; // NOLINT(*) break; // NOLINT(*)
case 6:ib_dev_list = (char*)(ZIONROCE?"mlx5_10:1":DGX_H100?"mlx5_10:1":"mlx5_8:1"); break; // NOLINT(*) case 2:
case 7:ib_dev_list = (char*)(DGX_H100?"mlx5_11:1":"mlx5_9:1"); break; // NOLINT(*) ib_dev_list = (char *)(ZIONROCE ? "mlx5_4:1" : DGX_H100 ? "mlx5_4:1" : "mlx5_2:1"); // NOLINT(*)
default: break; break; // NOLINT(*)
case 3:
ib_dev_list = (char *)(DGX_H100 ? "mlx5_5:1" : "mlx5_3:1"); // NOLINT(*)
break; // NOLINT(*)
case 4:
ib_dev_list = (char *)(DGX_H100 ? "mlx5_6:1" : "mlx5_6:1"); // NOLINT(*)
break; // NOLINT(*)
case 5:
ib_dev_list = (char *)(DGX_H100 ? "mlx5_9:1" : "mlx5_7:1"); // NOLINT(*)
break; // NOLINT(*)
case 6:
ib_dev_list = (char *)(ZIONROCE ? "mlx5_10:1" : DGX_H100 ? "mlx5_10:1" : "mlx5_8:1"); // NOLINT(*)
break; // NOLINT(*)
case 7:
ib_dev_list = (char *)(DGX_H100 ? "mlx5_11:1" : "mlx5_9:1"); // NOLINT(*)
break; // NOLINT(*)
default:
break;
} }
(*comm)->fifo = reinterpret_cast<ub_request *>(malloc(sizeof(ub_request) * NVTE_MAX_REQUESTS)); (*comm)->fifo = reinterpret_cast<ub_request *>(malloc(sizeof(ub_request) * NVTE_MAX_REQUESTS));
...@@ -215,7 +258,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode ...@@ -215,7 +258,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
CUDACHECK(cudaMallocHost((void **)&(*comm)->hostflags, // NOLINT(*) CUDACHECK(cudaMallocHost((void **)&(*comm)->hostflags, // NOLINT(*)
(NVTE_MAX_SMS + 100) * sizeof(int))); (NVTE_MAX_SMS + 100) * sizeof(int)));
for (int i = 0; i < 100 + NVTE_MAX_SMS; i++) (*comm)->hostflags[i] = 0; for (int i = 0; i < 100 + NVTE_MAX_SMS; i++)
(*comm)->hostflags[i] = 0;
_mm_mfence(); _mm_mfence();
sleep(1); sleep(1);
...@@ -223,13 +267,16 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode ...@@ -223,13 +267,16 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
(*comm)->ibnvsize = (*comm)->nvsize; (*comm)->ibnvsize = (*comm)->nvsize;
#define NBUF 2 #define NBUF 2
#define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF) #define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF)
// peer pointers + op flags + comm buffer // peer pointers + op flags + comm buffer
CUDACHECK(cudaMalloc(&(*comm)->gpu_ptrs, LOCALSIZE)); // flags and pointers, no block data yet CUDACHECK(cudaMalloc(&(*comm)->gpu_ptrs,
LOCALSIZE)); // flags and pointers, no block data yet
CUDACHECK(cudaMemset((*comm)->gpu_ptrs, 0, LOCALSIZE)); CUDACHECK(cudaMemset((*comm)->gpu_ptrs, 0, LOCALSIZE));
CUDACHECK(cudaDeviceSynchronize()); CUDACHECK(cudaDeviceSynchronize());
register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm); // will use handler 0 register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE,
*comm); // will use handler 0
CUDACHECK(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int))); CUDACHECK(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int)));
CUDACHECK(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int))); CUDACHECK(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int)));
CUDACHECK(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int))); CUDACHECK(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int)));
...@@ -243,7 +290,6 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode ...@@ -243,7 +290,6 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
#define GPU_PAGE_MASK (~GPU_PAGE_OFFSET) #define GPU_PAGE_MASK (~GPU_PAGE_OFFSET)
CUDACHECK(cudaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE)); CUDACHECK(cudaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE));
unsigned int flag = 1; unsigned int flag = 1;
// cuPointerSetAttribute(&flag, CU_POINTER_ATTRIBUTE_SYNC_MEMOPS, (CUdeviceptr)(*comm)->flags);
CUDACHECK(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE)); CUDACHECK(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE));
(*comm)->flags = (*comm)->flags =
reinterpret_cast<int *>(((CUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK); reinterpret_cast<int *>(((CUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK);
...@@ -275,7 +321,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode ...@@ -275,7 +321,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
pthread_attr_setschedparam(&attr, &param); pthread_attr_setschedparam(&attr, &param);
if (getenv("NVTE_UBDEBUG")) if (getenv("NVTE_UBDEBUG"))
printf("%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP %dx%d TPGROUP %dx%d PIPE_ID %d/%d\n", printf("%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP %dx%d TPGROUP "
"%dx%d PIPE_ID %d/%d\n",
myrank, nranks, myrank / numlocal, myrank % numlocal, (*comm)->my_node, myrank, nranks, myrank / numlocal, myrank % numlocal, (*comm)->my_node,
(*comm)->ar_nvrank, (*comm)->my2_node, (*comm)->ar2_nvrank, (*comm)->num_nodes, (*comm)->ar_nvrank, (*comm)->my2_node, (*comm)->ar2_nvrank, (*comm)->num_nodes,
(*comm)->ar_nvsize, (*comm)->num2_nodes, (*comm)->ar2_nvsize, (*comm)->pipe_id, (*comm)->ar_nvsize, (*comm)->num2_nodes, (*comm)->ar2_nvsize, (*comm)->pipe_id,
...@@ -300,9 +347,9 @@ void destroy_communicator(communicator *comm) { ...@@ -300,9 +347,9 @@ void destroy_communicator(communicator *comm) {
} }
int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) { int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) {
if (comm->free_region > NVTE_MAX_REGIONS) return -1; if (comm->free_region > NVTE_MAX_REGIONS)
return -1;
int hndl = comm->free_region; int hndl = comm->free_region;
// printf("%d register %d size %lld\n",comm->myrank,hndl,bytes);fflush(NULL);
comm->peer_ptr[hndl] = reinterpret_cast<void **>(malloc(sizeof(void *) * (comm->nvsize))); comm->peer_ptr[hndl] = reinterpret_cast<void **>(malloc(sizeof(void *) * (comm->nvsize)));
if (alloc) { if (alloc) {
...@@ -313,25 +360,22 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * ...@@ -313,25 +360,22 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
reinterpret_cast<cudaIpcMemHandle_t *>(malloc(sizeof(cudaIpcMemHandle_t) * (comm->nvsize))); reinterpret_cast<cudaIpcMemHandle_t *>(malloc(sizeof(cudaIpcMemHandle_t) * (comm->nvsize)));
CUDACHECK(cudaIpcGetMemHandle(&memhndl[comm->nvrank], *gpubuff)); CUDACHECK(cudaIpcGetMemHandle(&memhndl[comm->nvrank], *gpubuff));
MPI_Allgather(&memhndl[comm->nvrank], sizeof(cudaIpcMemHandle_t), MPI_BYTE, memhndl, MPI_Allgather(&memhndl[comm->nvrank], sizeof(cudaIpcMemHandle_t), MPI_BYTE, memhndl,
sizeof(cudaIpcMemHandle_t), MPI_BYTE, comm->comm_intra); sizeof(cudaIpcMemHandle_t), MPI_BYTE, comm->comm_intra);
for (int i = 0; i < comm->nvsize; i++) for (int i = 0; i < comm->nvsize; i++)
if (i != comm->nvrank) if (i != comm->nvrank)
CUDACHECK(cudaIpcOpenMemHandle((void **)&(comm->peer_ptr[hndl][i]), // NOLINT(*) CUDACHECK(cudaIpcOpenMemHandle((void **)&(comm->peer_ptr[hndl][i]), // NOLINT(*)
memhndl[i], cudaIpcMemLazyEnablePeerAccess)); memhndl[i], cudaIpcMemLazyEnablePeerAccess));
comm->peer_ptr[hndl][comm->nvrank] = *gpubuff; comm->peer_ptr[hndl][comm->nvrank] = *gpubuff;
CUDACHECK(cudaDeviceSynchronize()); CUDACHECK(cudaDeviceSynchronize());
CUDACHECK( CUDACHECK(
cudaMemcpy(reinterpret_cast<char *>(comm->gpu_ptrs) + (hndl * comm->nvsize * sizeof(void *)), cudaMemcpy(reinterpret_cast<char *>(comm->gpu_ptrs) + (hndl * comm->nvsize * sizeof(void *)),
comm->peer_ptr[hndl], comm->nvsize * sizeof(void *), cudaMemcpyHostToDevice)); comm->peer_ptr[hndl], comm->nvsize * sizeof(void *), cudaMemcpyHostToDevice));
CUDACHECK(cudaDeviceSynchronize()); CUDACHECK(cudaDeviceSynchronize());
free(memhndl); free(memhndl);
comm->mem_ptr[hndl] = *gpubuff; comm->mem_ptr[hndl] = *gpubuff;
return comm->free_region++; return comm->free_region++;
} }
...@@ -352,8 +396,10 @@ int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons ...@@ -352,8 +396,10 @@ int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons
void allreduce_nonsharp_inplace(const int handler, const int offset, const int elements, void allreduce_nonsharp_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream, int op) { communicator *comm, cudaStream_t stream, int op) {
if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); if (elements < 64)
// if(comm->myrank==0) fprintf(stderr,"AR2(%d) user call launch_mode=%d\n",op,comm->launch_mode); NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
// if(comm->myrank==0) fprintf(stderr,"AR2(%d) user call
// launch_mode=%d\n",op,comm->launch_mode);
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
int blocksize = elements * 2; int blocksize = elements * 2;
int maxcredit = 0; int maxcredit = 0;
...@@ -361,19 +407,19 @@ void allreduce_nonsharp_inplace(const int handler, const int offset, const int e ...@@ -361,19 +407,19 @@ void allreduce_nonsharp_inplace(const int handler, const int offset, const int e
blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) / blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) /
comm->nblocks; // FIXME TUNING comm->nblocks; // FIXME TUNING
blocksize *= comm->alignblock; blocksize *= comm->alignblock;
if (blocksize < comm->minblock) blocksize = comm->minblock; if (blocksize < comm->minblock)
blocksize = comm->minblock;
maxcredit = (elements * 2 + blocksize - 1) / blocksize; maxcredit = (elements * 2 + blocksize - 1) / blocksize;
// if(maxcredit>4) maxcredit=4;
// if(maxcredit>4 && ar_nvsize==1) maxcredit=4;
size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit
if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize; if (blocksize > peerblock * ar_nvsize)
// blocksize=elements*2; blocksize = peerblock * ar_nvsize;
int sms = allreduce2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, comm, int sms = allreduce2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, comm,
stream, op); stream, op);
if (num_nodes > 1 && comm->launch_mode & NVTE_LAUNCH_CPU) { if (num_nodes > 1 && comm->launch_mode & NVTE_LAUNCH_CPU) {
if (!sms) return; if (!sms)
return;
comm->fifo[comm->head].optype = op; comm->fifo[comm->head].optype = op;
comm->fifo[comm->head].basecounter = comm->basecounter[op]; comm->fifo[comm->head].basecounter = comm->basecounter[op];
comm->fifo[comm->head].blocksize = blocksize; comm->fifo[comm->head].blocksize = blocksize;
...@@ -399,7 +445,8 @@ void allreduce2_userbuff_inplace(const int handler, const int offset, const int ...@@ -399,7 +445,8 @@ void allreduce2_userbuff_inplace(const int handler, const int offset, const int
void allreduce_userbuff_inplace(const int handler, const int offset, const int elements, void allreduce_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) { communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); if (elements < 64)
NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
allreduce_nonsharp_inplace(handler, offset, elements, comm, stream, allreduce_nonsharp_inplace(handler, offset, elements, comm, stream,
userbuffers_allreduceop_nonsharp); userbuffers_allreduceop_nonsharp);
return; return;
...@@ -407,7 +454,8 @@ void allreduce_userbuff_inplace(const int handler, const int offset, const int e ...@@ -407,7 +454,8 @@ void allreduce_userbuff_inplace(const int handler, const int offset, const int e
void reducescatter_userbuff_inplace(const int handler, const int offset, const int elements, void reducescatter_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) { communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); if (elements < 64)
NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
int op = userbuffers_allreduceop_nonsharp; int op = userbuffers_allreduceop_nonsharp;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
...@@ -418,17 +466,20 @@ void reducescatter_userbuff_inplace(const int handler, const int offset, const i ...@@ -418,17 +466,20 @@ void reducescatter_userbuff_inplace(const int handler, const int offset, const i
blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) / blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) /
comm->nblocks; // FIXME TUNING comm->nblocks; // FIXME TUNING
blocksize *= comm->alignblock; blocksize *= comm->alignblock;
if (blocksize < comm->minblock) blocksize = comm->minblock; if (blocksize < comm->minblock)
blocksize = comm->minblock;
maxcredit = (elements * 2 + blocksize - 1) / blocksize; maxcredit = (elements * 2 + blocksize - 1) / blocksize;
size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit
if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize; if (blocksize > peerblock * ar_nvsize)
blocksize = peerblock * ar_nvsize;
int sms = reducescatter2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, int sms = reducescatter2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize,
comm, stream, op); comm, stream, op);
if (num_nodes > 1 && comm->launch_mode & NVTE_LAUNCH_CPU) { if (num_nodes > 1 && comm->launch_mode & NVTE_LAUNCH_CPU) {
if (!sms) return; if (!sms)
return;
comm->fifo[comm->head].optype = op; comm->fifo[comm->head].optype = op;
comm->fifo[comm->head].basecounter = comm->basecounter[op]; comm->fifo[comm->head].basecounter = comm->basecounter[op];
comm->fifo[comm->head].blocksize = blocksize; comm->fifo[comm->head].blocksize = blocksize;
...@@ -448,7 +499,8 @@ void reducescatter_userbuff_inplace(const int handler, const int offset, const i ...@@ -448,7 +499,8 @@ void reducescatter_userbuff_inplace(const int handler, const int offset, const i
void allgather_userbuff_inplace(const int handler, const int offset, const int elements, void allgather_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) { communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); if (elements < 64)
NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
int op = userbuffers_allreduceop_nonsharp; int op = userbuffers_allreduceop_nonsharp;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
int blocksize = elements * 2; int blocksize = elements * 2;
...@@ -458,11 +510,13 @@ void allgather_userbuff_inplace(const int handler, const int offset, const int e ...@@ -458,11 +510,13 @@ void allgather_userbuff_inplace(const int handler, const int offset, const int e
blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) / blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) /
comm->nblocks; // FIXME TUNING comm->nblocks; // FIXME TUNING
blocksize *= comm->alignblock; blocksize *= comm->alignblock;
if (blocksize < comm->minblock) blocksize = comm->minblock; if (blocksize < comm->minblock)
blocksize = comm->minblock;
maxcredit = (elements * 2 + blocksize - 1) / blocksize; maxcredit = (elements * 2 + blocksize - 1) / blocksize;
size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit
if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize; if (blocksize > peerblock * ar_nvsize)
blocksize = peerblock * ar_nvsize;
int sms = allgather2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, comm, int sms = allgather2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, comm,
stream, op); stream, op);
......
...@@ -24,6 +24,18 @@ ...@@ -24,6 +24,18 @@
#define NVTE_LAUNCH_CPU 2 #define NVTE_LAUNCH_CPU 2
#define NVTE_MAX_NVLINK 8 #define NVTE_MAX_NVLINK 8
#define UB_MEM_UC_CONTIG 1
#define UB_MEM_MC_CREATED 2
#define UB_MEM_ALLOCATED 4
#define NVTE_UB_MEM_UC_CONTIG 1
#define NVTE_UB_MEM_MC_CREATED 2
#define NVTE_UB_MEM_ALLOCATED 4
#ifdef UCP
#include <ucp/api/ucp.h>
#endif
// region 0 flag offsets // region 0 flag offsets
#define NVTE_REG0_OPFLAGS 1024 #define NVTE_REG0_OPFLAGS 1024
#define NVTE_REG0_RECV (NVTE_REG0_OPFLAGS * userbuffers_op_types) #define NVTE_REG0_RECV (NVTE_REG0_OPFLAGS * userbuffers_op_types)
...@@ -35,6 +47,10 @@ ...@@ -35,6 +47,10 @@
#define NVTE_REG0_IBRS 32 #define NVTE_REG0_IBRS 32
#define NVTE_REG0_IBAG 512 #define NVTE_REG0_IBAG 512
#if defined(UCP) || !defined(NOSHARP)
#undef REG0_COMMBUFFER
#define REG0_COMMBUFFER (1024*1024*16)
#endif
// gpuflags map offsets // gpuflags map offsets
#define NVTE_GF_STATE 16000 #define NVTE_GF_STATE 16000
#define NVTE_GF_IBSHARPDONE 0 #define NVTE_GF_IBSHARPDONE 0
...@@ -81,6 +97,19 @@ struct communicator { ...@@ -81,6 +97,19 @@ struct communicator {
void *mem_ptr[NVTE_MAX_REGIONS]; void *mem_ptr[NVTE_MAX_REGIONS];
void **peer_ptr[NVTE_MAX_REGIONS]; void **peer_ptr[NVTE_MAX_REGIONS];
int memflags[NVTE_MAX_REGIONS]; // UC,MC, user/lib allocated
CUmemGenericAllocationHandle *uchandles[NVTE_MAX_REGIONS];
void* ucbase_ptr[NVTE_MAX_REGIONS]; // only for cuMem allocated memory
size_t mem_size[NVTE_MAX_REGIONS];
void* mc_ptr[NVTE_MAX_REGIONS];
void* mc_baseptr;
CUmemGenericAllocationHandle mc_handle;
size_t mc_offset, mc_maxsize;
int use_mc; // 1: use MC if available, 0: override not to use MC
int ar_nvsize, ar_firstgpu, int ar_nvsize, ar_firstgpu,
ar_nvrank; // number of gpus(and first gpu in a group) of gpus per node in reduction subgroup ar_nvrank; // number of gpus(and first gpu in a group) of gpus per node in reduction subgroup
// (_splitar init used) would be equal to (nvsize,0) for regular comm_create // (_splitar init used) would be equal to (nvsize,0) for regular comm_create
...@@ -120,6 +149,8 @@ struct communicator { ...@@ -120,6 +149,8 @@ struct communicator {
}; };
typedef struct communicator communicator; typedef struct communicator communicator;
void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream);
void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream);
int create_communicator(communicator **comm); int create_communicator(communicator **comm);
/* creates communicator, allocates all internal buffers if necessary */ /* creates communicator, allocates all internal buffers if necessary */
...@@ -191,6 +222,45 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons ...@@ -191,6 +222,45 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons
const int rowelements, const int colelements, const int rowelements, const int colelements,
const int strideelements, communicator *comm, const int strideelements, communicator *comm,
cudaStream_t stream = 0); cudaStream_t stream = 0);
template<typename fp8type>
void reducescatter2_userbuff_stridedoutput_fp8(void* output, float* scale, const int handler,
const int offset, const int rowelements,
const int colelements, const int strideelements,
communicator* comm, cudaStream_t stream = 0);
template<typename fp8type>
void reducescatter2_userbuff_fp8(void* output, float* scale, const int handler, const int offset,
const int elements, communicator* comm, cudaStream_t stream = 0);
#if 0
template<typename fp8type>
void reducescatter2_userbuff_strided_atomic_fp8(void* output, float *scale, const int handler,
const int offset, const int rowelements,
const int colelements, const int strideelements,
const int numchunks, void *counters,
communicator* comm, cudaStream_t stream = 0);
#endif
template<typename fp8type>
void reducescatter2_userbuff_strided_atomic_fp8(void* output, float *scale, const int handler,
const int offset, const int rowelements,
const int colelements, const int strideelements_out,
const int strideelements_in, const int numchunks,
void *counters, communicator* comm,
cudaStream_t stream = 0);
template<typename fp8type>
void reducescatter2_userbuff_strided_multiatomic_fp8(
void* output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator* comm, cudaStream_t stream = 0);
void reducescatter2_userbuff_strided(
void* output, const int handler, const int offset, const int rowelements, const int colelements,
const int strideelements, communicator* comm, cudaStream_t stream = 0);
void reducescatter2_userbuff_strided_atomic(
void* output, const int handler , const int offset, const int rowelements, const int colelements,
const int strideelements, const int numchunks, void *counters, communicator* comm,
cudaStream_t stream = 0);
void reducescatter2_userbuff_strided_multiatomic(
void* output, const int handler, const int offset, const int rowelements, const int colelements,
const int strideelements, const int numchunks, void *counters, communicator* comm,
cudaStream_t stream = 0);
/* everything should be 16byte aligned = 8 elts aligned /* everything should be 16byte aligned = 8 elts aligned
output is strided: row starts separated by stride elements*/ output is strided: row starts separated by stride elements*/
...@@ -208,6 +278,19 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds ...@@ -208,6 +278,19 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds
void userbuffers_recv(const int srchandler, const size_t srcoffset, const int dsthandler, void userbuffers_recv(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes, communicator *comm, const size_t dstoffset, const size_t bytes, communicator *comm,
const int peer, cudaStream_t stream = 0); const int peer, cudaStream_t stream = 0);
void userbuffers_sendrecv(
const int srchandler, const int dsthandler, const size_t send_offset, const size_t recv_offset,
const size_t bytes, communicator* comm, const int send_peer, const int recv_peer,
cudaStream_t stream = 0);
void userbuffers_sendrecv_atomic(
const int srchandler, const int dsthandler, const size_t send_offset, const size_t recv_offset,
const size_t bytes, communicator* comm, const int send_peer, const int recv_peer, void *counters,
cudaStream_t stream = 0);
void userbuffers_sendrecv_multiatomic(
const int srchandler, const int dsthandler, const size_t send_offset, const size_t recv_offset,
const size_t bytes, communicator* comm, const int send_peer, const int recv_peer,
const int nchunks, void *counters, bool shuffle, cudaStream_t stream = 0);
// alltoall split send and recv to allow for overlap // alltoall split send and recv to allow for overlap
// send kicks in sending data to the destination - invoke on same stream as data generation // send kicks in sending data to the destination - invoke on same stream as data generation
......
...@@ -124,6 +124,8 @@ def initialize_ub( ...@@ -124,6 +124,8 @@ def initialize_ub(
fp8_buf = [ fp8_buf = [
"qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad" "qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad"
] ]
if bool(int(os.getenv("NVTE_UB_FP8_RS", "0"))):
fp8_buf.append ("proj_fprop")
# Default overlap methods for layers # Default overlap methods for layers
methods = { methods = {
"ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"], "ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
...@@ -153,8 +155,12 @@ def initialize_ub( ...@@ -153,8 +155,12 @@ def initialize_ub(
sample_buffer, # Sample userbuffer sample_buffer, # Sample userbuffer
rank_id, # Rank id rank_id, # Rank id
tp_size, # TP size tp_size, # TP size
num_sm, # Number of communication SMs
cga_size, # CGA cluster size
set_sm_margin, # Set SM margin
aggregate, # Aggregate 2X GEMM chunks aggregate, # Aggregate 2X GEMM chunks
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
torch.Tensor(), # empty tensor to pass to counters
) )
else: else:
ub_obj = tex.UbufCommOverlap( ub_obj = tex.UbufCommOverlap(
...@@ -166,6 +172,7 @@ def initialize_ub( ...@@ -166,6 +172,7 @@ def initialize_ub(
num_splits, # Number of communication splits num_splits, # Number of communication splits
set_sm_margin, # Set SM margin set_sm_margin, # Set SM margin
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
torch.Tensor(), # empty tensor to pass to counters
) )
_ub_communicators[name] = ub_obj _ub_communicators[name] = ub_obj
...@@ -676,10 +683,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -676,10 +683,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_output_mat = grad_output.view((-1, grad_output.shape[-1])) grad_output_mat = grad_output.view((-1, grad_output.shape[-1]))
gather_grad_output = row_parallel_mode and ctx.sequence_parallel gather_grad_output = row_parallel_mode and ctx.sequence_parallel
if gather_grad_output:
ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag
# No-FP8 case: bgrad is fused with wgrad for this case. # No-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8: if not ctx.fp8:
if gather_grad_output: if gather_grad_output:
if not ctx.ub_split_ag: if not ub_overlap_ag:
grad_output_mat, _ = gather_along_first_dim( grad_output_mat, _ = gather_along_first_dim(
grad_output_mat, ctx.tp_group grad_output_mat, ctx.tp_group
) )
...@@ -698,8 +707,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -698,8 +707,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
and ctx.fp8_meta["recipe"].override_linear_precision.wgrad and ctx.fp8_meta["recipe"].override_linear_precision.wgrad
): ):
assert ( assert (
not ctx.ub_split_ag not ub_overlap_ag
), "override_linear_precision.wgrad not supported with ub_split_ag" ), "override_linear_precision.wgrad not supported with UB AG overlap"
grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group) grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group)
# FP8 case with gather: unfused bgrad, cast, transpose for efficient gather # FP8 case with gather: unfused bgrad, cast, transpose for efficient gather
elif gather_grad_output: elif gather_grad_output:
...@@ -707,7 +716,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -707,7 +716,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_bias = grad_output_mat.sum(dim=0) grad_bias = grad_output_mat.sum(dim=0)
else: else:
grad_bias = None grad_bias = None
if ctx.ub_split_ag: if ub_overlap_ag:
grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0) grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0)
else: else:
grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8) grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8)
...@@ -718,7 +727,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -718,7 +727,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
fp8_dtype_backward, fp8_dtype_backward,
out=grad_output_c, out=grad_output_c,
) )
if not ctx.ub_split_ag: if not ub_overlap_ag:
grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group) grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group)
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
else: else:
......
...@@ -83,6 +83,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -83,6 +83,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_bulk_dgrad: bool, ub_bulk_dgrad: bool,
ub_split_ag: bool, ub_split_ag: bool,
normalization: str, normalization: str,
ub_atomic_gemm_ag: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -100,11 +101,12 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -100,11 +101,12 @@ class _LayerNormLinear(torch.autograd.Function):
if ln_bias is not None: if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype)
if ub_split_ag: if ub_split_ag or ub_atomic_gemm_ag:
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output:
ub_split_ag = False ub_split_ag = False
if ub_split_ag: ub_atomic_gemm_ag = False
if ub_split_ag or ub_atomic_gemm_ag:
dim_size = list(inputmat.size()) dim_size = list(inputmat.size())
dim_size[0] = dim_size[0] * tp_world_size dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("qkv_fprop") ub_obj_lnout = get_ub("qkv_fprop")
...@@ -112,6 +114,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -112,6 +114,8 @@ class _LayerNormLinear(torch.autograd.Function):
else: else:
ln_out_dtype = torch.uint8 if fp8 else inputmat.dtype ln_out_dtype = torch.uint8 if fp8 else inputmat.dtype
ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype) ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype)
if ub_atomic_gemm_ag:
assert fp8, "AtomicGemm overlap supported only for FP8 GEMM."
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
...@@ -139,7 +143,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -139,7 +143,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
) )
# Column Parallel Linear # Column Parallel Linear
if ub_split_ag: if ub_split_ag or ub_atomic_gemm_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out) ln_out = torch.empty_like(ln_out)
elif parallel_mode == "column" and sequence_parallel: elif parallel_mode == "column" and sequence_parallel:
...@@ -173,6 +177,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -173,6 +177,8 @@ class _LayerNormLinear(torch.autograd.Function):
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward) fp8_dtype_forward)
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo
out, _ = tex.fp8_gemm( out, _ = tex.fp8_gemm(
weight_fp8, weight_fp8,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
...@@ -187,9 +193,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -187,9 +193,9 @@ class _LayerNormLinear(torch.autograd.Function):
bias=bias, bias=bias,
use_bias=use_bias, use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None, ub_algo=ub_algo,
ub=ub_obj_lnout if ub_split_ag else None, ub=ub_obj_lnout if (ub_split_ag or ub_atomic_gemm_ag) else None,
extra_output_tensor=ln_out if ub_split_ag else None, extra_output_tensor=ln_out if (ub_split_ag or ub_atomic_gemm_ag) else None,
) )
else: else:
# Cast for native AMP # Cast for native AMP
...@@ -339,6 +345,14 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -339,6 +345,14 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_backward = get_fp8_te_dtype( fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False ctx.fp8_meta["recipe"], fprop_tensor=False
) )
out_index, meta_tensor, out_te_type, out_type = (
None, None, None, ctx.activation_dtype)
if ctx.ub_bulk_wgrad and ub_obj_dgrad.is_fp8_ubuf():
out_index = tex.FP8BwdTensors.GRAD_INPUT1
meta_tensor = ctx.fp8_meta["scaling_bwd"]
out_te_type = fp8_dtype_backward
out_type = torch.uint8
ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index])
# DGRAD: Evaluated unconditionally to feed into Linear backward # DGRAD: Evaluated unconditionally to feed into Linear backward
_ = tex.fp8_gemm( _ = tex.fp8_gemm(
...@@ -350,12 +364,15 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -350,12 +364,15 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.fp8_meta["scaling_bwd"].scale_inv, ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1, tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward, fp8_dtype_backward,
ctx.activation_dtype, out_type,
get_workspace(), get_workspace(),
out=dgrad, out=dgrad,
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None,
out_index=out_index,
fp8_meta_tensor = meta_tensor,
D_dtype = out_te_type,
) )
else: else:
# DGRAD: Evaluated unconditionally to feed into Linear backward # DGRAD: Evaluated unconditionally to feed into Linear backward
...@@ -387,6 +404,15 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -387,6 +404,15 @@ class _LayerNormLinear(torch.autograd.Function):
if weight.requires_grad: if weight.requires_grad:
if ctx.fp8: if ctx.fp8:
# WGRAD # WGRAD
extra_output_tensor = None
if ctx.ub_bulk_wgrad:
if ub_obj_dgrad.is_fp8_ubuf():
dim_size = list(ub_obj_dgrad.get_ubuf_output(0).size()) # RS output
extra_output_tensor = torch.empty(
dim_size, dtype=ctx.activation_dtype, device=dgrad.device)
dgrad = extra_output_tensor
else:
dgrad = ub_obj_dgrad.get_ubuf_output(0)
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
wgrad, _ = tex.fp8_gemm( wgrad, _ = tex.fp8_gemm(
...@@ -405,7 +431,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -405,7 +431,8 @@ class _LayerNormLinear(torch.autograd.Function):
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None, if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor
) )
else: else:
ln_out_total_c = tex.cast_from_fp8( ln_out_total_c = tex.cast_from_fp8(
...@@ -426,7 +453,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -426,7 +453,8 @@ class _LayerNormLinear(torch.autograd.Function):
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None, if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor
) )
else: else:
# WGRAD # WGRAD
...@@ -443,11 +471,14 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -443,11 +471,14 @@ class _LayerNormLinear(torch.autograd.Function):
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
) )
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output
# Column Parallel Linear # Column Parallel Linear
elif ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None: if ((not ctx.ub_bulk_wgrad)
and ctx.parallel_mode == "column"
and ctx.tensor_parallel
and handle is not None):
handle.wait() handle.wait()
# LayerNorm gradient # LayerNorm gradient
...@@ -512,6 +543,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -512,6 +543,7 @@ class _LayerNormLinear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -624,6 +656,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -624,6 +656,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_split_ag: bool = False, ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
ub_atomic_gemm_ag: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -650,12 +683,18 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -650,12 +683,18 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_ag = ub_split_ag self.ub_split_ag = ub_split_ag
self.ub_atomic_gemm_ag = ub_atomic_gemm_ag
if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_ag: if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_ag or ub_atomic_gemm_ag:
assert ( assert (
tex.userbuf_comm_available() tex.userbuf_comm_available()
), "Userbuffer communication backend not available." ), "Userbuffer communication backend not available."
if ub_atomic_gemm_ag:
warnings.warn(
"Atomic gemm uses a beta API from cublas and is not tested for all use cases."
)
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
if tp_size == 1: if tp_size == 1:
...@@ -919,6 +958,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -919,6 +958,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.ub_bulk_dgrad, self.ub_bulk_dgrad,
self.ub_split_ag, self.ub_split_ag,
self.normalization, self.normalization,
self.ub_atomic_gemm_ag,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""LayerNormMLP API""" """LayerNormMLP API"""
import os import os
import warnings
from typing import Union, Optional, Callable, Tuple, List, Dict, Any from typing import Union, Optional, Callable, Tuple, List, Dict, Any
import torch import torch
...@@ -107,7 +108,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -107,7 +108,9 @@ class _LayerNormMLP(torch.autograd.Function):
ub_bulk_wgrad: bool, ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool, ub_bulk_dgrad: bool,
ub_split_rs: bool, ub_split_rs: bool,
ub_atomic_gemm_rs: bool,
ub_split_ag: bool, ub_split_ag: bool,
ub_atomic_gemm_ag: bool,
activation: str, activation: str,
normalization: str, normalization: str,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
...@@ -130,20 +133,25 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -130,20 +133,25 @@ class _LayerNormMLP(torch.autograd.Function):
if ln_bias is not None: if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype)
if ub_split_ag: if ub_split_ag or ub_atomic_gemm_ag:
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output:
ub_split_ag = False ub_split_ag = False
if ub_split_ag: ub_atomic_gemm_ag = False
ub_overlap_ag = ub_split_ag or ub_atomic_gemm_ag
if ub_overlap_ag:
ub_obj_lnout = get_ub("fc1_fprop") ub_obj_lnout = get_ub("fc1_fprop")
ln_out = ub_obj_lnout.get_ubuf_output(0) ln_out = ub_obj_lnout.get_ubuf_output(0)
else: else:
ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype
ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype) ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype)
if ub_split_rs: if ub_split_rs or ub_atomic_gemm_rs:
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1: if tp_world_size == 1:
ub_split_rs = False ub_split_rs = False
ub_atomic_gemm_rs = False
if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
assert fp8, "AtomicGemm overlap supported only for FP8 GEMM."
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
...@@ -171,7 +179,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -171,7 +179,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
) )
# Column Parallel Linear # Column Parallel Linear
if ub_split_ag: if ub_overlap_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out) ln_out = torch.empty_like(ln_out)
elif set_parallel_mode and sequence_parallel: elif set_parallel_mode and sequence_parallel:
...@@ -223,6 +231,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -223,6 +231,8 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
) )
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo
fc1_out, _ = tex.fp8_gemm( fc1_out, _ = tex.fp8_gemm(
fc1_weight_fp8, fc1_weight_fp8,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
...@@ -237,9 +247,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -237,9 +247,9 @@ class _LayerNormMLP(torch.autograd.Function):
bias=fc1_bias, bias=fc1_bias,
use_bias=use_fc1_bias, use_bias=use_fc1_bias,
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None, ub_algo=ub_algo,
ub=ub_obj_lnout if ub_split_ag else None, ub=ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if ub_split_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None,
) )
gelu_out = activation_func( gelu_out = activation_func(
...@@ -249,18 +259,29 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -249,18 +259,29 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
) )
if ub_split_rs: fc2_out_index, fc2_meta_tensor, fc2_te_type, out_type = (
None, None, None, activation_dtype)
if ub_split_rs or ub_atomic_gemm_rs:
ub_obj_fc2out = get_ub("fc2_fprop") ub_obj_fc2out = get_ub("fc2_fprop")
fc2_out = ub_obj_fc2out.get_ubuf_output(1) fc2_out = ub_obj_fc2out.get_ubuf_output(1)
dim_size = list(gelu_out.size()) dim_size = list(gelu_out.size())
dim_size[0] = dim_size[0] // tp_world_size dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = fc2_weight.size(0) dim_size[1] = fc2_weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
if ub_obj_fc2out.is_fp8_ubuf():
fc2_out_index = tex.FP8FwdTensors.GEMM2_OUTPUT
fc2_meta_tensor = fp8_meta["scaling_fwd"]
fc2_te_type = fp8_dtype_forward
out_type = torch.uint8
ub_obj_fc2out.set_ubuf_scale_inv(fc2_meta_tensor.scale_inv[fc2_out_index])
else: else:
dim_size = list(gelu_out.size()) dim_size = list(gelu_out.size())
dim_size[1] = fc2_weight.size(0) dim_size[1] = fc2_weight.size(0)
fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo
_ = tex.fp8_gemm( _ = tex.fp8_gemm(
fc2_weight_fp8, fc2_weight_fp8,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
...@@ -270,15 +291,18 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -270,15 +291,18 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM2_INPUT, tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
activation_dtype, out_type,
get_workspace(), get_workspace(),
bias=fc2_bias, bias=fc2_bias,
use_bias=use_fc2_bias, use_bias=use_fc2_bias,
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
out=fc2_out, out=fc2_out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None, ub_algo=ub_algo,
ub=ub_obj_fc2out if ub_split_rs else None, ub=ub_obj_fc2out if ub_split_rs or ub_atomic_gemm_rs else None,
extra_output_tensor=rs_out if ub_split_rs else None, extra_output_tensor=rs_out if ub_split_rs or ub_atomic_gemm_rs else None,
out_index=fc2_out_index,
fp8_meta_tensor = fc2_meta_tensor,
D_dtype = fc2_te_type,
) )
else: else:
# Cast for native AMP # Cast for native AMP
...@@ -394,11 +418,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -394,11 +418,12 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_wgrad = ub_bulk_wgrad
ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.ub_split_ag = ub_split_ag ctx.ub_split_ag = ub_split_ag
ctx.ub_atomic_gemm_ag = ub_atomic_gemm_ag
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization ctx.normalization = normalization
# Row Parallel Linear # Row Parallel Linear
if ub_split_rs: if ub_split_rs or ub_atomic_gemm_rs:
fc2_out = rs_out fc2_out = rs_out
elif set_parallel_mode and sequence_parallel: elif set_parallel_mode and sequence_parallel:
fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group) fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group)
...@@ -447,11 +472,15 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -447,11 +472,15 @@ class _LayerNormMLP(torch.autograd.Function):
dim_size[0] = dim_size[0] * tp_world_size dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("fc1_dgrad") ub_obj_lnout = get_ub("fc1_dgrad")
ub_obj_lnout.copy_input_to_ubuf(ln_out, 1) ub_obj_lnout.copy_input_to_ubuf(ln_out, 1)
if ctx.ub_split_ag: ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag
if ub_overlap_ag:
tp_world_size = get_distributed_world_size(ctx.tp_group) tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1: if tp_world_size == 1:
ctx.ub_split_ag = False ctx.ub_split_ag = False
if ctx.ub_split_ag: ctx.ub_overlap_ag = False
ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag
if ub_overlap_ag:
dim_size = list(grad_outputs[0].size()) dim_size = list(grad_outputs[0].size())
dim_size[0] = dim_size[0] * tp_world_size dim_size[0] = dim_size[0] * tp_world_size
ctx.ub_obj_gradout = get_ub("fc2_dgrad") ctx.ub_obj_gradout = get_ub("fc2_dgrad")
...@@ -497,6 +526,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -497,6 +526,8 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fp8_meta["recipe"], fprop_tensor=False ctx.fp8_meta["recipe"], fprop_tensor=False
) )
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo
# FC2 DGRAD; Unconditional # FC2 DGRAD; Unconditional
fc2_dgrad, _ = tex.fp8_gemm( fc2_dgrad, _ = tex.fp8_gemm(
fc2_weight_t_fp8, fc2_weight_t_fp8,
...@@ -510,10 +541,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -510,10 +541,10 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None, ub_algo=ub_algo,
ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None, ub=ctx.ub_obj_gradout if ub_overlap_ag else None,
) )
if ctx.ub_split_ag: if ub_overlap_ag:
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
# FC2 WGRAD # FC2 WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
...@@ -595,11 +626,19 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -595,11 +626,19 @@ class _LayerNormMLP(torch.autograd.Function):
) )
dgelu_t = None dgelu_t = None
out_index, meta_tensor, out_te_type, out_type = (
None, None, None, ctx.activation_dtype)
fc1_dgrad_size = list(dgelu.size()) fc1_dgrad_size = list(dgelu.size())
fc1_dgrad_size[1] = fc1_weight.size(1) fc1_dgrad_size[1] = fc1_weight.size(1)
if ctx.ub_bulk_wgrad: # allocate dgrad output if ctx.ub_bulk_wgrad: # allocate dgrad output
ub_obj_dgrad = get_ub("fc1_wgrad") ub_obj_dgrad = get_ub("fc1_wgrad")
fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
if ub_obj_dgrad.is_fp8_ubuf():
out_index = tex.FP8BwdTensors.GRAD_INPUT2
meta_tensor = ctx.fp8_meta["scaling_bwd"]
out_te_type = fp8_dtype_backward
out_type = torch.uint8
ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index])
else: else:
fc1_dgrad = torch.empty( fc1_dgrad = torch.empty(
fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device
...@@ -614,12 +653,15 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -614,12 +653,15 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fp8_meta["scaling_bwd"].scale_inv, ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT2, tex.FP8BwdTensors.GRAD_OUTPUT2,
fp8_dtype_backward, fp8_dtype_backward,
ctx.activation_dtype, out_type,
get_workspace(), get_workspace(),
out=fc1_dgrad, out=fc1_dgrad,
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None,
out_index=out_index,
fp8_meta_tensor = meta_tensor,
D_dtype = out_te_type,
) )
else: else:
# FC2 DGRAD; Unconditional # FC2 DGRAD; Unconditional
...@@ -703,6 +745,15 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -703,6 +745,15 @@ class _LayerNormMLP(torch.autograd.Function):
if fc1_weight.requires_grad: if fc1_weight.requires_grad:
if ctx.fp8: if ctx.fp8:
# FC1 WGRAD # FC1 WGRAD
extra_output_tensor = None
if ctx.ub_bulk_wgrad:
if ub_obj_dgrad.is_fp8_ubuf():
dim_size = list(ub_obj_dgrad.get_ubuf_output(0).size()) # RS output
extra_output_tensor = torch.empty(
dim_size, dtype=ctx.activation_dtype, device=fc1_dgrad.device)
fc1_dgrad = extra_output_tensor
else:
fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0)
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
fc1_wgrad, _ = tex.fp8_gemm( fc1_wgrad, _ = tex.fp8_gemm(
...@@ -724,6 +775,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -724,6 +775,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None, if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor,
) )
else: else:
ln_out_total_c = tex.cast_from_fp8( ln_out_total_c = tex.cast_from_fp8(
...@@ -747,6 +799,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -747,6 +799,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None, if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor,
) )
else: else:
# FC1 WGRAD # FC1 WGRAD
...@@ -768,11 +821,14 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -768,11 +821,14 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_wgrad, _, _ = fc1_wgrad_outputs fc1_wgrad, _, _ = fc1_wgrad_outputs
else: else:
fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs
# Column Parallel Linear
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output
elif ctx.set_parallel_mode and ctx.tensor_parallel and handle is not None:
# Column Parallel Linear
if ((not ctx.ub_bulk_wgrad)
and ctx.set_parallel_mode
and ctx.tensor_parallel
and handle is not None):
handle.wait() handle.wait()
# LayerNorm gradient # LayerNorm gradient
...@@ -850,6 +906,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -850,6 +906,8 @@ class _LayerNormMLP(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
...@@ -965,8 +1023,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -965,8 +1023,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False, ub_split_rs: bool = False,
ub_atomic_gemm_rs: bool = False,
ub_split_ag: bool = False, ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
ub_atomic_gemm_ag: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -987,12 +1047,24 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -987,12 +1047,24 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_rs = ub_split_rs self.ub_split_rs = ub_split_rs
self.ub_split_ag = ub_split_ag self.ub_split_ag = ub_split_ag
self.ub_atomic_gemm_rs = ub_atomic_gemm_rs
if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_rs or ub_split_ag: self.ub_atomic_gemm_ag = ub_atomic_gemm_ag
if (ub_bulk_wgrad # pylint: disable=too-many-boolean-expressions
or ub_bulk_dgrad
or ub_split_rs
or ub_split_ag
or ub_atomic_gemm_rs
or ub_atomic_gemm_ag):
assert ( assert (
tex.userbuf_comm_available() tex.userbuf_comm_available()
), "Userbuffer communication backend not available." ), "Userbuffer communication backend not available."
if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
warnings.warn(
"Atomic gemm uses a beta API from cublas and is not tested for all use cases."
)
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
if tp_size == 1: if tp_size == 1:
...@@ -1210,7 +1282,9 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1210,7 +1282,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.ub_bulk_wgrad, self.ub_bulk_wgrad,
self.ub_bulk_dgrad, self.ub_bulk_dgrad,
self.ub_split_rs, self.ub_split_rs,
self.ub_atomic_gemm_rs,
self.ub_split_ag, self.ub_split_ag,
self.ub_atomic_gemm_ag,
self.activation, self.activation,
self.normalization, self.normalization,
) )
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment