"vscode:/vscode.git/clone" did not exist on "d198770e96a6eac4d7b6233e6f411e339b32ce3d"
Commit 47077129 authored by yuguo's avatar yuguo
Browse files

[DCU] remove redundant gemm

parent aa62d24c
...@@ -1166,82 +1166,13 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D, ...@@ -1166,82 +1166,13 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
stream); stream);
} }
// add for batchgemm
void nvte_cublas_batchgemm_v2(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, int batch_count, cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_batchgemm_v2);
using namespace transformer_engine;
const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = convertNVTETensorCheck(B);
Tensor *outputD = convertNVTETensor(D);
const Tensor *biasTensor = convertNVTETensor(bias);
Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
Tensor *wspace = convertNVTETensor(workspace);
if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr != nullptr)) {
NVTE_ERROR("MOE batchgemm not surpport bias or gelu.");
}
int m, n, k;
if (!transa && transb) {
// for NT
m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
} else if(transa && !transb){
// for TN
m = transa ? inputA->data.shape[0]/batch_count: inputA->data.shape[1];
k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
} else if(!transa && !transb){
// for NN
m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count; }
int lda, ldb, ldd;
if (transa && !transb) { // TN
lda = k;
ldb = k;
ldd = m;
} else if (!transa && !transb) { // NN
lda = m;
ldb = k;
ldd = m;
} else if (!transa && transb) { // NT
lda = m;
ldb = n;
ldd = m;
} else { // TT
NVTE_ERROR("TT layout not allowed.");
}
hipblas_batchgemm(inputA,
inputB,
outputD,
biasTensor,
outputGelu,
m, n, k,
lda, ldb, ldd,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
grad, wspace->data.dptr,
wspace->data.shape[0],
accumulate, use_split_accumulator,
math_sm_count,
0,
0,
false,
nullptr,
batch_count,
stream);
}
// add for batchgemm // add for batchgemm
void nvte_cublas_batchgemm_v3(const NVTETensor A, const NVTETensor B, const NVTETensor A_scales, const NVTETensor B_scales, NVTETensor D, const NVTETensor bias, void nvte_cublas_batchgemm_tensorwise_int8(const NVTETensor A, const NVTETensor B, const NVTETensor A_scales, const NVTETensor B_scales, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator, NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, int batch_count, cudaStream_t stream) { int math_sm_count, int batch_count, cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_batchgemm_v3); NVTE_API_CALL(nvte_cublas_batchgemm_tensorwise_int8);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = convertNVTETensorCheck(B); const Tensor *inputB = convertNVTETensorCheck(B);
...@@ -1297,16 +1228,7 @@ void nvte_cublas_batchgemm_v3(const NVTETensor A, const NVTETensor B, const NVTE ...@@ -1297,16 +1228,7 @@ void nvte_cublas_batchgemm_v3(const NVTETensor A, const NVTETensor B, const NVTE
handle = hipblaslt_handles[0]; handle = hipblaslt_handles[0];
hipblaslt_batchgemm_tensorwise_int8(inputA, inputB, inputA_scales, inputB_scales, outputD, biasTensor, outputGelu, NVTE_ERROR("Remove nvte_cublas_batchgemm_tensorwise_int8 for now.");
m, n, k, lda, ldb, ldd,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
grad,
wspace->data.dptr,
wspace->data.shape[0], accumulate, use_split_accumulator,
math_sm_count, 0, 0,
false, nullptr, batch_count, stream,
handle);
} }
#endif #endif
......
...@@ -1352,411 +1352,6 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD, ...@@ -1352,411 +1352,6 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc)); NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
} }
void hipblaslt_batchgemm_tensorwise_int8(const Tensor *inputA,
const Tensor *inputB,
const Tensor *inputA_scales,
const Tensor *inputB_scales,
Tensor *outputD,
const Tensor *inputBias,
Tensor *outputPreGelu,
int m, int n, int k,
int lda, int ldb, int ldd,
hipblasOperation_t transa,
hipblasOperation_t transb,
bool grad,
void* workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const Tensor *inputCounter,
size_t batch_count,
hipStream_t stream,
hipblasLtHandle_t handle
) {
void *A = inputA->data.dptr;
void *A_scale_inverse = inputA_scales->data.dptr;
float *A_scale_inverse_float = (float*)(inputA_scales->data.dptr);
void *B = inputB->data.dptr;
void *B_scale_inverse = inputB_scales->data.dptr;
float *B_scale_inverse_float = (float*)(inputB_scales->data.dptr);
void *D = outputD->data.dptr;
void *bias_ptr = inputBias->data.dptr;
const bool bias = bias_ptr != nullptr;
void *pre_gelu_out = outputPreGelu->data.dptr;
const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype);
const bool use_int8 = is_int8_dtype(inputA->data.dtype) ||
is_int8_dtype(inputB->data.dtype);
const hipDataType A_type = get_hipblaslt_dtype(inputA->data.dtype);
const hipDataType B_type = get_hipblaslt_dtype(inputB->data.dtype);
const hipDataType D_type = get_hipblaslt_dtype(outputD->data.dtype);
const hipDataType bias_type = get_hipblaslt_dtype(inputBias->data.dtype);
NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr,
"FP8 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_fp8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr,
"FP8 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_int8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr,
"INT8 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_int8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr,
"INT8 input to GEMM requires inverse of scale!");
bool tensorwise_int8 = 0;;
const char *NVTE_INT8_SIM_FP8_TENSORWISE = std::getenv("NVTE_INT8_SIM_FP8_TENSORWISE");
if (NVTE_INT8_SIM_FP8_TENSORWISE != nullptr && NVTE_INT8_SIM_FP8_TENSORWISE[0] == '1' && use_int8) tensorwise_int8 = 1;
// check consistency of arguments:
// if fp8 is desired, context cannot be null
// fp8 + gelu fusion + fp8 aux is unavailable right now.
if (use_fp8 || use_int8) {
NVTE_CHECK(!gelu, "fp8 gemm + gelu fusion is unavailable right now!");
}
float one = 1.0;
float zero = 0.0;
float beta = (accumulate) ? one : zero;
int device_id;
NVTE_CHECK_CUDA(hipGetDevice(&device_id));
if (handle == nullptr) {
handle = cached_handles.get(device_id);
if (handle == nullptr)
{
handle = cached_handles.obtain(device_id);
}
}
hipblasLtMatmulDesc_t operationDesc = nullptr;
hipblasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
hipblasLtMatmulPreference_t preference = nullptr;
hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT;
int64_t ld_gelumat = (int64_t) ldd;
// default to tf32 except for e5m2 inputs where the config is not supported
hipblasComputeType_t gemm_compute_type = HIPBLAS_COMPUTE_32F;
// Create matrix descriptors. Not setting any extra attributes.
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Adesc, A_type,
transa == HIPBLAS_OP_N ? m : k,
transa == HIPBLAS_OP_N ? k : m,
lda));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Bdesc, B_type,
transb == HIPBLAS_OP_N ? k : n,
transb == HIPBLAS_OP_N ? n : k,
ldb));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
if (tensorwise_int8) {
size_t strideA = m*k;
size_t strideB = k*n;
size_t strideD = m*n;
hipblasLtMatrixLayoutSetAttribute(Adesc, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t));
hipblasLtMatrixLayoutSetAttribute(Adesc, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideA, sizeof(int64_t));
hipblasLtMatrixLayoutSetAttribute(Bdesc, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t));
hipblasLtMatrixLayoutSetAttribute(Bdesc, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideB, sizeof(int64_t));
hipblasLtMatrixLayoutSetAttribute(Ddesc, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t));
hipblasLtMatrixLayoutSetAttribute(Ddesc, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideD, sizeof(int64_t));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F));
} else {
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F));
}
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSA,
&transa, sizeof(transa)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSB,
&transb, sizeof(transb)));
// set fp8 attributes -- input and output types should already be set to fp8 as appropriate
// Note: gelu fusion isn't available right now, and we don't need
// amax(D) either (next op is high precision).
if (use_fp8) {
// Split accumulator.
const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1;
/*
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_FAST_ACCUM, //TODO: We don't have fast accum mode yet
&fastAccuMode,
sizeof(fastAccuMode)));
*/
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&A_scale_inverse,
sizeof(A_scale_inverse)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse,
sizeof(B_scale_inverse)));
if (bias) {
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE,
&bias_type, sizeof(bias_type)));
}
}
if (tensorwise_int8) {
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER,
(void*)&A_scale_inverse_float,
sizeof(void*)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
(void*)&B_scale_inverse_float,
sizeof(void*)));
if (bias) {
NVTE_CHECK(false, "tensorwise_int8 not surpport bias!");
}
}
if (bias && gelu) {
if (grad) {
epilogue = HIPBLASLT_EPILOGUE_DGELU_BGRAD;
} else {
epilogue = HIPBLASLT_EPILOGUE_GELU_AUX_BIAS;
}
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_ptr, sizeof(bias_ptr)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&pre_gelu_out, sizeof(pre_gelu_out)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
&ld_gelumat, sizeof(ld_gelumat)));
} else if (bias) {
if (grad) {
// grad output is always input B
epilogue = HIPBLASLT_EPILOGUE_BGRADB;
} else {
epilogue = HIPBLASLT_EPILOGUE_BIAS;
}
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_ptr, sizeof(bias_ptr)));
} else if (gelu) {
if (grad) {
epilogue = HIPBLASLT_EPILOGUE_DGELU;
} else {
epilogue = HIPBLASLT_EPILOGUE_GELU_AUX;
}
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&pre_gelu_out, sizeof(pre_gelu_out)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
&ld_gelumat, sizeof(ld_gelumat)));
}
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue, sizeof(epilogue)));
GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type,
use_fp8 ? bias_type : (hipDataType)-1,
m, n, k, lda, ldb, ldd, transa, transb, epilogue );
GemmAlgoCache::Algo cached_algo;
if (algoCache.find(gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo.has_value())
{
int firstAlgo = getIntEnv("TE_HIPBLASLT_ALGO_SELECTION", 0, 0);
int tuneLoopCount = getIntEnv("TE_HIPBLASLT_TUNING_RUN_COUNT", 0, 0);
int algoTuneCount = 1;
std::vector<hipblasLtMatmulHeuristicResult_t> algoArr;
bool logTuning = getIntEnv("TE_HIPBLASLT_LOG_TUNING", 0, 0) != 0;
if (tuneLoopCount)
{
/* HIPBLASLT may return hundreds of algos for some configs
* Limit amount by default. User may override with env
*/
static const int defaultAlgoCount = 16;
algoTuneCount = getIntEnv("TE_HIPBLASLT_TUNING_ALGO_COUNT", defaultAlgoCount, 1);
}
algoTuneCount += firstAlgo;
int algoTotalCount = cached_algo.hasId() ? std::max(algoTuneCount, (cached_algo.index + 1)) : algoTuneCount;
algoArr.resize(algoTotalCount);
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceCreate(&preference));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceSetAttribute(
preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspaceSize, sizeof(workspaceSize)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Ddesc,
Ddesc, preference, algoTotalCount, algoArr.data(),
&algoTotalCount));
algoArr.resize(algoTotalCount);
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceDestroy(preference));
//If cached algo exists in persistent storage we just need to find matching hipblasLtMatmulAlgo_t
if (cached_algo.hasId())
{
int idx = (cached_algo.index < algoTotalCount) ? cached_algo.index : 0;
for (int i=0; i<algoTotalCount; i++)
{
const auto &algo = algoArr[idx];
if (algo.state == HIPBLAS_STATUS_SUCCESS)
{
if (cached_algo.algoId == cached_algo.getAlgoId(algo.algo))
{
cached_algo.algo = algo.algo;
if (algo.workspaceSize != cached_algo.ws_size_min || idx != cached_algo.index)
{
cached_algo.ws_size_min = algo.workspaceSize;
cached_algo.index = idx;
algoCache.store(gemm_cfg, cached_algo);
}
break;
}
}
idx = (idx + 1) % algoTotalCount;
}
if (logTuning && !cached_algo.algo.has_value())
{
std::cout << "[WARNING] Cannot find cached algoId " << cached_algo.algoId << " in hipBLASLt results" << std::endl;
}
}
//No suitable entry in autotune cache or could not find matched algo in hipBLASLt results
if (!cached_algo.algo.has_value())
{
int bestAlgo = -1;
algoTuneCount = std::min(algoTuneCount, algoTotalCount);
if (tuneLoopCount > 0)
{
if (logTuning)
std::cout << "[INFO] Perform hipBLASLt algo selection on GPU" << device_id
<< " in range [" << firstAlgo << "-" << (algoTuneCount - 1) << "] with "
<< tuneLoopCount << " loops " << std::endl;
NVTE_CHECK_CUDA(hipStreamSynchronize(stream));
hipStream_t profilingStream;
NVTE_CHECK_CUDA(hipStreamCreateWithFlags(&profilingStream, hipStreamNonBlocking));
using tuning_clock = std::chrono::steady_clock;
tuning_clock::now(); //the first call takes little longer so do it outside the loop
tuning_clock::duration bestTime = tuning_clock::duration::max();
for (int algo=firstAlgo; algo<algoTuneCount; algo++)
{
if (algoArr[algo].state != HIPBLAS_STATUS_SUCCESS)
{
continue;
}
// Warm-up call
NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle,
operationDesc,
static_cast<const void*>(&one), /* alpha */
A, /* A */
Adesc,
B, /* B */
Bdesc,
static_cast<const void*>(&beta), /* beta */
D, /* C */
Ddesc,
D, /* D */
Ddesc,
&algoArr[algo].algo, /* algo */
workspace, /* workspace */
workspaceSize,
profilingStream)); /* stream */
NVTE_CHECK_CUDA(hipStreamSynchronize(profilingStream));
//Profiling loop
tuning_clock::time_point startTime = tuning_clock::now();
for (int loop=0; loop<tuneLoopCount; loop++)
{
NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle,
operationDesc,
static_cast<const void*>(&one), /* alpha */
A, /* A */
Adesc,
B, /* B */
Bdesc,
static_cast<const void*>(&beta), /* beta */
D, /* C */
Ddesc,
D, /* D */
Ddesc,
&algoArr[algo].algo, /* algo */
workspace, /* workspace */
workspaceSize,
profilingStream)); /* stream */
}
NVTE_CHECK_CUDA(hipStreamSynchronize(profilingStream));
tuning_clock::duration algoTime = tuning_clock::now() - startTime;
if (algoTime < bestTime)
{
bestAlgo = algo;
bestTime = algoTime;
}
}
NVTE_CHECK_CUDA(hipStreamDestroy(profilingStream));
if (bestAlgo >= 0)
{
if (logTuning)
std::cout << "[INFO] Select hipBLASLt algo " << bestAlgo << " with time "
<< std::chrono::duration_cast<std::chrono::nanoseconds>(bestTime).count() / tuneLoopCount
<< " ns" << std::endl;
}
}
else if (firstAlgo < algoTuneCount)
{
bestAlgo = firstAlgo;
}
if (bestAlgo < 0) {
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Ddesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Bdesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Adesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
throw std::runtime_error("Unable to find any suitable algorithms");
}
cached_algo.algo = algoArr[bestAlgo].algo;
cached_algo.index = bestAlgo;
cached_algo.algoId = cached_algo.getAlgoId(algoArr[bestAlgo].algo);
cached_algo.ws_size_min = algoArr[bestAlgo].workspaceSize;
cached_algo.ws_size_max = workspaceSize;
if (logTuning)
std::cout << "[INFO] Use hipBLASLt algo [" << bestAlgo << "] " << cached_algo.algoId << std::endl;
algoCache.store(gemm_cfg, cached_algo);
}
}
// D = alpha * (A * B) + beta * C
NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle,
operationDesc,
static_cast<const void*>(&one), /* alpha */
A, /* A */
Adesc,
B, /* B */
Bdesc,
static_cast<const void*>(&beta), /* beta */
D, /* C */
Ddesc,
D, /* D */
Ddesc,
&cached_algo.algo.value(), /* algo */
workspace, /* workspace */
workspaceSize,
stream)); /* stream */
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Ddesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Bdesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Adesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
}
class userArgsManager { class userArgsManager {
public: public:
......
...@@ -152,12 +152,7 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D, ...@@ -152,12 +152,7 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
NVTETensor workspace, bool accumulate, bool use_split_accumulator, NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, int batch_count, cudaStream_t stream); int math_sm_count, int batch_count, cudaStream_t stream);
void nvte_cublas_batchgemm_v2(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, void nvte_cublas_batchgemm_tensorwise_int8(const NVTETensor A, const NVTETensor B, const NVTETensor A_scales, const NVTETensor B_scales, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, int batch_count, cudaStream_t stream);
void nvte_cublas_batchgemm_v3(const NVTETensor A, const NVTETensor B, const NVTETensor A_scales, const NVTETensor B_scales, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator, NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, int batch_count, cudaStream_t stream); int math_sm_count, int batch_count, cudaStream_t stream);
......
...@@ -588,7 +588,7 @@ std::vector<py::object> generic_batchgemm(py::handle A, bool transa, py::handle ...@@ -588,7 +588,7 @@ std::vector<py::object> generic_batchgemm(py::handle A, bool transa, py::handle
} else { } else {
// Launch GEMM // Launch GEMM
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_cublas_batchgemm_v2(A_tensor.data(), B_tensor.data(), D_tensor.data(), bias_tensor.data(), nvte_cublas_batchgemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), bias_tensor.data(),
te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(),
accumulate, use_split_accumulator, num_math_sms, batch_count, main_stream); accumulate, use_split_accumulator, num_math_sms, batch_count, main_stream);
}); });
...@@ -724,7 +724,7 @@ std::vector<py::object> tensorwise_int8_batchgemm(py::handle A, bool transa, py: ...@@ -724,7 +724,7 @@ std::vector<py::object> tensorwise_int8_batchgemm(py::handle A, bool transa, py:
} else { } else {
// Launch GEMM // Launch GEMM
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_cublas_batchgemm_v3(A_tensor.data(), B_tensor.data(), A_scales_tensor.data(), B_scales_tensor.data(), D_tensor.data(), bias_tensor.data(), nvte_cublas_batchgemm_tensorwise_int8(A_tensor.data(), B_tensor.data(), A_scales_tensor.data(), B_scales_tensor.data(), D_tensor.data(), bias_tensor.data(),
te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(),
accumulate, use_split_accumulator, num_math_sms, batch_count, main_stream); accumulate, use_split_accumulator, num_math_sms, batch_count, main_stream);
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment