Unverified Commit 1203099a authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

Remove `THCState` from `apex/contrib/multihead_attn` (#1239)

* pass `self.mask_additive`

* clang-format

* removing THCState
parent 3c8f5161
#include <vector>
#include <iostream> #include <iostream>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include <cuda_runtime.h>
//#include <ATen/ATen.h> //#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h" #include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/wmma_gemm_traits.h" #include "cutlass/gemm/wmma_gemm_traits.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
cublasOperation_t convertTransToCublasOperation(char trans) { cublasOperation_t convertTransToCublasOperation(char trans) {
if (trans == 't') return CUBLAS_OP_T; if (trans == 't')
else if (trans == 'n') return CUBLAS_OP_N; return CUBLAS_OP_T;
else if (trans == 'c') return CUBLAS_OP_C; else if (trans == 'n')
return CUBLAS_OP_N;
else if (trans == 'c')
return CUBLAS_OP_C;
else { else {
AT_ERROR("trans must be one of: t, n, c"); AT_ERROR("trans must be one of: t, n, c");
return CUBLAS_OP_T; return CUBLAS_OP_T;
} }
} }
void CublasStridedBatchedGemm(THCState *state, char transa, char transb, long m, long n, long k, void CublasStridedBatchedGemm(
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, char transa, char transb, long m, long n, long k,
float beta, half *c, long ldc, long strideC, long batchCount, cublasGemmAlgo_t algo=CUBLAS_GEMM_DEFAULT_TENSOR_OP) { float alpha, const half *a, long lda, long strideA, const half *b, long ldb,
cublasOperation_t opa = convertTransToCublasOperation(transa); long strideB, float beta, half *c, long ldc, long strideC, long batchCount,
cublasOperation_t opb = convertTransToCublasOperation(transb); cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP) {
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasOperation_t opb = convertTransToCublasOperation(transb);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
float fAlpha = alpha; cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
float fBeta = beta; cublasSetStream(handle, stream);
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); float fAlpha = alpha;
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx(handle, float fBeta = beta;
opa, opb, (int)m, (int)n, (int)k, // THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
(void*)&fAlpha, a, CUDA_R_16F, (int)lda, strideA, TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx(
b, CUDA_R_16F, (int)ldb, strideB, handle, opa, opb, (int)m, (int)n, (int)k, (void *)&fAlpha, a, CUDA_R_16F,
(void*)&fBeta, c, CUDA_R_16F, (int)ldc, strideC, (int)lda, strideA, b, CUDA_R_16F, (int)ldb, strideB, (void *)&fBeta, c,
(int)batchCount, CUDA_R_32F, algo)); CUDA_R_16F, (int)ldc, strideC, (int)batchCount, CUDA_R_32F, algo));
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); // THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
} }
template<cutlass::MatrixLayout::Kind A_LAYOUT, cutlass::MatrixLayout::Kind B_LAYOUT, int SRC_A, int SRC_B, int DST_C> template <cutlass::MatrixLayout::Kind A_LAYOUT,
cutlass::MatrixLayout::Kind B_LAYOUT, int SRC_A, int SRC_B, int DST_C>
void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k, void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, float alpha, const half *a, long lda, long strideA,
float beta, half *c, long ldc, long strideC, long batchCount) { const half *b, long ldb, long strideB, float beta,
//printf("CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC: %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f\n", ((int)A_LAYOUT == 0 ? 'T' : 'N'), ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, SRC_A,SRC_B,DST_C, lda, ldb, ldc, strideA, strideB, strideC, alpha, beta); half *c, long ldc, long strideC, long batchCount) {
// printf("CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC:
// %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f\n",
// ((int)A_LAYOUT == 0 ? 'T' : 'N'), ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k,
// SRC_A,SRC_B,DST_C, lda, ldb, ldc, strideA, strideB, strideC, alpha, beta);
typedef cutlass::gemm::WmmaGemmTraits< typedef cutlass::gemm::WmmaGemmTraits<
A_LAYOUT, A_LAYOUT, B_LAYOUT, cutlass::Shape<32, 16, 16>, half, half, half,
B_LAYOUT, cutlass::gemm::LinearScaling<float>, float,
cutlass::Shape<32, 16, 16>, typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<
half, typename cutlass::Shape<32, 16, 16>>::Shape,
half, typename cutlass::Shape<16, 16, 16>,
half, SRC_A, // kScalarsPerLdgA_
cutlass::gemm::LinearScaling<float>, SRC_B, // kScalarsPerLdgB_
float, SRC_A, // KScalarsPerLdsA_
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape, SRC_B, // KScalarsPerLdsB_
typename cutlass::Shape<16, 16, 16>, DST_C, // kScalarsPerLdgCAndStgD_
SRC_A, //kScalarsPerLdgA_ DST_C / 2, // kScalarsPerStsD_
SRC_B, //kScalarsPerLdgB_ DST_C / 2 // kScalarsPerLdsD_
SRC_A, //KScalarsPerLdsA_ >
SRC_B, //KScalarsPerLdsB_ WmmaGemmTraits;
DST_C, //kScalarsPerLdgCAndStgD_
DST_C/2, //kScalarsPerStsD_
DST_C/2 //kScalarsPerLdsD_
>
WmmaGemmTraits;
typedef cutlass::gemm::Gemm<WmmaGemmTraits> Gemm; typedef cutlass::gemm::Gemm<WmmaGemmTraits> Gemm;
typename Gemm::Params params; typename Gemm::Params params;
int result = params.initialize( int result = params.initialize(
m, // M dimension for each batch m, // M dimension for each batch
n, // N dimension for each batch n, // N dimension for each batch
k, // K dimension for each batch k, // K dimension for each batch
alpha, // scalar alpha alpha, // scalar alpha
a, a, lda,
lda, strideA, // distance in memory between the first element of neighboring
strideA, // distance in memory between the first element of neighboring batch // batch
b, b, ldb,
ldb, strideB, // distance in memory between the first element of neighboring
strideB, // distance in memory between the first element of neighboring batch // batch
beta, // scalar beta beta, // scalar beta
c, // source matrix C c, // source matrix C
ldc, ldc,
strideC, // distance in memory between the first element of neighboring batch strideC, // distance in memory between the first element of neighboring
c, // destination matrix C (may be different memory than source C matrix) // batch
ldc, c, // destination matrix C (may be different memory than source C matrix)
strideC, // distance in memory between the first element of neighboring batch ldc,
batchCount strideC, // distance in memory between the first element of neighboring
); // batch
batchCount);
AT_ASSERTM(result == 0, "Failed to initialize CUTLASS Gemm::Params object."); AT_ASSERTM(result == 0, "Failed to initialize CUTLASS Gemm::Params object.");
// batchCount in cutlass batched GEMM kernels maps to gridDim.z, which is limited to 16 bits. // batchCount in cutlass batched GEMM kernels maps to gridDim.z, which is
// To implement batched GEMM with larger batch size, we fragment it into // limited to 16 bits. To implement batched GEMM with larger batch size, we
// smaller batched GEMMs of gridDim.z <= 64k // fragment it into smaller batched GEMMs of gridDim.z <= 64k
long batchesLeft = batchCount; long batchesLeft = batchCount;
long iterBatchCount = std::min(batchesLeft, static_cast<long>((1 << 16) - 1)); long iterBatchCount = std::min(batchesLeft, static_cast<long>((1 << 16) - 1));
do {
//printf("CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC: %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f TotalBatches: %ld iterBatchCount %ld\n", ((int)A_LAYOUT == 0 ? 'T' : 'N'), ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, SRC_A,SRC_B,DST_C, lda, ldb, ldc, strideA, strideB, strideC, alpha, beta, batchesLeft, iterBatchCount);
int result = params.initialize(
m, // M dimension for each batch
n, // N dimension for each batch
k, // K dimension for each batch
alpha, // scalar alpha
a,
lda,
strideA, // distance in memory between the first element of neighboring batch
b,
ldb,
strideB, // distance in memory between the first element of neighboring batch
beta, // scalar beta
c, // source matrix C
ldc,
strideC, // distance in memory between the first element of neighboring batch
c, // destination matrix C (may be different memory than source C matrix)
ldc,
strideC, // distance in memory between the first element of neighboring batch
iterBatchCount
);
AT_ASSERTM(result == 0, "Failed to initialize CUTLASS Gemm::Params object."); do {
// printf("CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC:
// %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f
// TotalBatches: %ld iterBatchCount %ld\n", ((int)A_LAYOUT == 0 ? 'T' : 'N'),
// ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, SRC_A,SRC_B,DST_C, lda, ldb,
// ldc, strideA, strideB, strideC, alpha, beta, batchesLeft, iterBatchCount);
int result =
params.initialize(m, // M dimension for each batch
n, // N dimension for each batch
k, // K dimension for each batch
alpha, // scalar alpha
a, lda,
strideA, // distance in memory between the first
// element of neighboring batch
b, ldb,
strideB, // distance in memory between the first
// element of neighboring batch
beta, // scalar beta
c, // source matrix C
ldc,
strideC, // distance in memory between the first
// element of neighboring batch
c, // destination matrix C (may be different memory
// than source C matrix)
ldc,
strideC, // distance in memory between the first
// element of neighboring batch
iterBatchCount);
AT_ASSERTM(result == 0,
"Failed to initialize CUTLASS Gemm::Params object.");
// Launch the CUTLASS GEMM kernel. // Launch the CUTLASS GEMM kernel.
C10_CUDA_CHECK(Gemm::launch(params, stream)); C10_CUDA_CHECK(Gemm::launch(params, stream));
...@@ -139,269 +145,490 @@ void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k, ...@@ -139,269 +145,490 @@ void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k,
batchesLeft = batchesLeft - iterBatchCount; batchesLeft = batchesLeft - iterBatchCount;
a += iterBatchCount * strideA; a += iterBatchCount * strideA;
b += iterBatchCount * strideB; b += iterBatchCount * strideB;
c += iterBatchCount * strideC;; c += iterBatchCount * strideC;
;
iterBatchCount = std::min(batchesLeft, static_cast<long>((1 << 16) - 1)); iterBatchCount = std::min(batchesLeft, static_cast<long>((1 << 16) - 1));
} while(batchesLeft > 0); } while (batchesLeft > 0);
} }
void gemm_switch_fp32accum(THCState *state, char transa, char transb, long m, long n, long k, void gemm_switch_fp32accum(char transa, char transb, long m,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, long n, long k, float alpha, const half *a, long lda,
float beta, half *c, long ldc, long strideC, long batchCount) { long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC,
long batchCount) {
auto stream = c10::cuda::getCurrentCUDAStream(); auto stream = c10::cuda::getCurrentCUDAStream();
//printf("GEMM -> %c%c M: %i N: %i K: %i Alpha: %f Beta: %f\n", (transa == 't' ? 'T' : 'N'), (transb =='t' ? 'T' : 'N'), m, n, k, alpha, beta); // printf("GEMM -> %c%c M: %i N: %i K: %i Alpha: %f Beta: %f\n", (transa ==
if ( (transa == 't') && (transb == 'n') ) { // 't' ? 'T' : 'N'), (transb =='t' ? 'T' : 'N'), m, n, k, alpha, beta);
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); } if ((transa == 't') && (transb == 'n')) {
/*if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {
int m_rem = m % 64; CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda,
int n_rem = n % 64; strideA, b, ldb, strideB, beta, c, ldc, strideC,
if ( (m_rem > 48) && ( m <= 192) && (n_rem > 48) && (n <= 192 ) ) { batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); }
} else if ( (m_rem > 32) && ( m > 192) && (n_rem > 32) && (n > 192) ) { else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) {
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
} else { cutlass::MatrixLayout::kColumnMajor, 8, 8, 4>(
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
} ldc, strideC, batchCount);
}*/ } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) {
else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } cutlass::MatrixLayout::kColumnMajor, 8, 8, 2>(
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } ldc, strideC, batchCount);
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) {
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } cutlass::MatrixLayout::kColumnMajor, 8, 4, 8>(
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } ldc, strideC, batchCount);
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) {
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } cutlass::MatrixLayout::kColumnMajor, 8, 4, 4>(
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } ldc, strideC, batchCount);
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) {
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } cutlass::MatrixLayout::kColumnMajor, 8, 4, 2>(
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } ldc, strideC, batchCount);
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) {
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } cutlass::MatrixLayout::kColumnMajor, 8, 2, 8>(
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } ldc, strideC, batchCount);
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) {
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
else { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } cutlass::MatrixLayout::kColumnMajor, 8, 2, 4>(
} else if ( (transa == 'n') && (transb == 'n') ) { stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); } ldc, strideC, batchCount);
/*if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) {
int m_rem = m % 64; CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
int n_rem = n % 64; cutlass::MatrixLayout::kColumnMajor, 8, 2, 2>(
if ( (m_rem > 48) && ( m <= 192) && (n_rem > 48) && (n <= 192 ) ) { stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); ldc, strideC, batchCount);
} else if ( (m_rem > 32) && ( m > 192) && (n_rem > 32) && (n > 192) ) { } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) {
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
} else { cutlass::MatrixLayout::kColumnMajor, 4, 8, 8>(
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
} ldc, strideC, batchCount);
}*/ } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) {
else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } cutlass::MatrixLayout::kColumnMajor, 4, 8, 4>(
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } ldc, strideC, batchCount);
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) {
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } cutlass::MatrixLayout::kColumnMajor, 4, 8, 2>(
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } ldc, strideC, batchCount);
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) {
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } cutlass::MatrixLayout::kColumnMajor, 4, 4, 8>(
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } ldc, strideC, batchCount);
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) {
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } cutlass::MatrixLayout::kColumnMajor, 4, 4, 4>(
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } ldc, strideC, batchCount);
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) {
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } cutlass::MatrixLayout::kColumnMajor, 4, 4, 2>(
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } ldc, strideC, batchCount);
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) {
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
else { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } cutlass::MatrixLayout::kColumnMajor, 4, 2, 8>(
} else if ( (transa == 'n') && (transb == 't') ) { stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); } ldc, strideC, batchCount);
/*if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) {
int m_rem = m % 64; CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
int n_rem = n % 64; cutlass::MatrixLayout::kColumnMajor, 4, 2, 4>(
if ( (m_rem > 48) && ( m <= 192) && (n_rem > 48) && (n <= 192 ) ) { stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); ldc, strideC, batchCount);
} else if ( (m_rem > 32) && ( m > 192) && (n_rem > 32) && (n > 192) ) { } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) {
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
} else { cutlass::MatrixLayout::kColumnMajor, 4, 2, 2>(
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
} ldc, strideC, batchCount);
}*/ } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) {
else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } cutlass::MatrixLayout::kColumnMajor, 2, 8, 8>(
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } ldc, strideC, batchCount);
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) {
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } cutlass::MatrixLayout::kColumnMajor, 2, 8, 4>(
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } ldc, strideC, batchCount);
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) {
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } cutlass::MatrixLayout::kColumnMajor, 2, 8, 2>(
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } ldc, strideC, batchCount);
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) {
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } cutlass::MatrixLayout::kColumnMajor, 2, 4, 8>(
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } ldc, strideC, batchCount);
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) {
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } cutlass::MatrixLayout::kColumnMajor, 2, 4, 4>(
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } ldc, strideC, batchCount);
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) {
else { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor, 2, 4, 2>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor, 2, 2, 8>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor, 2, 2, 4>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor, 2, 2, 2>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else {
CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda,
strideA, b, ldb, strideB, beta, c, ldc, strideC,
batchCount);
}
} else if ((transa == 'n') && (transb == 'n')) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {
CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda,
strideA, b, ldb, strideB, beta, c, ldc, strideC,
batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
}
else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 8, 8, 4>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 8, 8, 2>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 8, 4, 8>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 8, 4, 4>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 8, 4, 2>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 8, 2, 8>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 8, 2, 4>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 8, 2, 2>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 4, 8, 8>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 4, 8, 4>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 4, 8, 2>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 4, 4, 8>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 4, 4, 4>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 4, 4, 2>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 4, 2, 8>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 4, 2, 4>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 4, 2, 2>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 2, 8, 8>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 2, 8, 4>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 2, 8, 2>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 2, 4, 8>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 2, 4, 4>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 2, 4, 2>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 2, 2, 8>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 2, 2, 4>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, 2, 2, 2>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else {
CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda,
strideA, b, ldb, strideB, beta, c, ldc, strideC,
batchCount);
}
} else if ((transa == 'n') && (transb == 't')) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {
CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda,
strideA, b, ldb, strideB, beta, c, ldc, strideC,
batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
}
else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 8, 8, 4>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 8, 8, 2>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 8, 4, 8>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 8, 4, 4>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 8, 4, 2>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 8, 2, 8>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 8, 2, 4>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 8, 2, 2>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 4, 8, 8>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 4, 8, 4>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 4, 8, 2>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 4, 4, 8>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 4, 4, 4>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 4, 2, 8>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 4, 2, 4>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 4, 2, 2>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 2, 8, 8>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 2, 8, 4>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 2, 8, 2>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 2, 4, 8>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 2, 4, 4>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 2, 4, 2>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 2, 2, 8>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 2, 2, 4>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, 2, 2, 2>(
stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
ldc, strideC, batchCount);
} else {
CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda,
strideA, b, ldb, strideB, beta, c, ldc, strideC,
batchCount);
}
} else { } else {
AT_ASSERTM(false, "TransA and TransB are invalid"); AT_ASSERTM(false, "TransA and TransB are invalid");
} }
} }
void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t *lda, int64_t *ldb, int64_t *ldc) void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k,
{ int64_t *lda, int64_t *ldb, int64_t *ldc) {
int transa_ = ((transa == 't') || (transa == 'T')); int transa_ = ((transa == 't') || (transa == 'T'));
int transb_ = ((transb == 't') || (transb == 'T')); int transb_ = ((transb == 't') || (transb == 'T'));
// Note: leading dimensions generally are checked that they are > 0 and at least as big the result // Note: leading dimensions generally are checked that they are > 0 and at
// requires (even if the value won't be used). // least as big the result requires (even if the value won't be used).
if(n <= 1) if (n <= 1)
*ldc = std::max<int64_t>(m, 1); *ldc = std::max<int64_t>(m, 1);
if(transa_) if (transa_) {
{ if (m <= 1)
if(m <= 1)
*lda = std::max<int64_t>(k, 1); *lda = std::max<int64_t>(k, 1);
} } else {
else if (k <= 1)
{
if(k <= 1)
*lda = std::max<int64_t>(m, 1); *lda = std::max<int64_t>(m, 1);
} }
if(transb_) if (transb_) {
{ if (k <= 1)
if(k <= 1)
*ldb = std::max<int64_t>(n, 1); *ldb = std::max<int64_t>(n, 1);
} } else {
else if (n <= 1)
{
if(n <= 1)
*ldb = std::max<int64_t>(k, 1); *ldb = std::max<int64_t>(k, 1);
} }
} }
void HgemmStridedBatched(THCState *state, char transa, char transb, long m, long n, long k, void HgemmStridedBatched(char transa, char transb, long m,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, long n, long k, float alpha, const half *a, long lda,
float beta, half *c, long ldc, long strideC, long batchCount) long strideA, const half *b, long ldb, long strideB,
{ float beta, half *c, long ldc, long strideC,
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) long batchCount) {
if ((m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) ||
(ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX))
{ {
AT_ERROR("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount" AT_ERROR("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, "
"with the bound [val] <= %d", INT_MAX); "batchCount"
"with the bound [val] <= %d",
INT_MAX);
} }
adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
//gemm_switch(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA,
gemm_switch_fp32accum(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); b, ldb, strideB, beta, c, ldc, strideC, batchCount);
}
/******
at::Tensor strided_batched_gemm_cuda(
float beta,
at::Tensor in_result,
float alpha,
at::Tensor batch1,
at::Tensor batch2) {
bool transpose_result;
char transpose_batch1, transpose_batch2;
int64_t lda, ldb, ldc;
at::Tensor result, input1, input2;
if (in_result.stride(1) == 1)
{
transpose_result = false;
result = in_result;
ldc = result.stride(2);
}
else if (in_result.stride(2) == 1)
{
transpose_result = true;
at::Tensor swap = batch2;
batch2 = batch1;
batch1 = swap;
result = in_result;
ldc = result.stride(1);
} else {
AT_ASSERTM(false, "result should be contiguous");
}
if (batch1.stride(transpose_result ? 2 : 1) == 1 &&
batch1.stride(transpose_result ? 1 : 2) != 0) {
transpose_batch1 = 'n';
input1 = batch1;
lda = input1.stride(transpose_result ? 1 : 2);
} else if (batch1.stride(transpose_result ? 1 : 2) == 1 &&
batch1.stride(transpose_result ? 2 : 1) != 0) {
transpose_batch1 = 't';
input1 = batch1;
lda = input1.stride(transpose_result ? 2 : 1);
} else {
AT_ASSERTM(false, "input1 should be contiguous");
}
if (batch2.stride(transpose_result ? 2 : 1) == 1 &&
batch2.stride(transpose_result ? 1 : 2) != 0) {
transpose_batch2 = 'n';
input2 = batch2;
ldb = input2.stride(transpose_result ? 1 : 2);
} else if (batch2.stride(transpose_result ? 1 : 2) == 1 &&
batch2.stride(transpose_result ? 2 : 1) != 0) {
transpose_batch2 = 't';
input2 = batch2;
ldb = input2.stride(transpose_result ? 2 : 1);
} else {
AT_ASSERTM(false, "input2 should be contiguous");
}
int64_t num_batches = result.size(0);
HgemmStridedBatched(
state,
transpose_batch1,
transpose_batch2,
result.size(transpose_result ? 2 : 1),
result.size(transpose_result ? 1 : 2),
input1.size(transpose_result ? 1 : 2),
alpha,
static_cast<const half*>(input1.data_ptr()), lda, input1.stride(0),
static_cast<const half*>(input2.data_ptr()), ldb, input2.stride(0),
beta,
static_cast<half*>(result.data_ptr()), ldc, result.stride(0),
num_batches);
return in_result;
} }
***/
...@@ -160,7 +160,7 @@ class SelfMultiheadAttn(nn.Module): ...@@ -160,7 +160,7 @@ class SelfMultiheadAttn(nn.Module):
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results,
input_weights, self.out_proj_weight, input_weights, self.out_proj_weight,
input_bias, self.out_proj_bias, input_bias, self.out_proj_bias,
mask, self.dropout) mask, self.mask_additive, self.dropout)
if is_training: if is_training:
outputs = jit_dropout_add(outputs, query, self.dropout, is_training) outputs = jit_dropout_add(outputs, query, self.dropout, is_training)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment