Unverified Commit 958e1889 authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

Atomic gemm and FP8 Reduce Scatter (#449)



* Initial commit
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Repro for RS output mismatch with Single GEMM + Split pipelined RS
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* minor changes for AG->GEMM pipelined overlap
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Add Atomic Gemm cublasApi attributes and initial implementation of AG->Atomic GEMM
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* AtomicGemm+RS functional with workaround
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* add amax update to layernorm_linear for FP8 unit test accuracy
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Enable reducescatter2_userbuff_strided variants
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fix
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* AG+AtomicGemm overlap functional but gemm doesnt overlap with comm
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Add userbuffers_sendrecv kernel variants
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* TransformerLayer API changes to enable AtomicGemm+RS overlap
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Code cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Code cleanup2
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [UB] AllGather Atomic GEMM overlap using userbuffer_sendrecv kernels
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Code cleanup + bug fix for multiatomic sendrecv kernel
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fixes
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [UB] Add shuffling for better AG AtomicGEMM overlap
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fix for AG AtomicGemm overlap
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fix for multiAtomicAG and singleAtomicAG
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Use chunk_i+1 as recv_chunk for multiatomic_AG with shuffling
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Launch AtomicGEMM after first-chunk AG
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Rebase to main
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Add FP8 ReduceScatter kernels, AtomicGEMM+FP8 RS not functional
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Revert "Add FP8 ReduceScatter kernels, AtomicGEMM+FP8 RS not functional"

This reverts commit 80a47a76355440cd5fb4314c96fe9fda632d87f9.
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Add support for NVLS-MC and FP8 Reduce Scatter
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fix
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Atomic and Multiatomic FP8 RS functional
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Remove debug print
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* UB comm initialization hang fix
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Code cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Create new GEMM API for Atomic GEMM
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* CI ready
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* more fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* license
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fix
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Revert NVLS-MC
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Check cu* versions for running atomic gemms
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Add experimental warning
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Better wording
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add warning to c api
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix wording
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent be67f219
......@@ -506,7 +506,7 @@ def test_export_gemm(
self.fp8_tensor_weight,
self.weights_type)
ret = fp8_gemm(
ret, _ = fp8_gemm(
weight_fp8,
self.meta_weight.scale_inv,
self.fp8_tensor_weight,
......@@ -1323,7 +1323,7 @@ def test_export_gemm_layernorm(
self.fp8_tensor_weight,
self.weights_type)
ret = fp8_gemm(
ret, _ = fp8_gemm(
weight_fp8,
self.meta_weight.scale_inv,
self.fp8_tensor_weight,
......
......@@ -7,6 +7,7 @@
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/logging.h>
#include <transformer_engine/gemm.h>
#include <cuda.h>
#include <cublasLt.h>
#include <cublas_v2.h>
#include "../common.h"
......@@ -50,6 +51,10 @@ void cublas_gemm(const Tensor *inputA,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const Tensor *inputCounter,
cudaStream_t stream
) {
void *A = inputA->data.dptr;
......@@ -63,6 +68,10 @@ void cublas_gemm(const Tensor *inputA,
void *bias_ptr = inputBias->data.dptr;
const bool bias = bias_ptr != nullptr;
void *pre_gelu_out = outputPreGelu->data.dptr;
void *counter = nullptr;
if (inputCounter != nullptr) {
counter = inputCounter->data.dptr;
}
const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype);
......@@ -223,6 +232,27 @@ void cublas_gemm(const Tensor *inputA,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue, sizeof(epilogue)));
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205
if (counter != nullptr) {
if (m_split == 0) m_split=1;
if (n_split == 0) n_split=1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS,
&m_split, sizeof(m_split)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS,
&n_split, sizeof(n_split)));
if (gemm_producer) {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER,
&counter, sizeof(counter)));
} else {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER,
&counter, sizeof(counter)));
}
}
#endif
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
......@@ -254,7 +284,6 @@ void cublas_gemm(const Tensor *inputA,
workspaceSize,
stream)); /* stream */
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc));
......@@ -320,5 +349,82 @@ void nvte_cublas_gemm(const NVTETensor A,
wspace->data.shape[0],
accumulate, use_split_accumulator,
math_sm_count,
0,
0,
false,
nullptr,
stream);
}
void nvte_cublas_atomic_gemm(const NVTETensor A,
const NVTETensor B,
NVTETensor D,
const NVTETensor bias,
NVTETensor pre_gelu_out,
bool transa,
bool transb,
bool grad,
NVTETensor workspace,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const NVTETensor counter,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_atomic_gemm);
int cudart_version;
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&cudart_version));
NVTE_CHECK(cudart_version >= 12020, "Cuda version 12.2 is required for atomic gemm.");
NVTE_CHECK(cublasLtGetVersion() >= 120205, "Cublas version 12.2.5 is required for atomic gemm.");
using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor*>(A);
const Tensor *inputB = reinterpret_cast<const Tensor*>(B);
Tensor *outputD = reinterpret_cast<Tensor*>(D);
const Tensor *biasTensor = reinterpret_cast<const Tensor*>(bias);
Tensor *outputGelu = reinterpret_cast<Tensor*>(pre_gelu_out);
const Tensor *inputCounter = reinterpret_cast<const Tensor*>(counter);
Tensor *wspace = reinterpret_cast<Tensor*>(workspace);
const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
int lda, ldb, ldd;
if (transa && !transb) { // TN
lda = k;
ldb = k;
ldd = m;
} else if (!transa && !transb) { // NN
lda = m;
ldb = k;
ldd = m;
} else if (!transa && transb) { // NT
lda = m;
ldb = n;
ldd = m;
} else { // TT
NVTE_ERROR("TT layout not allowed.");
}
cublas_gemm(inputA,
inputB,
outputD,
biasTensor,
outputGelu,
m, n, k,
lda, ldb, ldd,
(transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N,
grad, wspace->data.dptr,
wspace->data.shape[0],
accumulate, use_split_accumulator,
math_sm_count,
m_split,
n_split,
gemm_producer,
inputCounter,
stream);
}
......@@ -54,6 +54,52 @@ void nvte_cublas_gemm(const NVTETensor A,
cudaStream_t stream
);
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
*
* \warning Cublas atomic gemm uses a beta API and is not tested for all use cases.
*
* Computes:
* - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors
* - `D = AB + bias` if `pre_gelu_out` is empty and `bias` is not empty
* - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors
*
* \param[in] A The A matrix.
* \param[in] B The B matrix.
* \param[in,out] D Output matrix.
* \param[in] bias Bias tensor.
* \param[in,out] pre_gelu_out Output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of the
* gradient computation.
* \param[out] workspace Workspace tensor.
* \param[in] accumulate Whether to accumulate the result into the D matrix.
* \param[in] use_split_accumulator Whether to use split accumulator in the FP8 GEMM.
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] m_split Number of chunks/splits along m-dimension for Atomic GEMM.
* \param[in] n_split Number of chunks/splits along n-dimension for Atomic GEMM.
* \param[in] gemm_producer Whether Atomic GEMM is the producer or consumer.
* \param[in,out] counter counter[chunk_i]=0 indicates chunk_i has been produced.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cublas_atomic_gemm(const NVTETensor A,
const NVTETensor B,
NVTETensor D,
const NVTETensor bias,
NVTETensor pre_gelu_out,
bool transa,
bool transb,
bool grad,
NVTETensor workspace,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const NVTETensor counter,
cudaStream_t stream
);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -2262,6 +2262,8 @@ class MultiheadAttention(torch.nn.Module):
ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False,
ub_split_ag: bool = False,
ub_atomic_gemm_rs: bool = False,
ub_atomic_gemm_ag: bool = False,
bias: bool = True,
normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda",
......@@ -2342,6 +2344,7 @@ class MultiheadAttention(torch.nn.Module):
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag,
normalization=normalization,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
**common_gemm_kwargs,
)
else:
......@@ -2372,6 +2375,7 @@ class MultiheadAttention(torch.nn.Module):
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag,
normalization=normalization,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
**common_gemm_kwargs,
)
else:
......@@ -2418,6 +2422,8 @@ class MultiheadAttention(torch.nn.Module):
parallel_mode="row" if set_parallel_mode else None,
ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag,
ub_atomic_gemm_rs=ub_atomic_gemm_rs,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
**common_gemm_kwargs,
)
......
......@@ -91,22 +91,40 @@ def fp8_gemm(
assert ub is not None, 'ub object is None!'
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap
args = tuple(args + (1,))
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (1, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap
args = tuple(args + (0,))
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (0, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG:
fn = ub.atomic_gemm_overlap_ag
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS:
fn = ub.split_overlap_rs
assert (
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS requires extra output tensor'
args = tuple(args + (True, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS:
fn = ub.atomic_gemm_overlap_rs
assert (
extra_output_tensor is not None
), 'ATOMIC_GEMM_RS requires extra output tensor'
args = tuple(args + (True, extra_output_tensor,))
_ = fn(*args)
return out, gelu_input
......@@ -195,10 +213,10 @@ def gemm(
assert ub is not None, 'ub object is None!'
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap
args = tuple(args + (1,))
args = tuple(args + (1, empty_tensor))
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap
args = tuple(args + (0,))
args = tuple(args + (0, empty_tensor))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag
extra_output_tensor = (
......
......@@ -179,6 +179,32 @@ void te_gemm(at::Tensor A,
int math_sm_count
);
void te_atomic_gemm(at::Tensor A,
at::Tensor A_scale_inverse,
transformer_engine::DType A_type,
bool transa,
at::Tensor B,
at::Tensor B_scale_inverse,
transformer_engine::DType B_type,
bool transb,
at::Tensor D,
at::Tensor D_scale,
transformer_engine::DType D_type,
at::Tensor D_amax,
at::Tensor bias,
transformer_engine::DType bias_type,
at::Tensor pre_gelu_out,
bool grad,
at::Tensor workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
at::Tensor counter
);
void fused_cast_transpose(at::Tensor input,
at::Tensor scale,
......
......@@ -6,6 +6,7 @@
#include "extensions.h"
void te_gemm(at::Tensor A,
at::Tensor A_scale_inverse,
transformer_engine::DType A_type,
......@@ -73,3 +74,82 @@ void te_gemm(at::Tensor A,
math_sm_count,
at::cuda::getCurrentCUDAStream());
}
void te_atomic_gemm(at::Tensor A,
at::Tensor A_scale_inverse,
transformer_engine::DType A_type,
bool transa,
at::Tensor B,
at::Tensor B_scale_inverse,
transformer_engine::DType B_type,
bool transb,
at::Tensor D,
at::Tensor D_scale,
transformer_engine::DType D_type,
at::Tensor D_amax,
at::Tensor bias,
transformer_engine::DType bias_type,
at::Tensor pre_gelu_out,
bool grad,
at::Tensor workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
at::Tensor counter
) {
using namespace transformer_engine;
auto te_A = makeTransformerEngineTensor(A.data_ptr(),
{static_cast<size_t>(A.size(0)),
static_cast<size_t>(A.size(1))},
A_type, nullptr, nullptr,
A_scale_inverse.data_ptr());
auto te_B = makeTransformerEngineTensor(B.data_ptr(),
{static_cast<size_t>(B.size(0)),
static_cast<size_t>(B.size(1))},
B_type, nullptr, nullptr,
B_scale_inverse.data_ptr());
auto te_D = makeTransformerEngineTensor(D.data_ptr(),
{static_cast<size_t>(D.size(0)),
static_cast<size_t>(D.size(1))},
D_type, D_amax.data_ptr(),
D_scale.data_ptr(), nullptr);
auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), {static_cast<size_t>(bias.size(0))},
bias_type);
auto te_counter = makeTransformerEngineTensor(counter.data_ptr(),
{static_cast<size_t>(counter.size(0))},
DType::kInt32);
const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0))}
: std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0)),
static_cast<size_t>(pre_gelu_out.size(1))};
auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out.data_ptr(),
gelu_shape,
GetTransformerEngineDType(
pre_gelu_out.scalar_type()));
auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(),
{workspaceSize},
DType::kByte);
nvte_cublas_atomic_gemm(te_A.data(),
te_B.data(),
te_D.data(),
te_bias.data(),
te_pre_gelu_out.data(),
transa,
transb,
grad,
te_workspace.data(),
accumulate,
use_split_accumulator,
math_sm_count,
m_split,
n_split,
gemm_producer,
te_counter.data(),
at::cuda::getCurrentCUDAStream());
}
......@@ -91,18 +91,24 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG)
.value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS)
.value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS)
.value("SPLIT_PIPELINED_AG", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG);
.value("SPLIT_PIPELINED_AG", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG)
.value("ATOMIC_GEMM_RS", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS)
.value("ATOMIC_GEMM_AG", ubuf::UBOverlapAlgo::ATOMIC_GEMM_AG);
py::class_<ubuf::UbufCommOverlap>(m, "UbufCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, int, bool, int>())
.def(py::init<torch::Tensor&, int, int, int, int, int, bool, int, torch::Tensor>())
.def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap)
.def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs)
.def("set_ubuf_scale_inv", &ubuf::UbufCommOverlap::set_ubuf_scale_inv)
.def("atomic_gemm_overlap_rs", &ubuf::UbufCommOverlap::atomic_gemm_overlap_rs)
.def("is_fp8_ubuf", &ubuf::UbufCommOverlap::is_fp8_ubuf)
.def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output);
py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap")
.def(py::init<torch::Tensor&, int, int, bool, int>())
.def(py::init<torch::Tensor&, int, int, int, int, bool, bool, int, torch::Tensor>())
.def("split_overlap_ag", &ubuf::UbufP2PCommOverlap::split_overlap_ag)
.def("atomic_gemm_overlap_ag", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag)
.def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output);
#else // NVTE_WITH_USERBUFFERS
......
......@@ -4,10 +4,13 @@
* See LICENSE for license information.
************************************************************************/
#include "userbuffers.h"
#include <assert.h>
#include <chrono>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <immintrin.h>
#include <iostream>
#include <math.h>
#include <mpi.h>
#include <sched.h>
......@@ -15,9 +18,6 @@
#include <string.h>
#include <unistd.h>
#include <x86intrin.h>
#include <chrono>
#include <iostream>
#include "userbuffers.h"
static int oob_bcast(void *comm_context, void *buf, int size, int root) {
MPI_Bcast(buf, size, MPI_BYTE, root,
......@@ -47,6 +47,17 @@ int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (co
} \
} while (0)
#define CUCHECK(cmd) \
do { \
CUresult retval = cmd; \
if (retval != CUDA_SUCCESS) { \
const char *error_string; \
cuGetErrorString(retval, &error_string); \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, error_string); \
exit(EXIT_FAILURE); \
} \
} while (0);
#define NVTE_UB_ERROR(x) \
do { \
throw std::runtime_error(std::string(__FILE__ ":") + std::to_string(__LINE__) + \
......@@ -89,12 +100,14 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
(*comm)->push = 1;
(*comm)->use_ce = 0;
(*comm)->cga_size = 2;
for (int i = 0; i < userbuffers_op_types; i++) (*comm)->basecounter[i] = 0;
for (int i = 0; i < userbuffers_op_types; i++)
(*comm)->basecounter[i] = 0;
(*comm)->head = 0;
(*comm)->tail = 0;
(*comm)->activeproxy = 1;
(*comm)->active_nreqs = 0;
for (int i = 0; i < userbuffers_op_types; i++) (*comm)->active_req[i].active = -1;
for (int i = 0; i < userbuffers_op_types; i++)
(*comm)->active_req[i].active = -1;
int ret = 0;
// split communicator
......@@ -112,8 +125,10 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
color = 0;
for (int n = 0; n < size; n++) {
if (n > 0 && strcmp(host_names[n - 1], host_names[n])) color++;
if (strcmp(host_name, host_names[n]) == 0) break;
if (n > 0 && strcmp(host_names[n - 1], host_names[n]))
color++;
if (strcmp(host_name, host_names[n]) == 0)
break;
}
free(host_names);
......@@ -128,14 +143,22 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
int core;
if (mylocal == 0) core = 50;
if (mylocal == 1) core = 58;
if (mylocal == 2) core = 18;
if (mylocal == 3) core = 26;
if (mylocal == 4) core = 114;
if (mylocal == 5) core = 122;
if (mylocal == 6) core = 82;
if (mylocal == 7) core = 90;
if (mylocal == 0)
core = 50;
if (mylocal == 1)
core = 58;
if (mylocal == 2)
core = 18;
if (mylocal == 3)
core = 26;
if (mylocal == 4)
core = 114;
if (mylocal == 5)
core = 122;
if (mylocal == 6)
core = 82;
if (mylocal == 7)
core = 90;
CPU_SET(core, &cpuset);
if (!getenv("NVTE_NODOUBLE")) {
......@@ -144,7 +167,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
else
CPU_SET(core + 128, &cpuset);
}
if (getenv("NVTE_DOPIN")) pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset);
if (getenv("NVTE_DOPIN"))
pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset);
if (ndev == numlocal) { // all visible devices
if (cur_dev != mylocal)
......@@ -175,7 +199,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
int datanodegroup_id =
myrank / numlocal / datanodes; // data reduction group node belongs, equals 0 for all if both
// pipenodes=1 and tensornodes=1
// mpi communicator only needed for SHARP which is always allreduce1/data-parallel
// mpi communicator only needed for SHARP which is always
// allreduce1/data-parallel
MPI_Comm_split(MPI_COMM_WORLD, mylocal + numlocal * datanodegroup_id, rank, &(*comm)->comm_inter);
// different rails from same group are in different subcommunicators
......@@ -192,19 +217,37 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
char *ib_dev_list;
int ZIONROCE = getenv("NVTE_ZIONROCE") ? atoi(getenv("NVTE_ZIONROCE")) : 0;
int ROCE = getenv("NVTE_ROCE") ? atoi(getenv("NVTE_ROCE")) : 0;
if (ZIONROCE) ROCE = 1;
if (ZIONROCE)
ROCE = 1;
int DGX_H100 = device_prop.major == 9;
switch (mylocal) {
case 0:ib_dev_list = "mlx5_0:1"; break; // NOLINT(*)
case 1:ib_dev_list = (char*)(DGX_H100?"mlx5_3:1":"mlx5_1:1"); break; // NOLINT(*)
case 2:ib_dev_list = (char*)(ZIONROCE?"mlx5_4:1":DGX_H100?"mlx5_4:1":"mlx5_2:1"); break; // NOLINT(*)
case 3:ib_dev_list = (char*)(DGX_H100?"mlx5_5:1":"mlx5_3:1"); break; // NOLINT(*)
case 4:ib_dev_list = (char*)(DGX_H100?"mlx5_6:1":"mlx5_6:1"); break; // NOLINT(*)
case 5:ib_dev_list = (char*)(DGX_H100?"mlx5_9:1":"mlx5_7:1"); break; // NOLINT(*)
case 6:ib_dev_list = (char*)(ZIONROCE?"mlx5_10:1":DGX_H100?"mlx5_10:1":"mlx5_8:1"); break; // NOLINT(*)
case 7:ib_dev_list = (char*)(DGX_H100?"mlx5_11:1":"mlx5_9:1"); break; // NOLINT(*)
default: break;
case 0:
ib_dev_list = "mlx5_0:1";
break; // NOLINT(*)
case 1:
ib_dev_list = (char *)(DGX_H100 ? "mlx5_3:1" : "mlx5_1:1"); // NOLINT(*)
break; // NOLINT(*)
case 2:
ib_dev_list = (char *)(ZIONROCE ? "mlx5_4:1" : DGX_H100 ? "mlx5_4:1" : "mlx5_2:1"); // NOLINT(*)
break; // NOLINT(*)
case 3:
ib_dev_list = (char *)(DGX_H100 ? "mlx5_5:1" : "mlx5_3:1"); // NOLINT(*)
break; // NOLINT(*)
case 4:
ib_dev_list = (char *)(DGX_H100 ? "mlx5_6:1" : "mlx5_6:1"); // NOLINT(*)
break; // NOLINT(*)
case 5:
ib_dev_list = (char *)(DGX_H100 ? "mlx5_9:1" : "mlx5_7:1"); // NOLINT(*)
break; // NOLINT(*)
case 6:
ib_dev_list = (char *)(ZIONROCE ? "mlx5_10:1" : DGX_H100 ? "mlx5_10:1" : "mlx5_8:1"); // NOLINT(*)
break; // NOLINT(*)
case 7:
ib_dev_list = (char *)(DGX_H100 ? "mlx5_11:1" : "mlx5_9:1"); // NOLINT(*)
break; // NOLINT(*)
default:
break;
}
(*comm)->fifo = reinterpret_cast<ub_request *>(malloc(sizeof(ub_request) * NVTE_MAX_REQUESTS));
......@@ -215,7 +258,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
CUDACHECK(cudaMallocHost((void **)&(*comm)->hostflags, // NOLINT(*)
(NVTE_MAX_SMS + 100) * sizeof(int)));
for (int i = 0; i < 100 + NVTE_MAX_SMS; i++) (*comm)->hostflags[i] = 0;
for (int i = 0; i < 100 + NVTE_MAX_SMS; i++)
(*comm)->hostflags[i] = 0;
_mm_mfence();
sleep(1);
......@@ -223,13 +267,16 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
(*comm)->ibnvsize = (*comm)->nvsize;
#define NBUF 2
#define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF)
// peer pointers + op flags + comm buffer
CUDACHECK(cudaMalloc(&(*comm)->gpu_ptrs, LOCALSIZE)); // flags and pointers, no block data yet
CUDACHECK(cudaMalloc(&(*comm)->gpu_ptrs,
LOCALSIZE)); // flags and pointers, no block data yet
CUDACHECK(cudaMemset((*comm)->gpu_ptrs, 0, LOCALSIZE));
CUDACHECK(cudaDeviceSynchronize());
register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm); // will use handler 0
register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE,
*comm); // will use handler 0
CUDACHECK(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int)));
CUDACHECK(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int)));
CUDACHECK(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int)));
......@@ -243,7 +290,6 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
#define GPU_PAGE_MASK (~GPU_PAGE_OFFSET)
CUDACHECK(cudaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE));
unsigned int flag = 1;
// cuPointerSetAttribute(&flag, CU_POINTER_ATTRIBUTE_SYNC_MEMOPS, (CUdeviceptr)(*comm)->flags);
CUDACHECK(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE));
(*comm)->flags =
reinterpret_cast<int *>(((CUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK);
......@@ -275,7 +321,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
pthread_attr_setschedparam(&attr, &param);
if (getenv("NVTE_UBDEBUG"))
printf("%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP %dx%d TPGROUP %dx%d PIPE_ID %d/%d\n",
printf("%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP %dx%d TPGROUP "
"%dx%d PIPE_ID %d/%d\n",
myrank, nranks, myrank / numlocal, myrank % numlocal, (*comm)->my_node,
(*comm)->ar_nvrank, (*comm)->my2_node, (*comm)->ar2_nvrank, (*comm)->num_nodes,
(*comm)->ar_nvsize, (*comm)->num2_nodes, (*comm)->ar2_nvsize, (*comm)->pipe_id,
......@@ -300,9 +347,9 @@ void destroy_communicator(communicator *comm) {
}
int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) {
if (comm->free_region > NVTE_MAX_REGIONS) return -1;
if (comm->free_region > NVTE_MAX_REGIONS)
return -1;
int hndl = comm->free_region;
// printf("%d register %d size %lld\n",comm->myrank,hndl,bytes);fflush(NULL);
comm->peer_ptr[hndl] = reinterpret_cast<void **>(malloc(sizeof(void *) * (comm->nvsize)));
if (alloc) {
......@@ -313,25 +360,22 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
reinterpret_cast<cudaIpcMemHandle_t *>(malloc(sizeof(cudaIpcMemHandle_t) * (comm->nvsize)));
CUDACHECK(cudaIpcGetMemHandle(&memhndl[comm->nvrank], *gpubuff));
MPI_Allgather(&memhndl[comm->nvrank], sizeof(cudaIpcMemHandle_t), MPI_BYTE, memhndl,
sizeof(cudaIpcMemHandle_t), MPI_BYTE, comm->comm_intra);
for (int i = 0; i < comm->nvsize; i++)
if (i != comm->nvrank)
CUDACHECK(cudaIpcOpenMemHandle((void **)&(comm->peer_ptr[hndl][i]), // NOLINT(*)
memhndl[i], cudaIpcMemLazyEnablePeerAccess));
comm->peer_ptr[hndl][comm->nvrank] = *gpubuff;
CUDACHECK(cudaDeviceSynchronize());
CUDACHECK(
cudaMemcpy(reinterpret_cast<char *>(comm->gpu_ptrs) + (hndl * comm->nvsize * sizeof(void *)),
comm->peer_ptr[hndl], comm->nvsize * sizeof(void *), cudaMemcpyHostToDevice));
CUDACHECK(cudaDeviceSynchronize());
free(memhndl);
comm->mem_ptr[hndl] = *gpubuff;
return comm->free_region++;
}
......@@ -352,8 +396,10 @@ int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons
void allreduce_nonsharp_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream, int op) {
if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
// if(comm->myrank==0) fprintf(stderr,"AR2(%d) user call launch_mode=%d\n",op,comm->launch_mode);
if (elements < 64)
NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
// if(comm->myrank==0) fprintf(stderr,"AR2(%d) user call
// launch_mode=%d\n",op,comm->launch_mode);
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
int blocksize = elements * 2;
int maxcredit = 0;
......@@ -361,19 +407,19 @@ void allreduce_nonsharp_inplace(const int handler, const int offset, const int e
blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) /
comm->nblocks; // FIXME TUNING
blocksize *= comm->alignblock;
if (blocksize < comm->minblock) blocksize = comm->minblock;
if (blocksize < comm->minblock)
blocksize = comm->minblock;
maxcredit = (elements * 2 + blocksize - 1) / blocksize;
// if(maxcredit>4) maxcredit=4;
// if(maxcredit>4 && ar_nvsize==1) maxcredit=4;
size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit
if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize;
// blocksize=elements*2;
if (blocksize > peerblock * ar_nvsize)
blocksize = peerblock * ar_nvsize;
int sms = allreduce2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, comm,
stream, op);
if (num_nodes > 1 && comm->launch_mode & NVTE_LAUNCH_CPU) {
if (!sms) return;
if (!sms)
return;
comm->fifo[comm->head].optype = op;
comm->fifo[comm->head].basecounter = comm->basecounter[op];
comm->fifo[comm->head].blocksize = blocksize;
......@@ -399,7 +445,8 @@ void allreduce2_userbuff_inplace(const int handler, const int offset, const int
void allreduce_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
if (elements < 64)
NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
allreduce_nonsharp_inplace(handler, offset, elements, comm, stream,
userbuffers_allreduceop_nonsharp);
return;
......@@ -407,7 +454,8 @@ void allreduce_userbuff_inplace(const int handler, const int offset, const int e
void reducescatter_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
if (elements < 64)
NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
int op = userbuffers_allreduceop_nonsharp;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
......@@ -418,17 +466,20 @@ void reducescatter_userbuff_inplace(const int handler, const int offset, const i
blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) /
comm->nblocks; // FIXME TUNING
blocksize *= comm->alignblock;
if (blocksize < comm->minblock) blocksize = comm->minblock;
if (blocksize < comm->minblock)
blocksize = comm->minblock;
maxcredit = (elements * 2 + blocksize - 1) / blocksize;
size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit
if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize;
if (blocksize > peerblock * ar_nvsize)
blocksize = peerblock * ar_nvsize;
int sms = reducescatter2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize,
comm, stream, op);
if (num_nodes > 1 && comm->launch_mode & NVTE_LAUNCH_CPU) {
if (!sms) return;
if (!sms)
return;
comm->fifo[comm->head].optype = op;
comm->fifo[comm->head].basecounter = comm->basecounter[op];
comm->fifo[comm->head].blocksize = blocksize;
......@@ -448,7 +499,8 @@ void reducescatter_userbuff_inplace(const int handler, const int offset, const i
void allgather_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
if (elements < 64)
NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
int op = userbuffers_allreduceop_nonsharp;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
int blocksize = elements * 2;
......@@ -458,11 +510,13 @@ void allgather_userbuff_inplace(const int handler, const int offset, const int e
blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) /
comm->nblocks; // FIXME TUNING
blocksize *= comm->alignblock;
if (blocksize < comm->minblock) blocksize = comm->minblock;
if (blocksize < comm->minblock)
blocksize = comm->minblock;
maxcredit = (elements * 2 + blocksize - 1) / blocksize;
size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit
if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize;
if (blocksize > peerblock * ar_nvsize)
blocksize = peerblock * ar_nvsize;
int sms = allgather2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, comm,
stream, op);
......
......@@ -24,6 +24,18 @@
#define NVTE_LAUNCH_CPU 2
#define NVTE_MAX_NVLINK 8
#define UB_MEM_UC_CONTIG 1
#define UB_MEM_MC_CREATED 2
#define UB_MEM_ALLOCATED 4
#define NVTE_UB_MEM_UC_CONTIG 1
#define NVTE_UB_MEM_MC_CREATED 2
#define NVTE_UB_MEM_ALLOCATED 4
#ifdef UCP
#include <ucp/api/ucp.h>
#endif
// region 0 flag offsets
#define NVTE_REG0_OPFLAGS 1024
#define NVTE_REG0_RECV (NVTE_REG0_OPFLAGS * userbuffers_op_types)
......@@ -35,6 +47,10 @@
#define NVTE_REG0_IBRS 32
#define NVTE_REG0_IBAG 512
#if defined(UCP) || !defined(NOSHARP)
#undef REG0_COMMBUFFER
#define REG0_COMMBUFFER (1024*1024*16)
#endif
// gpuflags map offsets
#define NVTE_GF_STATE 16000
#define NVTE_GF_IBSHARPDONE 0
......@@ -81,6 +97,19 @@ struct communicator {
void *mem_ptr[NVTE_MAX_REGIONS];
void **peer_ptr[NVTE_MAX_REGIONS];
int memflags[NVTE_MAX_REGIONS]; // UC,MC, user/lib allocated
CUmemGenericAllocationHandle *uchandles[NVTE_MAX_REGIONS];
void* ucbase_ptr[NVTE_MAX_REGIONS]; // only for cuMem allocated memory
size_t mem_size[NVTE_MAX_REGIONS];
void* mc_ptr[NVTE_MAX_REGIONS];
void* mc_baseptr;
CUmemGenericAllocationHandle mc_handle;
size_t mc_offset, mc_maxsize;
int use_mc; // 1: use MC if available, 0: override not to use MC
int ar_nvsize, ar_firstgpu,
ar_nvrank; // number of gpus(and first gpu in a group) of gpus per node in reduction subgroup
// (_splitar init used) would be equal to (nvsize,0) for regular comm_create
......@@ -120,6 +149,8 @@ struct communicator {
};
typedef struct communicator communicator;
void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream);
void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream);
int create_communicator(communicator **comm);
/* creates communicator, allocates all internal buffers if necessary */
......@@ -191,6 +222,45 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons
const int rowelements, const int colelements,
const int strideelements, communicator *comm,
cudaStream_t stream = 0);
template<typename fp8type>
void reducescatter2_userbuff_stridedoutput_fp8(void* output, float* scale, const int handler,
const int offset, const int rowelements,
const int colelements, const int strideelements,
communicator* comm, cudaStream_t stream = 0);
template<typename fp8type>
void reducescatter2_userbuff_fp8(void* output, float* scale, const int handler, const int offset,
const int elements, communicator* comm, cudaStream_t stream = 0);
#if 0
template<typename fp8type>
void reducescatter2_userbuff_strided_atomic_fp8(void* output, float *scale, const int handler,
const int offset, const int rowelements,
const int colelements, const int strideelements,
const int numchunks, void *counters,
communicator* comm, cudaStream_t stream = 0);
#endif
template<typename fp8type>
void reducescatter2_userbuff_strided_atomic_fp8(void* output, float *scale, const int handler,
const int offset, const int rowelements,
const int colelements, const int strideelements_out,
const int strideelements_in, const int numchunks,
void *counters, communicator* comm,
cudaStream_t stream = 0);
template<typename fp8type>
void reducescatter2_userbuff_strided_multiatomic_fp8(
void* output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator* comm, cudaStream_t stream = 0);
void reducescatter2_userbuff_strided(
void* output, const int handler, const int offset, const int rowelements, const int colelements,
const int strideelements, communicator* comm, cudaStream_t stream = 0);
void reducescatter2_userbuff_strided_atomic(
void* output, const int handler , const int offset, const int rowelements, const int colelements,
const int strideelements, const int numchunks, void *counters, communicator* comm,
cudaStream_t stream = 0);
void reducescatter2_userbuff_strided_multiatomic(
void* output, const int handler, const int offset, const int rowelements, const int colelements,
const int strideelements, const int numchunks, void *counters, communicator* comm,
cudaStream_t stream = 0);
/* everything should be 16byte aligned = 8 elts aligned
output is strided: row starts separated by stride elements*/
......@@ -208,6 +278,19 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds
void userbuffers_recv(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes, communicator *comm,
const int peer, cudaStream_t stream = 0);
void userbuffers_sendrecv(
const int srchandler, const int dsthandler, const size_t send_offset, const size_t recv_offset,
const size_t bytes, communicator* comm, const int send_peer, const int recv_peer,
cudaStream_t stream = 0);
void userbuffers_sendrecv_atomic(
const int srchandler, const int dsthandler, const size_t send_offset, const size_t recv_offset,
const size_t bytes, communicator* comm, const int send_peer, const int recv_peer, void *counters,
cudaStream_t stream = 0);
void userbuffers_sendrecv_multiatomic(
const int srchandler, const int dsthandler, const size_t send_offset, const size_t recv_offset,
const size_t bytes, communicator* comm, const int send_peer, const int recv_peer,
const int nchunks, void *counters, bool shuffle, cudaStream_t stream = 0);
// alltoall split send and recv to allow for overlap
// send kicks in sending data to the destination - invoke on same stream as data generation
......
......@@ -124,6 +124,8 @@ def initialize_ub(
fp8_buf = [
"qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad"
]
if bool(int(os.getenv("NVTE_UB_FP8_RS", "0"))):
fp8_buf.append ("proj_fprop")
# Default overlap methods for layers
methods = {
"ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
......@@ -153,8 +155,12 @@ def initialize_ub(
sample_buffer, # Sample userbuffer
rank_id, # Rank id
tp_size, # TP size
num_sm, # Number of communication SMs
cga_size, # CGA cluster size
set_sm_margin, # Set SM margin
aggregate, # Aggregate 2X GEMM chunks
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
torch.Tensor(), # empty tensor to pass to counters
)
else:
ub_obj = tex.UbufCommOverlap(
......@@ -166,6 +172,7 @@ def initialize_ub(
num_splits, # Number of communication splits
set_sm_margin, # Set SM margin
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
torch.Tensor(), # empty tensor to pass to counters
)
_ub_communicators[name] = ub_obj
......@@ -676,10 +683,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_output_mat = grad_output.view((-1, grad_output.shape[-1]))
gather_grad_output = row_parallel_mode and ctx.sequence_parallel
if gather_grad_output:
ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag
# No-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8:
if gather_grad_output:
if not ctx.ub_split_ag:
if not ub_overlap_ag:
grad_output_mat, _ = gather_along_first_dim(
grad_output_mat, ctx.tp_group
)
......@@ -698,8 +707,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
and ctx.fp8_meta["recipe"].override_linear_precision.wgrad
):
assert (
not ctx.ub_split_ag
), "override_linear_precision.wgrad not supported with ub_split_ag"
not ub_overlap_ag
), "override_linear_precision.wgrad not supported with UB AG overlap"
grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group)
# FP8 case with gather: unfused bgrad, cast, transpose for efficient gather
elif gather_grad_output:
......@@ -707,7 +716,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_bias = grad_output_mat.sum(dim=0)
else:
grad_bias = None
if ctx.ub_split_ag:
if ub_overlap_ag:
grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0)
else:
grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8)
......@@ -718,7 +727,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
fp8_dtype_backward,
out=grad_output_c,
)
if not ctx.ub_split_ag:
if not ub_overlap_ag:
grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group)
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
else:
......
......@@ -83,6 +83,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_bulk_dgrad: bool,
ub_split_ag: bool,
normalization: str,
ub_atomic_gemm_ag: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -100,11 +101,12 @@ class _LayerNormLinear(torch.autograd.Function):
if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype)
if ub_split_ag:
if ub_split_ag or ub_atomic_gemm_ag:
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output:
ub_split_ag = False
if ub_split_ag:
ub_atomic_gemm_ag = False
if ub_split_ag or ub_atomic_gemm_ag:
dim_size = list(inputmat.size())
dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("qkv_fprop")
......@@ -112,6 +114,8 @@ class _LayerNormLinear(torch.autograd.Function):
else:
ln_out_dtype = torch.uint8 if fp8 else inputmat.dtype
ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype)
if ub_atomic_gemm_ag:
assert fp8, "AtomicGemm overlap supported only for FP8 GEMM."
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
......@@ -139,7 +143,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward,
)
# Column Parallel Linear
if ub_split_ag:
if ub_split_ag or ub_atomic_gemm_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out)
elif parallel_mode == "column" and sequence_parallel:
......@@ -173,6 +177,8 @@ class _LayerNormLinear(torch.autograd.Function):
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward)
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo
out, _ = tex.fp8_gemm(
weight_fp8,
fp8_meta["scaling_fwd"].scale_inv,
......@@ -187,9 +193,9 @@ class _LayerNormLinear(torch.autograd.Function):
bias=bias,
use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None,
ub=ub_obj_lnout if ub_split_ag else None,
extra_output_tensor=ln_out if ub_split_ag else None,
ub_algo=ub_algo,
ub=ub_obj_lnout if (ub_split_ag or ub_atomic_gemm_ag) else None,
extra_output_tensor=ln_out if (ub_split_ag or ub_atomic_gemm_ag) else None,
)
else:
# Cast for native AMP
......@@ -339,6 +345,14 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
out_index, meta_tensor, out_te_type, out_type = (
None, None, None, ctx.activation_dtype)
if ctx.ub_bulk_wgrad and ub_obj_dgrad.is_fp8_ubuf():
out_index = tex.FP8BwdTensors.GRAD_INPUT1
meta_tensor = ctx.fp8_meta["scaling_bwd"]
out_te_type = fp8_dtype_backward
out_type = torch.uint8
ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index])
# DGRAD: Evaluated unconditionally to feed into Linear backward
_ = tex.fp8_gemm(
......@@ -350,12 +364,15 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.activation_dtype,
out_type,
get_workspace(),
out=dgrad,
use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None,
out_index=out_index,
fp8_meta_tensor = meta_tensor,
D_dtype = out_te_type,
)
else:
# DGRAD: Evaluated unconditionally to feed into Linear backward
......@@ -387,6 +404,15 @@ class _LayerNormLinear(torch.autograd.Function):
if weight.requires_grad:
if ctx.fp8:
# WGRAD
extra_output_tensor = None
if ctx.ub_bulk_wgrad:
if ub_obj_dgrad.is_fp8_ubuf():
dim_size = list(ub_obj_dgrad.get_ubuf_output(0).size()) # RS output
extra_output_tensor = torch.empty(
dim_size, dtype=ctx.activation_dtype, device=dgrad.device)
dgrad = extra_output_tensor
else:
dgrad = ub_obj_dgrad.get_ubuf_output(0)
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
wgrad, _ = tex.fp8_gemm(
......@@ -405,7 +431,8 @@ class _LayerNormLinear(torch.autograd.Function):
use_split_accumulator=_2X_ACC_WGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor
)
else:
ln_out_total_c = tex.cast_from_fp8(
......@@ -426,7 +453,8 @@ class _LayerNormLinear(torch.autograd.Function):
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor
)
else:
# WGRAD
......@@ -443,11 +471,14 @@ class _LayerNormLinear(torch.autograd.Function):
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
)
if ctx.ub_bulk_wgrad:
dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output
# Column Parallel Linear
elif ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
if ((not ctx.ub_bulk_wgrad)
and ctx.parallel_mode == "column"
and ctx.tensor_parallel
and handle is not None):
handle.wait()
# LayerNorm gradient
......@@ -512,6 +543,7 @@ class _LayerNormLinear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -624,6 +656,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
ub_bulk_dgrad: bool = False,
ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda",
ub_atomic_gemm_ag: bool = False,
) -> None:
super().__init__()
......@@ -650,12 +683,18 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_ag = ub_split_ag
self.ub_atomic_gemm_ag = ub_atomic_gemm_ag
if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_ag:
if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_ag or ub_atomic_gemm_ag:
assert (
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
if ub_atomic_gemm_ag:
warnings.warn(
"Atomic gemm uses a beta API from cublas and is not tested for all use cases."
)
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
......@@ -919,6 +958,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.ub_bulk_dgrad,
self.ub_split_ag,
self.normalization,
self.ub_atomic_gemm_ag,
)
out = fwd_fn(*args)
......
......@@ -4,6 +4,7 @@
"""LayerNormMLP API"""
import os
import warnings
from typing import Union, Optional, Callable, Tuple, List, Dict, Any
import torch
......@@ -107,7 +108,9 @@ class _LayerNormMLP(torch.autograd.Function):
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
ub_split_rs: bool,
ub_atomic_gemm_rs: bool,
ub_split_ag: bool,
ub_atomic_gemm_ag: bool,
activation: str,
normalization: str,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
......@@ -130,20 +133,25 @@ class _LayerNormMLP(torch.autograd.Function):
if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype)
if ub_split_ag:
if ub_split_ag or ub_atomic_gemm_ag:
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output:
ub_split_ag = False
if ub_split_ag:
ub_atomic_gemm_ag = False
ub_overlap_ag = ub_split_ag or ub_atomic_gemm_ag
if ub_overlap_ag:
ub_obj_lnout = get_ub("fc1_fprop")
ln_out = ub_obj_lnout.get_ubuf_output(0)
else:
ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype
ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype)
if ub_split_rs:
if ub_split_rs or ub_atomic_gemm_rs:
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1:
ub_split_rs = False
ub_atomic_gemm_rs = False
if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
assert fp8, "AtomicGemm overlap supported only for FP8 GEMM."
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
......@@ -171,7 +179,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward,
)
# Column Parallel Linear
if ub_split_ag:
if ub_overlap_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out)
elif set_parallel_mode and sequence_parallel:
......@@ -223,6 +231,8 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward,
)
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo
fc1_out, _ = tex.fp8_gemm(
fc1_weight_fp8,
fp8_meta["scaling_fwd"].scale_inv,
......@@ -237,9 +247,9 @@ class _LayerNormMLP(torch.autograd.Function):
bias=fc1_bias,
use_bias=use_fc1_bias,
use_split_accumulator=_2X_ACC_FPROP,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None,
ub=ub_obj_lnout if ub_split_ag else None,
extra_output_tensor=ln_out if ub_split_ag else None,
ub_algo=ub_algo,
ub=ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if ub_overlap_ag else None,
)
gelu_out = activation_func(
......@@ -249,18 +259,29 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward,
)
if ub_split_rs:
fc2_out_index, fc2_meta_tensor, fc2_te_type, out_type = (
None, None, None, activation_dtype)
if ub_split_rs or ub_atomic_gemm_rs:
ub_obj_fc2out = get_ub("fc2_fprop")
fc2_out = ub_obj_fc2out.get_ubuf_output(1)
dim_size = list(gelu_out.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = fc2_weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
if ub_obj_fc2out.is_fp8_ubuf():
fc2_out_index = tex.FP8FwdTensors.GEMM2_OUTPUT
fc2_meta_tensor = fp8_meta["scaling_fwd"]
fc2_te_type = fp8_dtype_forward
out_type = torch.uint8
ub_obj_fc2out.set_ubuf_scale_inv(fc2_meta_tensor.scale_inv[fc2_out_index])
else:
dim_size = list(gelu_out.size())
dim_size[1] = fc2_weight.size(0)
fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo
_ = tex.fp8_gemm(
fc2_weight_fp8,
fp8_meta["scaling_fwd"].scale_inv,
......@@ -270,15 +291,18 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype_forward,
activation_dtype,
out_type,
get_workspace(),
bias=fc2_bias,
use_bias=use_fc2_bias,
use_split_accumulator=_2X_ACC_FPROP,
out=fc2_out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None,
ub=ub_obj_fc2out if ub_split_rs else None,
extra_output_tensor=rs_out if ub_split_rs else None,
ub_algo=ub_algo,
ub=ub_obj_fc2out if ub_split_rs or ub_atomic_gemm_rs else None,
extra_output_tensor=rs_out if ub_split_rs or ub_atomic_gemm_rs else None,
out_index=fc2_out_index,
fp8_meta_tensor = fc2_meta_tensor,
D_dtype = fc2_te_type,
)
else:
# Cast for native AMP
......@@ -394,11 +418,12 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.ub_bulk_wgrad = ub_bulk_wgrad
ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.ub_split_ag = ub_split_ag
ctx.ub_atomic_gemm_ag = ub_atomic_gemm_ag
ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization
# Row Parallel Linear
if ub_split_rs:
if ub_split_rs or ub_atomic_gemm_rs:
fc2_out = rs_out
elif set_parallel_mode and sequence_parallel:
fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group)
......@@ -447,11 +472,15 @@ class _LayerNormMLP(torch.autograd.Function):
dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("fc1_dgrad")
ub_obj_lnout.copy_input_to_ubuf(ln_out, 1)
if ctx.ub_split_ag:
ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag
if ub_overlap_ag:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_split_ag = False
if ctx.ub_split_ag:
ctx.ub_overlap_ag = False
ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag
if ub_overlap_ag:
dim_size = list(grad_outputs[0].size())
dim_size[0] = dim_size[0] * tp_world_size
ctx.ub_obj_gradout = get_ub("fc2_dgrad")
......@@ -497,6 +526,8 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fp8_meta["recipe"], fprop_tensor=False
)
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo
# FC2 DGRAD; Unconditional
fc2_dgrad, _ = tex.fp8_gemm(
fc2_weight_t_fp8,
......@@ -510,10 +541,10 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None,
ub_algo=ub_algo,
ub=ctx.ub_obj_gradout if ub_overlap_ag else None,
)
if ctx.ub_split_ag:
if ub_overlap_ag:
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
# FC2 WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
......@@ -595,11 +626,19 @@ class _LayerNormMLP(torch.autograd.Function):
)
dgelu_t = None
out_index, meta_tensor, out_te_type, out_type = (
None, None, None, ctx.activation_dtype)
fc1_dgrad_size = list(dgelu.size())
fc1_dgrad_size[1] = fc1_weight.size(1)
if ctx.ub_bulk_wgrad: # allocate dgrad output
ub_obj_dgrad = get_ub("fc1_wgrad")
fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
if ub_obj_dgrad.is_fp8_ubuf():
out_index = tex.FP8BwdTensors.GRAD_INPUT2
meta_tensor = ctx.fp8_meta["scaling_bwd"]
out_te_type = fp8_dtype_backward
out_type = torch.uint8
ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index])
else:
fc1_dgrad = torch.empty(
fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device
......@@ -614,12 +653,15 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT2,
fp8_dtype_backward,
ctx.activation_dtype,
out_type,
get_workspace(),
out=fc1_dgrad,
use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None,
out_index=out_index,
fp8_meta_tensor = meta_tensor,
D_dtype = out_te_type,
)
else:
# FC2 DGRAD; Unconditional
......@@ -703,6 +745,15 @@ class _LayerNormMLP(torch.autograd.Function):
if fc1_weight.requires_grad:
if ctx.fp8:
# FC1 WGRAD
extra_output_tensor = None
if ctx.ub_bulk_wgrad:
if ub_obj_dgrad.is_fp8_ubuf():
dim_size = list(ub_obj_dgrad.get_ubuf_output(0).size()) # RS output
extra_output_tensor = torch.empty(
dim_size, dtype=ctx.activation_dtype, device=fc1_dgrad.device)
fc1_dgrad = extra_output_tensor
else:
fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0)
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
fc1_wgrad, _ = tex.fp8_gemm(
......@@ -724,6 +775,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor,
)
else:
ln_out_total_c = tex.cast_from_fp8(
......@@ -747,6 +799,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor,
)
else:
# FC1 WGRAD
......@@ -768,11 +821,14 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_wgrad, _, _ = fc1_wgrad_outputs
else:
fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs
# Column Parallel Linear
if ctx.ub_bulk_wgrad:
fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output
elif ctx.set_parallel_mode and ctx.tensor_parallel and handle is not None:
# Column Parallel Linear
if ((not ctx.ub_bulk_wgrad)
and ctx.set_parallel_mode
and ctx.tensor_parallel
and handle is not None):
handle.wait()
# LayerNorm gradient
......@@ -850,6 +906,8 @@ class _LayerNormMLP(torch.autograd.Function):
None,
None,
None,
None,
None,
)
......@@ -965,8 +1023,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False,
ub_atomic_gemm_rs: bool = False,
ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda",
ub_atomic_gemm_ag: bool = False,
) -> None:
super().__init__()
......@@ -987,12 +1047,24 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_rs = ub_split_rs
self.ub_split_ag = ub_split_ag
if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_rs or ub_split_ag:
self.ub_atomic_gemm_rs = ub_atomic_gemm_rs
self.ub_atomic_gemm_ag = ub_atomic_gemm_ag
if (ub_bulk_wgrad # pylint: disable=too-many-boolean-expressions
or ub_bulk_dgrad
or ub_split_rs
or ub_split_ag
or ub_atomic_gemm_rs
or ub_atomic_gemm_ag):
assert (
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
warnings.warn(
"Atomic gemm uses a beta API from cublas and is not tested for all use cases."
)
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
......@@ -1210,7 +1282,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_split_rs,
self.ub_atomic_gemm_rs,
self.ub_split_ag,
self.ub_atomic_gemm_ag,
self.activation,
self.normalization,
)
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment