Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
1203099a
Unverified
Commit
1203099a
authored
Dec 09, 2021
by
Masaki Kozuki
Committed by
GitHub
Dec 09, 2021
Browse files
Remove `THCState` from `apex/contrib/multihead_attn` (#1239)
* pass `self.mask_additive` * clang-format * removing THCState
parent
3c8f5161
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
562 additions
and
335 deletions
+562
-335
apex/contrib/csrc/multihead_attn/strided_batched_gemm.h
apex/contrib/csrc/multihead_attn/strided_batched_gemm.h
+561
-334
apex/contrib/multihead_attn/self_multihead_attn.py
apex/contrib/multihead_attn/self_multihead_attn.py
+1
-1
No files found.
apex/contrib/csrc/multihead_attn/strided_batched_gemm.h
View file @
1203099a
#include <vector>
#include <iostream>
#include <iostream>
#include <vector>
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
//#include <ATen/ATen.h>
//#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/Exceptions.h>
#include "cutlass/cutlass.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/wmma_gemm_traits.h"
#include "cutlass/gemm/wmma_gemm_traits.h"
// symbol to be automatically resolved by PyTorch libs
extern
THCState
*
state
;
cublasOperation_t
convertTransToCublasOperation
(
char
trans
)
{
cublasOperation_t
convertTransToCublasOperation
(
char
trans
)
{
if
(
trans
==
't'
)
return
CUBLAS_OP_T
;
if
(
trans
==
't'
)
else
if
(
trans
==
'n'
)
return
CUBLAS_OP_N
;
return
CUBLAS_OP_T
;
else
if
(
trans
==
'c'
)
return
CUBLAS_OP_C
;
else
if
(
trans
==
'n'
)
return
CUBLAS_OP_N
;
else
if
(
trans
==
'c'
)
return
CUBLAS_OP_C
;
else
{
else
{
AT_ERROR
(
"trans must be one of: t, n, c"
);
AT_ERROR
(
"trans must be one of: t, n, c"
);
return
CUBLAS_OP_T
;
return
CUBLAS_OP_T
;
}
}
}
}
void
CublasStridedBatchedGemm
(
THCState
*
state
,
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
void
CublasStridedBatchedGemm
(
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
beta
,
half
*
c
,
long
ld
c
,
long
stride
C
,
long
batchCount
,
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DEFAULT_TENSOR_OP
)
{
float
alpha
,
const
half
*
a
,
long
ld
a
,
long
stride
A
,
const
half
*
b
,
long
ldb
,
cublasOperation_t
opa
=
convertTransToCublasOperation
(
transa
);
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
long
batchCount
,
cublas
Operation_t
opb
=
convertTransToCublasOperation
(
transb
);
cublas
GemmAlgo_t
algo
=
CUBLAS_GEMM_DEFAULT_TENSOR_OP
)
{
cublasOperation_t
opa
=
convertTransToCublasOperation
(
transa
);
cublas
Handle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
(
);
cublas
Operation_t
opb
=
convertTransToCublasOperation
(
transb
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublas
SetStream
(
handle
,
stream
);
cublas
Handle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
(
);
float
fAlpha
=
alpha
;
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
()
;
float
fBeta
=
beta
;
cublasSetStream
(
handle
,
stream
)
;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH))
;
float
fAlpha
=
alpha
;
TORCH_CUDABLAS_CHECK
(
cublasGemmStridedBatchedEx
(
handle
,
float
fBeta
=
beta
;
opa
,
opb
,
(
int
)
m
,
(
int
)
n
,
(
int
)
k
,
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
(
void
*
)
&
fAlpha
,
a
,
CUDA_R_16F
,
(
int
)
lda
,
strideA
,
TORCH_CUDABLAS_CHECK
(
cublasGemmStridedBatchedEx
(
b
,
CUDA_R_16F
,
(
int
)
ldb
,
strideB
,
handle
,
opa
,
opb
,
(
int
)
m
,
(
int
)
n
,
(
int
)
k
,
(
void
*
)
&
fAlpha
,
a
,
CUDA_R_16F
,
(
void
*
)
&
fBeta
,
c
,
CUDA_R_16F
,
(
int
)
ld
c
,
stride
C
,
(
int
)
lda
,
strideA
,
b
,
CUDA_R_16F
,
(
int
)
ld
b
,
stride
B
,
(
void
*
)
&
fBeta
,
c
,
(
int
)
batchCount
,
CUDA_R_32F
,
algo
));
CUDA_R_16F
,
(
int
)
ldc
,
strideC
,
(
int
)
batchCount
,
CUDA_R_32F
,
algo
));
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
}
}
template
<
cutlass
::
MatrixLayout
::
Kind
A_LAYOUT
,
cutlass
::
MatrixLayout
::
Kind
B_LAYOUT
,
int
SRC_A
,
int
SRC_B
,
int
DST_C
>
template
<
cutlass
::
MatrixLayout
::
Kind
A_LAYOUT
,
cutlass
::
MatrixLayout
::
Kind
B_LAYOUT
,
int
SRC_A
,
int
SRC_B
,
int
DST_C
>
void
CutlassGemm_FP32Accum
(
cudaStream_t
stream
,
long
m
,
long
n
,
long
k
,
void
CutlassGemm_FP32Accum
(
cudaStream_t
stream
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
long
batchCount
)
{
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
//printf("CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC: %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f\n", ((int)A_LAYOUT == 0 ? 'T' : 'N'), ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, SRC_A,SRC_B,DST_C, lda, ldb, ldc, strideA, strideB, strideC, alpha, beta);
half
*
c
,
long
ldc
,
long
strideC
,
long
batchCount
)
{
// printf("CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC:
// %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f\n",
// ((int)A_LAYOUT == 0 ? 'T' : 'N'), ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k,
// SRC_A,SRC_B,DST_C, lda, ldb, ldc, strideA, strideB, strideC, alpha, beta);
typedef
cutlass
::
gemm
::
WmmaGemmTraits
<
typedef
cutlass
::
gemm
::
WmmaGemmTraits
<
A_LAYOUT
,
A_LAYOUT
,
B_LAYOUT
,
cutlass
::
Shape
<
32
,
16
,
16
>
,
half
,
half
,
half
,
B_LAYOUT
,
cutlass
::
gemm
::
LinearScaling
<
float
>
,
float
,
cutlass
::
Shape
<
32
,
16
,
16
>
,
typename
cutlass
::
gemm
::
WmmaGemmAccumulatorsPerWarp
<
half
,
typename
cutlass
::
Shape
<
32
,
16
,
16
>>::
Shape
,
half
,
typename
cutlass
::
Shape
<
16
,
16
,
16
>
,
half
,
SRC_A
,
// kScalarsPerLdgA_
cutlass
::
gemm
::
LinearScaling
<
float
>
,
SRC_B
,
// kScalarsPerLdgB_
float
,
SRC_A
,
// KScalarsPerLdsA_
typename
cutlass
::
gemm
::
WmmaGemmAccumulatorsPerWarp
<
typename
cutlass
::
Shape
<
32
,
16
,
16
>
>::
Shape
,
SRC_B
,
// KScalarsPerLdsB_
typename
cutlass
::
Shape
<
16
,
16
,
16
>
,
DST_C
,
// kScalarsPerLdgCAndStgD_
SRC_A
,
//kScalarsPerLdgA_
DST_C
/
2
,
// kScalarsPerStsD_
SRC_B
,
//kScalarsPerLdgB_
DST_C
/
2
// kScalarsPerLdsD_
SRC_A
,
//KScalarsPerLdsA_
>
SRC_B
,
//KScalarsPerLdsB_
WmmaGemmTraits
;
DST_C
,
//kScalarsPerLdgCAndStgD_
DST_C
/
2
,
//kScalarsPerStsD_
DST_C
/
2
//kScalarsPerLdsD_
>
WmmaGemmTraits
;
typedef
cutlass
::
gemm
::
Gemm
<
WmmaGemmTraits
>
Gemm
;
typedef
cutlass
::
gemm
::
Gemm
<
WmmaGemmTraits
>
Gemm
;
typename
Gemm
::
Params
params
;
typename
Gemm
::
Params
params
;
int
result
=
params
.
initialize
(
int
result
=
params
.
initialize
(
m
,
// M dimension for each batch
m
,
// M dimension for each batch
n
,
// N dimension for each batch
n
,
// N dimension for each batch
k
,
// K dimension for each batch
k
,
// K dimension for each batch
alpha
,
// scalar alpha
alpha
,
// scalar alpha
a
,
a
,
lda
,
lda
,
strideA
,
// distance in memory between the first element of neighboring
strideA
,
// distance in memory between the first element of neighboring batch
// batch
b
,
b
,
ldb
,
ldb
,
strideB
,
// distance in memory between the first element of neighboring
strideB
,
// distance in memory between the first element of neighboring batch
// batch
beta
,
// scalar beta
beta
,
// scalar beta
c
,
// source matrix C
c
,
// source matrix C
ldc
,
ldc
,
strideC
,
// distance in memory between the first element of neighboring batch
strideC
,
// distance in memory between the first element of neighboring
c
,
// destination matrix C (may be different memory than source C matrix)
// batch
ldc
,
c
,
// destination matrix C (may be different memory than source C matrix)
strideC
,
// distance in memory between the first element of neighboring batch
ldc
,
batchCount
strideC
,
// distance in memory between the first element of neighboring
);
// batch
batchCount
);
AT_ASSERTM
(
result
==
0
,
"Failed to initialize CUTLASS Gemm::Params object."
);
AT_ASSERTM
(
result
==
0
,
"Failed to initialize CUTLASS Gemm::Params object."
);
// batchCount in cutlass batched GEMM kernels maps to gridDim.z, which is
limited to 16 bits.
// batchCount in cutlass batched GEMM kernels maps to gridDim.z, which is
// To implement batched GEMM with larger batch size, we
fragment it into
//
limited to 16 bits.
To implement batched GEMM with larger batch size, we
// smaller batched GEMMs of gridDim.z <= 64k
//
fragment it into
smaller batched GEMMs of gridDim.z <= 64k
long
batchesLeft
=
batchCount
;
long
batchesLeft
=
batchCount
;
long
iterBatchCount
=
std
::
min
(
batchesLeft
,
static_cast
<
long
>
((
1
<<
16
)
-
1
));
long
iterBatchCount
=
std
::
min
(
batchesLeft
,
static_cast
<
long
>
((
1
<<
16
)
-
1
));
do
{
//printf("CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC: %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f TotalBatches: %ld iterBatchCount %ld\n", ((int)A_LAYOUT == 0 ? 'T' : 'N'), ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, SRC_A,SRC_B,DST_C, lda, ldb, ldc, strideA, strideB, strideC, alpha, beta, batchesLeft, iterBatchCount);
int
result
=
params
.
initialize
(
m
,
// M dimension for each batch
n
,
// N dimension for each batch
k
,
// K dimension for each batch
alpha
,
// scalar alpha
a
,
lda
,
strideA
,
// distance in memory between the first element of neighboring batch
b
,
ldb
,
strideB
,
// distance in memory between the first element of neighboring batch
beta
,
// scalar beta
c
,
// source matrix C
ldc
,
strideC
,
// distance in memory between the first element of neighboring batch
c
,
// destination matrix C (may be different memory than source C matrix)
ldc
,
strideC
,
// distance in memory between the first element of neighboring batch
iterBatchCount
);
AT_ASSERTM
(
result
==
0
,
"Failed to initialize CUTLASS Gemm::Params object."
);
do
{
// printf("CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC:
// %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f
// TotalBatches: %ld iterBatchCount %ld\n", ((int)A_LAYOUT == 0 ? 'T' : 'N'),
// ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, SRC_A,SRC_B,DST_C, lda, ldb,
// ldc, strideA, strideB, strideC, alpha, beta, batchesLeft, iterBatchCount);
int
result
=
params
.
initialize
(
m
,
// M dimension for each batch
n
,
// N dimension for each batch
k
,
// K dimension for each batch
alpha
,
// scalar alpha
a
,
lda
,
strideA
,
// distance in memory between the first
// element of neighboring batch
b
,
ldb
,
strideB
,
// distance in memory between the first
// element of neighboring batch
beta
,
// scalar beta
c
,
// source matrix C
ldc
,
strideC
,
// distance in memory between the first
// element of neighboring batch
c
,
// destination matrix C (may be different memory
// than source C matrix)
ldc
,
strideC
,
// distance in memory between the first
// element of neighboring batch
iterBatchCount
);
AT_ASSERTM
(
result
==
0
,
"Failed to initialize CUTLASS Gemm::Params object."
);
// Launch the CUTLASS GEMM kernel.
// Launch the CUTLASS GEMM kernel.
C10_CUDA_CHECK
(
Gemm
::
launch
(
params
,
stream
));
C10_CUDA_CHECK
(
Gemm
::
launch
(
params
,
stream
));
...
@@ -139,269 +145,490 @@ void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k,
...
@@ -139,269 +145,490 @@ void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k,
batchesLeft
=
batchesLeft
-
iterBatchCount
;
batchesLeft
=
batchesLeft
-
iterBatchCount
;
a
+=
iterBatchCount
*
strideA
;
a
+=
iterBatchCount
*
strideA
;
b
+=
iterBatchCount
*
strideB
;
b
+=
iterBatchCount
*
strideB
;
c
+=
iterBatchCount
*
strideC
;;
c
+=
iterBatchCount
*
strideC
;
;
iterBatchCount
=
std
::
min
(
batchesLeft
,
static_cast
<
long
>
((
1
<<
16
)
-
1
));
iterBatchCount
=
std
::
min
(
batchesLeft
,
static_cast
<
long
>
((
1
<<
16
)
-
1
));
}
while
(
batchesLeft
>
0
);
}
while
(
batchesLeft
>
0
);
}
}
void
gemm_switch_fp32accum
(
THCState
*
state
,
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
void
gemm_switch_fp32accum
(
char
transa
,
char
transb
,
long
m
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
long
batchCount
)
{
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
long
batchCount
)
{
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
();
//printf("GEMM -> %c%c M: %i N: %i K: %i Alpha: %f Beta: %f\n", (transa == 't' ? 'T' : 'N'), (transb =='t' ? 'T' : 'N'), m, n, k, alpha, beta);
// printf("GEMM -> %c%c M: %i N: %i K: %i Alpha: %f Beta: %f\n", (transa ==
if
(
(
transa
==
't'
)
&&
(
transb
==
'n'
)
)
{
// 't' ? 'T' : 'N'), (transb =='t' ? 'T' : 'N'), m, n, k, alpha, beta);
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CublasStridedBatchedGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
,
CUBLAS_GEMM_ALGO0_TENSOR_OP
);
}
if
((
transa
==
't'
)
&&
(
transb
==
'n'
))
{
/*if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
int m_rem = m % 64;
CublasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
int n_rem = n % 64;
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
if ( (m_rem > 48) && ( m <= 192) && (n_rem > 48) && (n <= 192 ) ) {
batchCount
,
CUBLAS_GEMM_ALGO0_TENSOR_OP
);
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
}
} else if ( (m_rem > 32) && ( m > 192) && (n_rem > 32) && (n > 192) ) {
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
} else {
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
8
,
4
>
(
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
}
ldc
,
strideC
,
batchCount
);
}*/
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
8
,
2
>
(
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
ldc
,
strideC
,
batchCount
);
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
8
>
(
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
ldc
,
strideC
,
batchCount
);
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
4
>
(
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
ldc
,
strideC
,
batchCount
);
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
2
>
(
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
ldc
,
strideC
,
batchCount
);
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
8
>
(
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
ldc
,
strideC
,
batchCount
);
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
else
{
CublasStridedBatchedGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
4
>
(
}
else
if
(
(
transa
==
'n'
)
&&
(
transb
==
'n'
)
)
{
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CublasStridedBatchedGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
,
CUBLAS_GEMM_ALGO0_TENSOR_OP
);
}
ldc
,
strideC
,
batchCount
);
/*if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
int m_rem = m % 64;
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
int n_rem = n % 64;
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
2
>
(
if ( (m_rem > 48) && ( m <= 192) && (n_rem > 48) && (n <= 192 ) ) {
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
ldc
,
strideC
,
batchCount
);
} else if ( (m_rem > 32) && ( m > 192) && (n_rem > 32) && (n > 192) ) {
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
} else {
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
8
>
(
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
}
ldc
,
strideC
,
batchCount
);
}*/
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
4
>
(
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
ldc
,
strideC
,
batchCount
);
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
2
>
(
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
ldc
,
strideC
,
batchCount
);
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
8
>
(
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
ldc
,
strideC
,
batchCount
);
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
4
>
(
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
ldc
,
strideC
,
batchCount
);
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
2
>
(
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
ldc
,
strideC
,
batchCount
);
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
else
{
CublasStridedBatchedGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
8
>
(
}
else
if
(
(
transa
==
'n'
)
&&
(
transb
==
't'
)
)
{
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CublasStridedBatchedGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
,
CUBLAS_GEMM_ALGO0_TENSOR_OP
);
}
ldc
,
strideC
,
batchCount
);
/*if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
int m_rem = m % 64;
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
int n_rem = n % 64;
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
4
>
(
if ( (m_rem > 48) && ( m <= 192) && (n_rem > 48) && (n <= 192 ) ) {
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
ldc
,
strideC
,
batchCount
);
} else if ( (m_rem > 32) && ( m > 192) && (n_rem > 32) && (n > 192) ) {
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
} else {
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
2
>
(
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
}
ldc
,
strideC
,
batchCount
);
}*/
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
8
>
(
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
ldc
,
strideC
,
batchCount
);
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
4
>
(
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
ldc
,
strideC
,
batchCount
);
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
2
>
(
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
ldc
,
strideC
,
batchCount
);
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
8
>
(
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
ldc
,
strideC
,
batchCount
);
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
4
>
(
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
ldc
,
strideC
,
batchCount
);
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
else
{
CublasStridedBatchedGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
{
CublasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
if
((
transa
==
'n'
)
&&
(
transb
==
'n'
))
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CublasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
,
CUBLAS_GEMM_ALGO0_TENSOR_OP
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
{
CublasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
if
((
transa
==
'n'
)
&&
(
transb
==
't'
))
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CublasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
,
CUBLAS_GEMM_ALGO0_TENSOR_OP
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
{
CublasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
{
}
else
{
AT_ASSERTM
(
false
,
"TransA and TransB are invalid"
);
AT_ASSERTM
(
false
,
"TransA and TransB are invalid"
);
}
}
}
}
void
adjustLdLevel3
(
char
transa
,
char
transb
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
int64_t
*
lda
,
int64_t
*
ldb
,
int64_t
*
ldc
)
void
adjustLdLevel3
(
char
transa
,
char
transb
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
{
int64_t
*
lda
,
int64_t
*
ldb
,
int64_t
*
ldc
)
{
int
transa_
=
((
transa
==
't'
)
||
(
transa
==
'T'
));
int
transa_
=
((
transa
==
't'
)
||
(
transa
==
'T'
));
int
transb_
=
((
transb
==
't'
)
||
(
transb
==
'T'
));
int
transb_
=
((
transb
==
't'
)
||
(
transb
==
'T'
));
// Note: leading dimensions generally are checked that they are > 0 and at
least as big the result
// Note: leading dimensions generally are checked that they are > 0 and at
// requires (even if the value won't be used).
//
least as big the result
requires (even if the value won't be used).
if
(
n
<=
1
)
if
(
n
<=
1
)
*
ldc
=
std
::
max
<
int64_t
>
(
m
,
1
);
*
ldc
=
std
::
max
<
int64_t
>
(
m
,
1
);
if
(
transa_
)
if
(
transa_
)
{
{
if
(
m
<=
1
)
if
(
m
<=
1
)
*
lda
=
std
::
max
<
int64_t
>
(
k
,
1
);
*
lda
=
std
::
max
<
int64_t
>
(
k
,
1
);
}
}
else
{
else
if
(
k
<=
1
)
{
if
(
k
<=
1
)
*
lda
=
std
::
max
<
int64_t
>
(
m
,
1
);
*
lda
=
std
::
max
<
int64_t
>
(
m
,
1
);
}
}
if
(
transb_
)
if
(
transb_
)
{
{
if
(
k
<=
1
)
if
(
k
<=
1
)
*
ldb
=
std
::
max
<
int64_t
>
(
n
,
1
);
*
ldb
=
std
::
max
<
int64_t
>
(
n
,
1
);
}
}
else
{
else
if
(
n
<=
1
)
{
if
(
n
<=
1
)
*
ldb
=
std
::
max
<
int64_t
>
(
k
,
1
);
*
ldb
=
std
::
max
<
int64_t
>
(
k
,
1
);
}
}
}
}
void
HgemmStridedBatched
(
THCState
*
state
,
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
void
HgemmStridedBatched
(
char
transa
,
char
transb
,
long
m
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
long
batchCount
)
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
{
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
if
(
(
m
>=
INT_MAX
)
||
(
n
>=
INT_MAX
)
||
(
k
>=
INT_MAX
)
||
(
lda
>=
INT_MAX
)
||
(
ldb
>=
INT_MAX
)
||
(
ldc
>=
INT_MAX
)
||
(
batchCount
>=
INT_MAX
)
)
long
batchCount
)
{
if
((
m
>=
INT_MAX
)
||
(
n
>=
INT_MAX
)
||
(
k
>=
INT_MAX
)
||
(
lda
>=
INT_MAX
)
||
(
ldb
>=
INT_MAX
)
||
(
ldc
>=
INT_MAX
)
||
(
batchCount
>=
INT_MAX
))
{
{
AT_ERROR
(
"Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount"
AT_ERROR
(
"Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, "
"with the bound [val] <= %d"
,
INT_MAX
);
"batchCount"
"with the bound [val] <= %d"
,
INT_MAX
);
}
}
adjustLdLevel3
(
transa
,
transb
,
m
,
n
,
k
,
&
lda
,
&
ldb
,
&
ldc
);
adjustLdLevel3
(
transa
,
transb
,
m
,
n
,
k
,
&
lda
,
&
ldb
,
&
ldc
);
//gemm_switch(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
gemm_switch_fp32accum
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
gemm_switch_fp32accum
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
/******
at::Tensor strided_batched_gemm_cuda(
float beta,
at::Tensor in_result,
float alpha,
at::Tensor batch1,
at::Tensor batch2) {
bool transpose_result;
char transpose_batch1, transpose_batch2;
int64_t lda, ldb, ldc;
at::Tensor result, input1, input2;
if (in_result.stride(1) == 1)
{
transpose_result = false;
result = in_result;
ldc = result.stride(2);
}
else if (in_result.stride(2) == 1)
{
transpose_result = true;
at::Tensor swap = batch2;
batch2 = batch1;
batch1 = swap;
result = in_result;
ldc = result.stride(1);
} else {
AT_ASSERTM(false, "result should be contiguous");
}
if (batch1.stride(transpose_result ? 2 : 1) == 1 &&
batch1.stride(transpose_result ? 1 : 2) != 0) {
transpose_batch1 = 'n';
input1 = batch1;
lda = input1.stride(transpose_result ? 1 : 2);
} else if (batch1.stride(transpose_result ? 1 : 2) == 1 &&
batch1.stride(transpose_result ? 2 : 1) != 0) {
transpose_batch1 = 't';
input1 = batch1;
lda = input1.stride(transpose_result ? 2 : 1);
} else {
AT_ASSERTM(false, "input1 should be contiguous");
}
if (batch2.stride(transpose_result ? 2 : 1) == 1 &&
batch2.stride(transpose_result ? 1 : 2) != 0) {
transpose_batch2 = 'n';
input2 = batch2;
ldb = input2.stride(transpose_result ? 1 : 2);
} else if (batch2.stride(transpose_result ? 1 : 2) == 1 &&
batch2.stride(transpose_result ? 2 : 1) != 0) {
transpose_batch2 = 't';
input2 = batch2;
ldb = input2.stride(transpose_result ? 2 : 1);
} else {
AT_ASSERTM(false, "input2 should be contiguous");
}
int64_t num_batches = result.size(0);
HgemmStridedBatched(
state,
transpose_batch1,
transpose_batch2,
result.size(transpose_result ? 2 : 1),
result.size(transpose_result ? 1 : 2),
input1.size(transpose_result ? 1 : 2),
alpha,
static_cast<const half*>(input1.data_ptr()), lda, input1.stride(0),
static_cast<const half*>(input2.data_ptr()), ldb, input2.stride(0),
beta,
static_cast<half*>(result.data_ptr()), ldc, result.stride(0),
num_batches);
return in_result;
}
}
***/
apex/contrib/multihead_attn/self_multihead_attn.py
View file @
1203099a
...
@@ -160,7 +160,7 @@ class SelfMultiheadAttn(nn.Module):
...
@@ -160,7 +160,7 @@ class SelfMultiheadAttn(nn.Module):
outputs
=
self
.
attn_func
(
attn_mask
is
not
None
,
is_training
,
self
.
num_heads
,
self
.
scaling
,
lyr_nrm_results
,
outputs
=
self
.
attn_func
(
attn_mask
is
not
None
,
is_training
,
self
.
num_heads
,
self
.
scaling
,
lyr_nrm_results
,
input_weights
,
self
.
out_proj_weight
,
input_weights
,
self
.
out_proj_weight
,
input_bias
,
self
.
out_proj_bias
,
input_bias
,
self
.
out_proj_bias
,
mask
,
self
.
dropout
)
mask
,
self
.
mask_additive
,
self
.
dropout
)
if
is_training
:
if
is_training
:
outputs
=
jit_dropout_add
(
outputs
,
query
,
self
.
dropout
,
is_training
)
outputs
=
jit_dropout_add
(
outputs
,
query
,
self
.
dropout
,
is_training
)
else
:
else
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment