Unverified Commit 958e1889 authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

Atomic gemm and FP8 Reduce Scatter (#449)



* Initial commit
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Repro for RS output mismatch with Single GEMM + Split pipelined RS
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* minor changes for AG->GEMM pipelined overlap
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Add Atomic Gemm cublasApi attributes and initial implementation of AG->Atomic GEMM
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* AtomicGemm+RS functional with workaround
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* add amax update to layernorm_linear for FP8 unit test accuracy
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Enable reducescatter2_userbuff_strided variants
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fix
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* AG+AtomicGemm overlap functional but gemm doesnt overlap with comm
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Add userbuffers_sendrecv kernel variants
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* TransformerLayer API changes to enable AtomicGemm+RS overlap
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Code cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Code cleanup2
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [UB] AllGather Atomic GEMM overlap using userbuffer_sendrecv kernels
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Code cleanup + bug fix for multiatomic sendrecv kernel
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fixes
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [UB] Add shuffling for better AG AtomicGEMM overlap
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fix for AG AtomicGemm overlap
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fix for multiAtomicAG and singleAtomicAG
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Use chunk_i+1 as recv_chunk for multiatomic_AG with shuffling
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Launch AtomicGEMM after first-chunk AG
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Rebase to main
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Add FP8 ReduceScatter kernels, AtomicGEMM+FP8 RS not functional
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Revert "Add FP8 ReduceScatter kernels, AtomicGEMM+FP8 RS not functional"

This reverts commit 80a47a76355440cd5fb4314c96fe9fda632d87f9.
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Add support for NVLS-MC and FP8 Reduce Scatter
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fix
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Atomic and Multiatomic FP8 RS functional
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Remove debug print
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* UB comm initialization hang fix
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Code cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Create new GEMM API for Atomic GEMM
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* CI ready
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* more fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* license
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fix
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Revert NVLS-MC
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Check cu* versions for running atomic gemms
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Add experimental warning
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Better wording
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add warning to c api
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix wording
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent be67f219
......@@ -506,7 +506,7 @@ def test_export_gemm(
self.fp8_tensor_weight,
self.weights_type)
ret = fp8_gemm(
ret, _ = fp8_gemm(
weight_fp8,
self.meta_weight.scale_inv,
self.fp8_tensor_weight,
......@@ -1323,7 +1323,7 @@ def test_export_gemm_layernorm(
self.fp8_tensor_weight,
self.weights_type)
ret = fp8_gemm(
ret, _ = fp8_gemm(
weight_fp8,
self.meta_weight.scale_inv,
self.fp8_tensor_weight,
......
......@@ -7,6 +7,7 @@
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/logging.h>
#include <transformer_engine/gemm.h>
#include <cuda.h>
#include <cublasLt.h>
#include <cublas_v2.h>
#include "../common.h"
......@@ -50,6 +51,10 @@ void cublas_gemm(const Tensor *inputA,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const Tensor *inputCounter,
cudaStream_t stream
) {
void *A = inputA->data.dptr;
......@@ -63,6 +68,10 @@ void cublas_gemm(const Tensor *inputA,
void *bias_ptr = inputBias->data.dptr;
const bool bias = bias_ptr != nullptr;
void *pre_gelu_out = outputPreGelu->data.dptr;
void *counter = nullptr;
if (inputCounter != nullptr) {
counter = inputCounter->data.dptr;
}
const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype);
......@@ -223,6 +232,27 @@ void cublas_gemm(const Tensor *inputA,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue, sizeof(epilogue)));
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205
if (counter != nullptr) {
if (m_split == 0) m_split=1;
if (n_split == 0) n_split=1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS,
&m_split, sizeof(m_split)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS,
&n_split, sizeof(n_split)));
if (gemm_producer) {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER,
&counter, sizeof(counter)));
} else {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER,
&counter, sizeof(counter)));
}
}
#endif
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
......@@ -254,7 +284,6 @@ void cublas_gemm(const Tensor *inputA,
workspaceSize,
stream)); /* stream */
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc));
......@@ -320,5 +349,82 @@ void nvte_cublas_gemm(const NVTETensor A,
wspace->data.shape[0],
accumulate, use_split_accumulator,
math_sm_count,
0,
0,
false,
nullptr,
stream);
}
void nvte_cublas_atomic_gemm(const NVTETensor A,
const NVTETensor B,
NVTETensor D,
const NVTETensor bias,
NVTETensor pre_gelu_out,
bool transa,
bool transb,
bool grad,
NVTETensor workspace,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const NVTETensor counter,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_atomic_gemm);
int cudart_version;
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&cudart_version));
NVTE_CHECK(cudart_version >= 12020, "Cuda version 12.2 is required for atomic gemm.");
NVTE_CHECK(cublasLtGetVersion() >= 120205, "Cublas version 12.2.5 is required for atomic gemm.");
using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor*>(A);
const Tensor *inputB = reinterpret_cast<const Tensor*>(B);
Tensor *outputD = reinterpret_cast<Tensor*>(D);
const Tensor *biasTensor = reinterpret_cast<const Tensor*>(bias);
Tensor *outputGelu = reinterpret_cast<Tensor*>(pre_gelu_out);
const Tensor *inputCounter = reinterpret_cast<const Tensor*>(counter);
Tensor *wspace = reinterpret_cast<Tensor*>(workspace);
const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
int lda, ldb, ldd;
if (transa && !transb) { // TN
lda = k;
ldb = k;
ldd = m;
} else if (!transa && !transb) { // NN
lda = m;
ldb = k;
ldd = m;
} else if (!transa && transb) { // NT
lda = m;
ldb = n;
ldd = m;
} else { // TT
NVTE_ERROR("TT layout not allowed.");
}
cublas_gemm(inputA,
inputB,
outputD,
biasTensor,
outputGelu,
m, n, k,
lda, ldb, ldd,
(transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N,
grad, wspace->data.dptr,
wspace->data.shape[0],
accumulate, use_split_accumulator,
math_sm_count,
m_split,
n_split,
gemm_producer,
inputCounter,
stream);
}
......@@ -54,6 +54,52 @@ void nvte_cublas_gemm(const NVTETensor A,
cudaStream_t stream
);
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
*
* \warning Cublas atomic gemm uses a beta API and is not tested for all use cases.
*
* Computes:
* - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors
* - `D = AB + bias` if `pre_gelu_out` is empty and `bias` is not empty
* - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors
*
* \param[in] A The A matrix.
* \param[in] B The B matrix.
* \param[in,out] D Output matrix.
* \param[in] bias Bias tensor.
* \param[in,out] pre_gelu_out Output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of the
* gradient computation.
* \param[out] workspace Workspace tensor.
* \param[in] accumulate Whether to accumulate the result into the D matrix.
* \param[in] use_split_accumulator Whether to use split accumulator in the FP8 GEMM.
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] m_split Number of chunks/splits along m-dimension for Atomic GEMM.
* \param[in] n_split Number of chunks/splits along n-dimension for Atomic GEMM.
* \param[in] gemm_producer Whether Atomic GEMM is the producer or consumer.
* \param[in,out] counter counter[chunk_i]=0 indicates chunk_i has been produced.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cublas_atomic_gemm(const NVTETensor A,
const NVTETensor B,
NVTETensor D,
const NVTETensor bias,
NVTETensor pre_gelu_out,
bool transa,
bool transb,
bool grad,
NVTETensor workspace,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const NVTETensor counter,
cudaStream_t stream
);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -2262,6 +2262,8 @@ class MultiheadAttention(torch.nn.Module):
ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False,
ub_split_ag: bool = False,
ub_atomic_gemm_rs: bool = False,
ub_atomic_gemm_ag: bool = False,
bias: bool = True,
normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda",
......@@ -2342,6 +2344,7 @@ class MultiheadAttention(torch.nn.Module):
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag,
normalization=normalization,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
**common_gemm_kwargs,
)
else:
......@@ -2372,6 +2375,7 @@ class MultiheadAttention(torch.nn.Module):
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag,
normalization=normalization,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
**common_gemm_kwargs,
)
else:
......@@ -2418,6 +2422,8 @@ class MultiheadAttention(torch.nn.Module):
parallel_mode="row" if set_parallel_mode else None,
ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag,
ub_atomic_gemm_rs=ub_atomic_gemm_rs,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
**common_gemm_kwargs,
)
......
......@@ -91,22 +91,40 @@ def fp8_gemm(
assert ub is not None, 'ub object is None!'
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap
args = tuple(args + (1,))
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (1, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap
args = tuple(args + (0,))
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (0, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG:
fn = ub.atomic_gemm_overlap_ag
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS:
fn = ub.split_overlap_rs
assert (
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS requires extra output tensor'
args = tuple(args + (True, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS:
fn = ub.atomic_gemm_overlap_rs
assert (
extra_output_tensor is not None
), 'ATOMIC_GEMM_RS requires extra output tensor'
args = tuple(args + (True, extra_output_tensor,))
_ = fn(*args)
return out, gelu_input
......@@ -195,10 +213,10 @@ def gemm(
assert ub is not None, 'ub object is None!'
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap
args = tuple(args + (1,))
args = tuple(args + (1, empty_tensor))
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap
args = tuple(args + (0,))
args = tuple(args + (0, empty_tensor))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag
extra_output_tensor = (
......
......@@ -179,6 +179,32 @@ void te_gemm(at::Tensor A,
int math_sm_count
);
void te_atomic_gemm(at::Tensor A,
at::Tensor A_scale_inverse,
transformer_engine::DType A_type,
bool transa,
at::Tensor B,
at::Tensor B_scale_inverse,
transformer_engine::DType B_type,
bool transb,
at::Tensor D,
at::Tensor D_scale,
transformer_engine::DType D_type,
at::Tensor D_amax,
at::Tensor bias,
transformer_engine::DType bias_type,
at::Tensor pre_gelu_out,
bool grad,
at::Tensor workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
at::Tensor counter
);
void fused_cast_transpose(at::Tensor input,
at::Tensor scale,
......
......@@ -6,6 +6,7 @@
#include "extensions.h"
void te_gemm(at::Tensor A,
at::Tensor A_scale_inverse,
transformer_engine::DType A_type,
......@@ -73,3 +74,82 @@ void te_gemm(at::Tensor A,
math_sm_count,
at::cuda::getCurrentCUDAStream());
}
void te_atomic_gemm(at::Tensor A,
at::Tensor A_scale_inverse,
transformer_engine::DType A_type,
bool transa,
at::Tensor B,
at::Tensor B_scale_inverse,
transformer_engine::DType B_type,
bool transb,
at::Tensor D,
at::Tensor D_scale,
transformer_engine::DType D_type,
at::Tensor D_amax,
at::Tensor bias,
transformer_engine::DType bias_type,
at::Tensor pre_gelu_out,
bool grad,
at::Tensor workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
at::Tensor counter
) {
using namespace transformer_engine;
auto te_A = makeTransformerEngineTensor(A.data_ptr(),
{static_cast<size_t>(A.size(0)),
static_cast<size_t>(A.size(1))},
A_type, nullptr, nullptr,
A_scale_inverse.data_ptr());
auto te_B = makeTransformerEngineTensor(B.data_ptr(),
{static_cast<size_t>(B.size(0)),
static_cast<size_t>(B.size(1))},
B_type, nullptr, nullptr,
B_scale_inverse.data_ptr());
auto te_D = makeTransformerEngineTensor(D.data_ptr(),
{static_cast<size_t>(D.size(0)),
static_cast<size_t>(D.size(1))},
D_type, D_amax.data_ptr(),
D_scale.data_ptr(), nullptr);
auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), {static_cast<size_t>(bias.size(0))},
bias_type);
auto te_counter = makeTransformerEngineTensor(counter.data_ptr(),
{static_cast<size_t>(counter.size(0))},
DType::kInt32);
const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0))}
: std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0)),
static_cast<size_t>(pre_gelu_out.size(1))};
auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out.data_ptr(),
gelu_shape,
GetTransformerEngineDType(
pre_gelu_out.scalar_type()));
auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(),
{workspaceSize},
DType::kByte);
nvte_cublas_atomic_gemm(te_A.data(),
te_B.data(),
te_D.data(),
te_bias.data(),
te_pre_gelu_out.data(),
transa,
transb,
grad,
te_workspace.data(),
accumulate,
use_split_accumulator,
math_sm_count,
m_split,
n_split,
gemm_producer,
te_counter.data(),
at::cuda::getCurrentCUDAStream());
}
......@@ -91,18 +91,24 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG)
.value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS)
.value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS)
.value("SPLIT_PIPELINED_AG", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG);
.value("SPLIT_PIPELINED_AG", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG)
.value("ATOMIC_GEMM_RS", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS)
.value("ATOMIC_GEMM_AG", ubuf::UBOverlapAlgo::ATOMIC_GEMM_AG);
py::class_<ubuf::UbufCommOverlap>(m, "UbufCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, int, bool, int>())
.def(py::init<torch::Tensor&, int, int, int, int, int, bool, int, torch::Tensor>())
.def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap)
.def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs)
.def("set_ubuf_scale_inv", &ubuf::UbufCommOverlap::set_ubuf_scale_inv)
.def("atomic_gemm_overlap_rs", &ubuf::UbufCommOverlap::atomic_gemm_overlap_rs)
.def("is_fp8_ubuf", &ubuf::UbufCommOverlap::is_fp8_ubuf)
.def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output);
py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap")
.def(py::init<torch::Tensor&, int, int, bool, int>())
.def(py::init<torch::Tensor&, int, int, int, int, bool, bool, int, torch::Tensor>())
.def("split_overlap_ag", &ubuf::UbufP2PCommOverlap::split_overlap_ag)
.def("atomic_gemm_overlap_ag", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag)
.def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output);
#else // NVTE_WITH_USERBUFFERS
......
......@@ -4,10 +4,13 @@
* See LICENSE for license information.
************************************************************************/
#include "userbuffers.h"
#include <assert.h>
#include <chrono>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <immintrin.h>
#include <iostream>
#include <math.h>
#include <mpi.h>
#include <sched.h>
......@@ -15,9 +18,6 @@
#include <string.h>
#include <unistd.h>
#include <x86intrin.h>
#include <chrono>
#include <iostream>
#include "userbuffers.h"
static int oob_bcast(void *comm_context, void *buf, int size, int root) {
MPI_Bcast(buf, size, MPI_BYTE, root,
......@@ -38,20 +38,31 @@ static int oob_gather(void *comm_context, int root, void *sbuf, void *rbuf, int
int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (const char *)b); }
#define CUDACHECK(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
#define CUDACHECK(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
#define NVTE_UB_ERROR(x) \
do { \
throw std::runtime_error(std::string(__FILE__ ":") + std::to_string(__LINE__) + \
" in function " + __func__ + ": " + x); \
} while (false)
#define CUCHECK(cmd) \
do { \
CUresult retval = cmd; \
if (retval != CUDA_SUCCESS) { \
const char *error_string; \
cuGetErrorString(retval, &error_string); \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, error_string); \
exit(EXIT_FAILURE); \
} \
} while (0);
#define NVTE_UB_ERROR(x) \
do { \
throw std::runtime_error(std::string(__FILE__ ":") + std::to_string(__LINE__) + \
" in function " + __func__ + ": " + x); \
} while (false)
int pipe_rank(communicator *comm, int step) {
int mynode = comm->myrank / comm->nvsize;
......@@ -89,12 +100,14 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
(*comm)->push = 1;
(*comm)->use_ce = 0;
(*comm)->cga_size = 2;
for (int i = 0; i < userbuffers_op_types; i++) (*comm)->basecounter[i] = 0;
for (int i = 0; i < userbuffers_op_types; i++)
(*comm)->basecounter[i] = 0;
(*comm)->head = 0;
(*comm)->tail = 0;
(*comm)->activeproxy = 1;
(*comm)->active_nreqs = 0;
for (int i = 0; i < userbuffers_op_types; i++) (*comm)->active_req[i].active = -1;
for (int i = 0; i < userbuffers_op_types; i++)
(*comm)->active_req[i].active = -1;
int ret = 0;
// split communicator
......@@ -112,8 +125,10 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
color = 0;
for (int n = 0; n < size; n++) {
if (n > 0 && strcmp(host_names[n - 1], host_names[n])) color++;
if (strcmp(host_name, host_names[n]) == 0) break;
if (n > 0 && strcmp(host_names[n - 1], host_names[n]))
color++;
if (strcmp(host_name, host_names[n]) == 0)
break;
}
free(host_names);
......@@ -128,14 +143,22 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
int core;
if (mylocal == 0) core = 50;
if (mylocal == 1) core = 58;
if (mylocal == 2) core = 18;
if (mylocal == 3) core = 26;
if (mylocal == 4) core = 114;
if (mylocal == 5) core = 122;
if (mylocal == 6) core = 82;
if (mylocal == 7) core = 90;
if (mylocal == 0)
core = 50;
if (mylocal == 1)
core = 58;
if (mylocal == 2)
core = 18;
if (mylocal == 3)
core = 26;
if (mylocal == 4)
core = 114;
if (mylocal == 5)
core = 122;
if (mylocal == 6)
core = 82;
if (mylocal == 7)
core = 90;
CPU_SET(core, &cpuset);
if (!getenv("NVTE_NODOUBLE")) {
......@@ -144,7 +167,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
else
CPU_SET(core + 128, &cpuset);
}
if (getenv("NVTE_DOPIN")) pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset);
if (getenv("NVTE_DOPIN"))
pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset);
if (ndev == numlocal) { // all visible devices
if (cur_dev != mylocal)
......@@ -175,7 +199,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
int datanodegroup_id =
myrank / numlocal / datanodes; // data reduction group node belongs, equals 0 for all if both
// pipenodes=1 and tensornodes=1
// mpi communicator only needed for SHARP which is always allreduce1/data-parallel
// mpi communicator only needed for SHARP which is always
// allreduce1/data-parallel
MPI_Comm_split(MPI_COMM_WORLD, mylocal + numlocal * datanodegroup_id, rank, &(*comm)->comm_inter);
// different rails from same group are in different subcommunicators
......@@ -192,19 +217,37 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
char *ib_dev_list;
int ZIONROCE = getenv("NVTE_ZIONROCE") ? atoi(getenv("NVTE_ZIONROCE")) : 0;
int ROCE = getenv("NVTE_ROCE") ? atoi(getenv("NVTE_ROCE")) : 0;
if (ZIONROCE) ROCE = 1;
if (ZIONROCE)
ROCE = 1;
int DGX_H100 = device_prop.major == 9;
switch (mylocal) {
case 0:ib_dev_list = "mlx5_0:1"; break; // NOLINT(*)
case 1:ib_dev_list = (char*)(DGX_H100?"mlx5_3:1":"mlx5_1:1"); break; // NOLINT(*)
case 2:ib_dev_list = (char*)(ZIONROCE?"mlx5_4:1":DGX_H100?"mlx5_4:1":"mlx5_2:1"); break; // NOLINT(*)
case 3:ib_dev_list = (char*)(DGX_H100?"mlx5_5:1":"mlx5_3:1"); break; // NOLINT(*)
case 4:ib_dev_list = (char*)(DGX_H100?"mlx5_6:1":"mlx5_6:1"); break; // NOLINT(*)
case 5:ib_dev_list = (char*)(DGX_H100?"mlx5_9:1":"mlx5_7:1"); break; // NOLINT(*)
case 6:ib_dev_list = (char*)(ZIONROCE?"mlx5_10:1":DGX_H100?"mlx5_10:1":"mlx5_8:1"); break; // NOLINT(*)
case 7:ib_dev_list = (char*)(DGX_H100?"mlx5_11:1":"mlx5_9:1"); break; // NOLINT(*)
default: break;
case 0:
ib_dev_list = "mlx5_0:1";
break; // NOLINT(*)
case 1:
ib_dev_list = (char *)(DGX_H100 ? "mlx5_3:1" : "mlx5_1:1"); // NOLINT(*)
break; // NOLINT(*)
case 2:
ib_dev_list = (char *)(ZIONROCE ? "mlx5_4:1" : DGX_H100 ? "mlx5_4:1" : "mlx5_2:1"); // NOLINT(*)
break; // NOLINT(*)
case 3:
ib_dev_list = (char *)(DGX_H100 ? "mlx5_5:1" : "mlx5_3:1"); // NOLINT(*)
break; // NOLINT(*)
case 4:
ib_dev_list = (char *)(DGX_H100 ? "mlx5_6:1" : "mlx5_6:1"); // NOLINT(*)
break; // NOLINT(*)
case 5:
ib_dev_list = (char *)(DGX_H100 ? "mlx5_9:1" : "mlx5_7:1"); // NOLINT(*)
break; // NOLINT(*)
case 6:
ib_dev_list = (char *)(ZIONROCE ? "mlx5_10:1" : DGX_H100 ? "mlx5_10:1" : "mlx5_8:1"); // NOLINT(*)
break; // NOLINT(*)
case 7:
ib_dev_list = (char *)(DGX_H100 ? "mlx5_11:1" : "mlx5_9:1"); // NOLINT(*)
break; // NOLINT(*)
default:
break;
}
(*comm)->fifo = reinterpret_cast<ub_request *>(malloc(sizeof(ub_request) * NVTE_MAX_REQUESTS));
......@@ -215,7 +258,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
CUDACHECK(cudaMallocHost((void **)&(*comm)->hostflags, // NOLINT(*)
(NVTE_MAX_SMS + 100) * sizeof(int)));
for (int i = 0; i < 100 + NVTE_MAX_SMS; i++) (*comm)->hostflags[i] = 0;
for (int i = 0; i < 100 + NVTE_MAX_SMS; i++)
(*comm)->hostflags[i] = 0;
_mm_mfence();
sleep(1);
......@@ -223,13 +267,16 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
(*comm)->ibnvsize = (*comm)->nvsize;
#define NBUF 2
#define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF)
// peer pointers + op flags + comm buffer
CUDACHECK(cudaMalloc(&(*comm)->gpu_ptrs, LOCALSIZE)); // flags and pointers, no block data yet
CUDACHECK(cudaMalloc(&(*comm)->gpu_ptrs,
LOCALSIZE)); // flags and pointers, no block data yet
CUDACHECK(cudaMemset((*comm)->gpu_ptrs, 0, LOCALSIZE));
CUDACHECK(cudaDeviceSynchronize());
register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm); // will use handler 0
register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE,
*comm); // will use handler 0
CUDACHECK(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int)));
CUDACHECK(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int)));
CUDACHECK(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int)));
......@@ -243,7 +290,6 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
#define GPU_PAGE_MASK (~GPU_PAGE_OFFSET)
CUDACHECK(cudaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE));
unsigned int flag = 1;
// cuPointerSetAttribute(&flag, CU_POINTER_ATTRIBUTE_SYNC_MEMOPS, (CUdeviceptr)(*comm)->flags);
CUDACHECK(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE));
(*comm)->flags =
reinterpret_cast<int *>(((CUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK);
......@@ -275,7 +321,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
pthread_attr_setschedparam(&attr, &param);
if (getenv("NVTE_UBDEBUG"))
printf("%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP %dx%d TPGROUP %dx%d PIPE_ID %d/%d\n",
printf("%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP %dx%d TPGROUP "
"%dx%d PIPE_ID %d/%d\n",
myrank, nranks, myrank / numlocal, myrank % numlocal, (*comm)->my_node,
(*comm)->ar_nvrank, (*comm)->my2_node, (*comm)->ar2_nvrank, (*comm)->num_nodes,
(*comm)->ar_nvsize, (*comm)->num2_nodes, (*comm)->ar2_nvsize, (*comm)->pipe_id,
......@@ -300,9 +347,9 @@ void destroy_communicator(communicator *comm) {
}
int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) {
if (comm->free_region > NVTE_MAX_REGIONS) return -1;
if (comm->free_region > NVTE_MAX_REGIONS)
return -1;
int hndl = comm->free_region;
// printf("%d register %d size %lld\n",comm->myrank,hndl,bytes);fflush(NULL);
comm->peer_ptr[hndl] = reinterpret_cast<void **>(malloc(sizeof(void *) * (comm->nvsize)));
if (alloc) {
......@@ -313,25 +360,22 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
reinterpret_cast<cudaIpcMemHandle_t *>(malloc(sizeof(cudaIpcMemHandle_t) * (comm->nvsize)));
CUDACHECK(cudaIpcGetMemHandle(&memhndl[comm->nvrank], *gpubuff));
MPI_Allgather(&memhndl[comm->nvrank], sizeof(cudaIpcMemHandle_t), MPI_BYTE, memhndl,
sizeof(cudaIpcMemHandle_t), MPI_BYTE, comm->comm_intra);
for (int i = 0; i < comm->nvsize; i++)
if (i != comm->nvrank)
CUDACHECK(cudaIpcOpenMemHandle((void **)&(comm->peer_ptr[hndl][i]), // NOLINT(*)
memhndl[i], cudaIpcMemLazyEnablePeerAccess));
comm->peer_ptr[hndl][comm->nvrank] = *gpubuff;
CUDACHECK(cudaDeviceSynchronize());
CUDACHECK(
cudaMemcpy(reinterpret_cast<char *>(comm->gpu_ptrs) + (hndl * comm->nvsize * sizeof(void *)),
comm->peer_ptr[hndl], comm->nvsize * sizeof(void *), cudaMemcpyHostToDevice));
CUDACHECK(cudaDeviceSynchronize());
free(memhndl);
comm->mem_ptr[hndl] = *gpubuff;
return comm->free_region++;
}
......@@ -352,8 +396,10 @@ int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons
void allreduce_nonsharp_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream, int op) {
if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
// if(comm->myrank==0) fprintf(stderr,"AR2(%d) user call launch_mode=%d\n",op,comm->launch_mode);
if (elements < 64)
NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
// if(comm->myrank==0) fprintf(stderr,"AR2(%d) user call
// launch_mode=%d\n",op,comm->launch_mode);
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
int blocksize = elements * 2;
int maxcredit = 0;
......@@ -361,19 +407,19 @@ void allreduce_nonsharp_inplace(const int handler, const int offset, const int e
blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) /
comm->nblocks; // FIXME TUNING
blocksize *= comm->alignblock;
if (blocksize < comm->minblock) blocksize = comm->minblock;
if (blocksize < comm->minblock)
blocksize = comm->minblock;
maxcredit = (elements * 2 + blocksize - 1) / blocksize;
// if(maxcredit>4) maxcredit=4;
// if(maxcredit>4 && ar_nvsize==1) maxcredit=4;
size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit
if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize;
// blocksize=elements*2;
if (blocksize > peerblock * ar_nvsize)
blocksize = peerblock * ar_nvsize;
int sms = allreduce2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, comm,
stream, op);
if (num_nodes > 1 && comm->launch_mode & NVTE_LAUNCH_CPU) {
if (!sms) return;
if (!sms)
return;
comm->fifo[comm->head].optype = op;
comm->fifo[comm->head].basecounter = comm->basecounter[op];
comm->fifo[comm->head].blocksize = blocksize;
......@@ -399,7 +445,8 @@ void allreduce2_userbuff_inplace(const int handler, const int offset, const int
void allreduce_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
if (elements < 64)
NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
allreduce_nonsharp_inplace(handler, offset, elements, comm, stream,
userbuffers_allreduceop_nonsharp);
return;
......@@ -407,7 +454,8 @@ void allreduce_userbuff_inplace(const int handler, const int offset, const int e
void reducescatter_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
if (elements < 64)
NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
int op = userbuffers_allreduceop_nonsharp;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
......@@ -418,17 +466,20 @@ void reducescatter_userbuff_inplace(const int handler, const int offset, const i
blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) /
comm->nblocks; // FIXME TUNING
blocksize *= comm->alignblock;
if (blocksize < comm->minblock) blocksize = comm->minblock;
if (blocksize < comm->minblock)
blocksize = comm->minblock;
maxcredit = (elements * 2 + blocksize - 1) / blocksize;
size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit
if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize;
if (blocksize > peerblock * ar_nvsize)
blocksize = peerblock * ar_nvsize;
int sms = reducescatter2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize,
comm, stream, op);
if (num_nodes > 1 && comm->launch_mode & NVTE_LAUNCH_CPU) {
if (!sms) return;
if (!sms)
return;
comm->fifo[comm->head].optype = op;
comm->fifo[comm->head].basecounter = comm->basecounter[op];
comm->fifo[comm->head].blocksize = blocksize;
......@@ -448,7 +499,8 @@ void reducescatter_userbuff_inplace(const int handler, const int offset, const i
void allgather_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
if (elements < 64)
NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
int op = userbuffers_allreduceop_nonsharp;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
int blocksize = elements * 2;
......@@ -458,11 +510,13 @@ void allgather_userbuff_inplace(const int handler, const int offset, const int e
blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) /
comm->nblocks; // FIXME TUNING
blocksize *= comm->alignblock;
if (blocksize < comm->minblock) blocksize = comm->minblock;
if (blocksize < comm->minblock)
blocksize = comm->minblock;
maxcredit = (elements * 2 + blocksize - 1) / blocksize;
size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit
if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize;
if (blocksize > peerblock * ar_nvsize)
blocksize = peerblock * ar_nvsize;
int sms = allgather2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, comm,
stream, op);
......
This diff is collapsed.
This diff is collapsed.
......@@ -263,6 +263,22 @@ class TransformerLayer(torch.nn.Module):
ub_bulk_dgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_DGRAD", "1")))
ub_split_ag = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_AG", "1")))
ub_split_rs = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_RS", "1")))
ub_atomic_gemm_rs = (ub_tp_comm_overlap
and bool(int(os.getenv("NVTE_UB_ATOMIC_GEMM_RS", "0"))))
assert (
not (ub_split_rs and ub_atomic_gemm_rs)
), "Only one type of RS overlap NVTE_UB_SPLIT_RS/NVTE_UB_ATOMIC_GEMM_RS should be enabled."
ub_atomic_gemm_ag = (ub_tp_comm_overlap
and bool(int(os.getenv("NVTE_UB_ATOMIC_GEMM_AG", "0"))))
assert (
not (ub_split_ag and ub_atomic_gemm_ag)
), "Only one type of AG overlap NVTE_UB_SPLIT_AG/NVTE_UB_ATOMIC_GEMM_AG should be enabled."
if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
warnings.warn(
"Atomic gemm uses a beta API from cublas and is not tested for all use cases."
)
bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1")))
self.layer_number = layer_number
self.output_layernorm = output_layernorm
......@@ -323,6 +339,8 @@ class TransformerLayer(torch.nn.Module):
"ub_bulk_dgrad" : ub_bulk_dgrad,
"ub_split_ag" : ub_split_ag,
"ub_split_rs" : ub_split_rs,
"ub_atomic_gemm_rs" : ub_atomic_gemm_rs,
"ub_atomic_gemm_ag" : ub_atomic_gemm_ag,
}
self.self_attention = MultiheadAttention(
......@@ -377,6 +395,8 @@ class TransformerLayer(torch.nn.Module):
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag,
ub_atomic_gemm_rs=ub_atomic_gemm_rs,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
activation=activation,
normalization=normalization,
device=device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment