".github/git@developer.sourcefind.cn:gaoqiong/pybind11.git" did not exist on "bb80e68f755fe11c8744bde7a88aa82fb776ed59"
Commit 47077129 authored by yuguo's avatar yuguo
Browse files

[DCU] remove redundant gemm

parent aa62d24c
...@@ -1166,82 +1166,13 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D, ...@@ -1166,82 +1166,13 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
stream); stream);
} }
// add for batchgemm
void nvte_cublas_batchgemm_v2(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, int batch_count, cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_batchgemm_v2);
using namespace transformer_engine;
const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = convertNVTETensorCheck(B);
Tensor *outputD = convertNVTETensor(D);
const Tensor *biasTensor = convertNVTETensor(bias);
Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
Tensor *wspace = convertNVTETensor(workspace);
if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr != nullptr)) {
NVTE_ERROR("MOE batchgemm not surpport bias or gelu.");
}
int m, n, k;
if (!transa && transb) {
// for NT
m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
} else if(transa && !transb){
// for TN
m = transa ? inputA->data.shape[0]/batch_count: inputA->data.shape[1];
k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
} else if(!transa && !transb){
// for NN
m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count; }
int lda, ldb, ldd;
if (transa && !transb) { // TN
lda = k;
ldb = k;
ldd = m;
} else if (!transa && !transb) { // NN
lda = m;
ldb = k;
ldd = m;
} else if (!transa && transb) { // NT
lda = m;
ldb = n;
ldd = m;
} else { // TT
NVTE_ERROR("TT layout not allowed.");
}
hipblas_batchgemm(inputA,
inputB,
outputD,
biasTensor,
outputGelu,
m, n, k,
lda, ldb, ldd,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
grad, wspace->data.dptr,
wspace->data.shape[0],
accumulate, use_split_accumulator,
math_sm_count,
0,
0,
false,
nullptr,
batch_count,
stream);
}
// add for batchgemm // add for batchgemm
void nvte_cublas_batchgemm_v3(const NVTETensor A, const NVTETensor B, const NVTETensor A_scales, const NVTETensor B_scales, NVTETensor D, const NVTETensor bias, void nvte_cublas_batchgemm_tensorwise_int8(const NVTETensor A, const NVTETensor B, const NVTETensor A_scales, const NVTETensor B_scales, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator, NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, int batch_count, cudaStream_t stream) { int math_sm_count, int batch_count, cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_batchgemm_v3); NVTE_API_CALL(nvte_cublas_batchgemm_tensorwise_int8);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = convertNVTETensorCheck(B); const Tensor *inputB = convertNVTETensorCheck(B);
...@@ -1297,16 +1228,7 @@ void nvte_cublas_batchgemm_v3(const NVTETensor A, const NVTETensor B, const NVTE ...@@ -1297,16 +1228,7 @@ void nvte_cublas_batchgemm_v3(const NVTETensor A, const NVTETensor B, const NVTE
handle = hipblaslt_handles[0]; handle = hipblaslt_handles[0];
hipblaslt_batchgemm_tensorwise_int8(inputA, inputB, inputA_scales, inputB_scales, outputD, biasTensor, outputGelu, NVTE_ERROR("Remove nvte_cublas_batchgemm_tensorwise_int8 for now.");
m, n, k, lda, ldb, ldd,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
grad,
wspace->data.dptr,
wspace->data.shape[0], accumulate, use_split_accumulator,
math_sm_count, 0, 0,
false, nullptr, batch_count, stream,
handle);
} }
#endif #endif
......
This diff is collapsed.
...@@ -152,12 +152,7 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D, ...@@ -152,12 +152,7 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
NVTETensor workspace, bool accumulate, bool use_split_accumulator, NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, int batch_count, cudaStream_t stream); int math_sm_count, int batch_count, cudaStream_t stream);
void nvte_cublas_batchgemm_v2(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, void nvte_cublas_batchgemm_tensorwise_int8(const NVTETensor A, const NVTETensor B, const NVTETensor A_scales, const NVTETensor B_scales, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, int batch_count, cudaStream_t stream);
void nvte_cublas_batchgemm_v3(const NVTETensor A, const NVTETensor B, const NVTETensor A_scales, const NVTETensor B_scales, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator, NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, int batch_count, cudaStream_t stream); int math_sm_count, int batch_count, cudaStream_t stream);
......
...@@ -588,7 +588,7 @@ std::vector<py::object> generic_batchgemm(py::handle A, bool transa, py::handle ...@@ -588,7 +588,7 @@ std::vector<py::object> generic_batchgemm(py::handle A, bool transa, py::handle
} else { } else {
// Launch GEMM // Launch GEMM
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_cublas_batchgemm_v2(A_tensor.data(), B_tensor.data(), D_tensor.data(), bias_tensor.data(), nvte_cublas_batchgemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), bias_tensor.data(),
te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(),
accumulate, use_split_accumulator, num_math_sms, batch_count, main_stream); accumulate, use_split_accumulator, num_math_sms, batch_count, main_stream);
}); });
...@@ -724,7 +724,7 @@ std::vector<py::object> tensorwise_int8_batchgemm(py::handle A, bool transa, py: ...@@ -724,7 +724,7 @@ std::vector<py::object> tensorwise_int8_batchgemm(py::handle A, bool transa, py:
} else { } else {
// Launch GEMM // Launch GEMM
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_cublas_batchgemm_v3(A_tensor.data(), B_tensor.data(), A_scales_tensor.data(), B_scales_tensor.data(), D_tensor.data(), bias_tensor.data(), nvte_cublas_batchgemm_tensorwise_int8(A_tensor.data(), B_tensor.data(), A_scales_tensor.data(), B_scales_tensor.data(), D_tensor.data(), bias_tensor.data(),
te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(),
accumulate, use_split_accumulator, num_math_sms, batch_count, main_stream); accumulate, use_split_accumulator, num_math_sms, batch_count, main_stream);
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment