Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
49facb91
Commit
49facb91
authored
Nov 07, 2023
by
Harisankar Sadasivan
Browse files
files for gemv and tall and skinny gemm examples and corresponding entries to ckprofiler
parent
98fd41f5
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2786 additions
and
0 deletions
+2786
-0
example/53_gemv_splitk/CMakeLists.txt
example/53_gemv_splitk/CMakeLists.txt
+11
-0
example/53_gemv_splitk/README.md
example/53_gemv_splitk/README.md
+19
-0
example/53_gemv_splitk/common.hpp
example/53_gemv_splitk/common.hpp
+92
-0
example/53_gemv_splitk/gemv_splitk_fp16.cpp
example/53_gemv_splitk/gemv_splitk_fp16.cpp
+42
-0
example/53_gemv_splitk/run_gemv_splitk_example.inc
example/53_gemv_splitk/run_gemv_splitk_example.inc
+196
-0
example/54_tall_and_skinny_gemm_splitk/CMakeLists.txt
example/54_tall_and_skinny_gemm_splitk/CMakeLists.txt
+11
-0
example/54_tall_and_skinny_gemm_splitk/README.md
example/54_tall_and_skinny_gemm_splitk/README.md
+19
-0
example/54_tall_and_skinny_gemm_splitk/common.hpp
example/54_tall_and_skinny_gemm_splitk/common.hpp
+92
-0
example/54_tall_and_skinny_gemm_splitk/run_tall_and_skinny_gemm_splitk_example.inc
...y_gemm_splitk/run_tall_and_skinny_gemm_splitk_example.inc
+194
-0
example/54_tall_and_skinny_gemm_splitk/tall_and_skinny_gemm_splitk_fp16.cpp
...d_skinny_gemm_splitk/tall_and_skinny_gemm_splitk_fp16.cpp
+43
-0
include/ck/tensor_operation/gpu/block/blockwise_tall_and_skinny_gemm.hpp
...or_operation/gpu/block/blockwise_tall_and_skinny_gemm.hpp
+152
-0
include/ck/tensor_operation/gpu/device/device_tall_and_skinny_gemm.hpp
...nsor_operation/gpu/device/device_tall_and_skinny_gemm.hpp
+42
-0
include/ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp
...on/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp
+377
-0
include/ck/tensor_operation/gpu/grid/gridwise_tall_and_skinny_gemm_splitk.hpp
...eration/gpu/grid/gridwise_tall_and_skinny_gemm_splitk.hpp
+772
-0
include/ck/tensor_operation/gpu/thread/threadwise_tall_and_skinny_gemm.hpp
..._operation/gpu/thread/threadwise_tall_and_skinny_gemm.hpp
+94
-0
library/src/tensor_operation_instance/gpu/gemv_splitk/CMakeLists.txt
.../tensor_operation_instance/gpu/gemv_splitk/CMakeLists.txt
+17
-0
library/src/tensor_operation_instance/gpu/gemv_splitk/device_gemv_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
...litk/device_gemv_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
+198
-0
library/src/tensor_operation_instance/gpu/gemv_splitk/device_gemv_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
...litk/device_gemv_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
+197
-0
library/src/tensor_operation_instance/gpu/tall_and_skinny_gemm_splitk/CMakeLists.txt
...n_instance/gpu/tall_and_skinny_gemm_splitk/CMakeLists.txt
+18
-0
library/src/tensor_operation_instance/gpu/tall_and_skinny_gemm_splitk/device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
..._and_skinny_gemm_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
+200
-0
No files found.
example/53_gemv_splitk/CMakeLists.txt
0 → 100755
View file @
49facb91
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
example_gemv_splitk
)
add_example_executable
(
example_gemv_splitk_fp16 gemv_splitk_fp16.cpp
)
add_dependencies
(
example_gemv_splitk
example_gemv_splitk_fp16
)
set
(
target 1
)
endif
()
endforeach
()
example/53_gemv_splitk/README.md
0 → 100755
View file @
49facb91
# Instructions for ```example_gemv_splitk```
## Run ```example_gemv_splitk```
```
bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1)
#arg4: number of splitk batches
bin/example_gemv_splitk_fp16 1 2 1 231
```
Result (MI250 @ 800Mhz, 181.05TFlops peak FP16)
```
a_m_k: dim 2, lengths {1, 4608}, strides {4608, 1}
b_k_n: dim 2, lengths {4608, 1104}, strides {1104, 1}
c_m_n: dim 2, lengths {1, 1104}, strides {1104, 1}
Perf: 0.0111038 ms, 0.916305 TFlops, 917.334 GB/s, deviceTsmmDl<64, 1, 128, 3, 4, 1, 2, 1>
```
example/53_gemv_splitk/common.hpp
0 → 100755
View file @
49facb91
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <numeric>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
struct
ProblemSize
final
// Default GEMV problem size
{
ck
::
index_t
M
=
1
;
ck
::
index_t
N
=
1104
;
ck
::
index_t
K
=
4608
;
ck
::
index_t
stride_A
=
K
;
ck
::
index_t
stride_B
=
N
;
// K;
ck
::
index_t
stride_C
=
N
;
ck
::
index_t
k_batch
=
1
;
};
struct
ExecutionConfig
final
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
};
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
inline
bool
parse_cmd_args
(
int
argc
,
char
*
argv
[],
ProblemSize
&
problem_size
,
ExecutionConfig
&
config
)
{
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
5
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
problem_size
.
k_batch
=
std
::
stoi
(
argv
[
4
]);
}
else
if
(
argc
==
11
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
problem_size
.
k_batch
=
std
::
stoi
(
argv
[
4
]);
problem_size
.
M
=
std
::
stoi
(
argv
[
5
]);
problem_size
.
N
=
std
::
stoi
(
argv
[
6
]);
problem_size
.
K
=
std
::
stoi
(
argv
[
7
]);
problem_size
.
stride_A
=
std
::
stoi
(
argv
[
8
]);
problem_size
.
stride_B
=
std
::
stoi
(
argv
[
9
]);
problem_size
.
stride_C
=
std
::
stoi
(
argv
[
10
]);
}
else
{
std
::
cerr
<<
"arg1: verification (0=no, 1=yes)"
<<
std
::
endl
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC"
<<
std
::
endl
;
return
false
;
}
return
true
;
}
example/53_gemv_splitk/gemv_splitk_fp16.cpp
0 → 100755
View file @
49facb91
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp"
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
ALayout
=
Row
;
using
BLayout
=
Row
;
// Col;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmMNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
#define K1 4
#define K0 3
#define N1 2
#define B 64 // block-size:64
// clang-format off
using
DeviceGemvInstance
=
ck
::
tensor_operation
::
device
::
deviceTsmmDl
/*
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer | ABlockTransfer| ABlockTransfer | BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|SrcVectorTensorLengths| SrcVectorTensor|DstVectorTensorLengths| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 1, 64, 32, 2, 1, 1, 1, S<1, 1, 1, 2>, S<32, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 1>;*/
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmMNPadding
,
B
,
1
,
B
*
N1
,
K0
,
K1
,
1
,
N1
,
1
,
S
<
1
,
1
,
1
,
1
,
K1
>
,
S
<
1
,
K0
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
K1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
N1
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
N1
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
#include "run_gemv_splitk_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemv_example
(
argc
,
argv
);
}
example/53_gemv_splitk/run_gemv_splitk_example.inc
0 → 100755
View file @
49facb91
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
bool
run_gemv
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
{
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert
(
sizeof
(
ck
::
int4_t
)
==
sizeof
(
int8_t
));
#endif
using
namespace
ck
::
literals
;
auto
&
[
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
k_batch
]
=
problem_size
;
// //
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1_
uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1_
uz
,
stride
});
}
};
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
switch
(
config
.
init_method
)
{
case
0
:
break
;
case
1
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
break
;
default
:
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
}
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
#ifdef BUILD_INT4_EXAMPLE
DeviceMem
a_m_k_device_buf
(
sizeof
(
KernelADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
KernelBDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
KernelCDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
const
Tensor
<
KernelADataType
>
a_m_k_converted
(
a_m_k
);
const
Tensor
<
KernelBDataType
>
b_k_n_converted
(
b_k_n
);
a_m_k_device_buf
.
ToDevice
(
a_m_k_converted
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n_converted
.
mData
.
data
());
#else
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
#endif
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
auto
gemv
=
DeviceGemvInstance
{};
auto
invoker
=
gemv
.
MakeInvoker
();
auto
argument
=
gemv
.
MakeArgument
(
#ifdef BUILD_INT4_EXAMPLE
static_cast
<
KernelADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelBDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelCDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
#else
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
#endif
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
k_batch
);
// //
// //
if
(
!
gemv
.
IsSupportedArgument
(
argument
))
{
std
::
cerr
<<
gemv
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
true
;
}
c_m_n_device_buf
.
SetZero
();
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
// Run prior to verification
if
(
config
.
do_verification
)
{
auto
ref_gemv
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemv
.
MakeInvoker
();
auto
ref_argument
=
ref_gemv
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
#ifdef BUILD_INT4_EXAMPLE
Tensor
<
CDataType
>
c_m_n_device_result_converted
(
c_m_n_host_result
.
mDesc
);
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result_converted
.
mData
.
data
());
c_m_n_device_result
=
c_m_n_device_result_converted
.
CopyAsType
<
CDataType
>
();
#else
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
#endif
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
// Run to measure performance
std
::
size_t
flop
=
2_
uz
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemv
.
GetTypeString
()
<<
std
::
endl
;
#ifdef BUILD_INT4_EXAMPLE
return
ck
::
utils
::
check_err
(
c_m_n_device_result_converted
,
c_m_n_host_result
);
#else
return
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
#endif
}
bool
run_gemv_example
(
int
argc
,
char
*
argv
[])
{
ProblemSize
problem_size
;
ExecutionConfig
config
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
5
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
problem_size
.
k_batch
=
std
::
stoi
(
argv
[
4
]);
}
else
if
(
argc
==
11
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
problem_size
.
k_batch
=
std
::
stoi
(
argv
[
4
]);
problem_size
.
M
=
std
::
stoi
(
argv
[
5
]);
problem_size
.
N
=
std
::
stoi
(
argv
[
6
]);
problem_size
.
K
=
std
::
stoi
(
argv
[
7
]);
problem_size
.
stride_A
=
std
::
stoi
(
argv
[
8
]);
problem_size
.
stride_B
=
std
::
stoi
(
argv
[
9
]);
problem_size
.
stride_C
=
std
::
stoi
(
argv
[
10
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4: splitk
\n
"
);
printf
(
"arg5 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
\n
"
);
exit
(
0
);
}
return
run_gemv
(
problem_size
,
config
);
}
example/54_tall_and_skinny_gemm_splitk/CMakeLists.txt
0 → 100755
View file @
49facb91
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
example_tall_and_skinny_gemm_splitk
)
add_example_executable
(
example_tall_and_skinny_gemm_splitk_fp16 tall_and_skinny_gemm_splitk_fp16.cpp
)
add_dependencies
(
example_tall_and_skinny_gemm_splitk
example_tall_and_skinny_gemm_splitk_fp16
)
set
(
target 1
)
endif
()
endforeach
()
\ No newline at end of file
example/54_tall_and_skinny_gemm_splitk/README.md
0 → 100755
View file @
49facb91
# Instructions for ```example_gemv_splitk```
## Run ```example_gemv_splitk```
```
bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1)
#arg4: number of splitk batches
bin/example_tall_and_skinny_gemm_splitk_fp16 1 2 1 231
```
Result (MI250 @ 800Mhz, 181.05TFlops peak FP16)
```
a_m_k: dim 2, lengths {16, 1024}, strides {1024, 1}
b_k_n: dim 2, lengths {1024, 16}, strides {16, 1}
c_m_n: dim 2, lengths {16, 16}, strides {16, 1}
Perf: 0.0065438 ms, 0.0801198 TFlops, 10.0932 GB/s, deviceTsmmDl<64, 16, 128, 4, 2, 16, 2, 1>
```
example/54_tall_and_skinny_gemm_splitk/common.hpp
0 → 100644
View file @
49facb91
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <numeric>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
struct
ProblemSize
final
// Default GEMV problem size
{
ck
::
index_t
M
=
16
;
ck
::
index_t
N
=
16
;
ck
::
index_t
K
=
1024
;
ck
::
index_t
stride_A
=
K
;
ck
::
index_t
stride_B
=
N
;
// K;
ck
::
index_t
stride_C
=
N
;
ck
::
index_t
k_batch
=
1
;
};
struct
ExecutionConfig
final
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
};
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
inline
bool
parse_cmd_args
(
int
argc
,
char
*
argv
[],
ProblemSize
&
problem_size
,
ExecutionConfig
&
config
)
{
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
5
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
problem_size
.
k_batch
=
std
::
stoi
(
argv
[
4
]);
}
else
if
(
argc
==
11
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
problem_size
.
k_batch
=
std
::
stoi
(
argv
[
4
]);
problem_size
.
M
=
std
::
stoi
(
argv
[
5
]);
problem_size
.
N
=
std
::
stoi
(
argv
[
6
]);
problem_size
.
K
=
std
::
stoi
(
argv
[
7
]);
problem_size
.
stride_A
=
std
::
stoi
(
argv
[
8
]);
problem_size
.
stride_B
=
std
::
stoi
(
argv
[
9
]);
problem_size
.
stride_C
=
std
::
stoi
(
argv
[
10
]);
}
else
{
std
::
cerr
<<
"arg1: verification (0=no, 1=yes)"
<<
std
::
endl
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC"
<<
std
::
endl
;
return
false
;
}
return
true
;
}
example/54_tall_and_skinny_gemm_splitk/run_tall_and_skinny_gemm_splitk_example.inc
0 → 100644
View file @
49facb91
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
bool
run_tall_and_skinny_gemm
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
{
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert
(
sizeof
(
ck
::
int4_t
)
==
sizeof
(
int8_t
));
#endif
using
namespace
ck
::
literals
;
auto
&
[
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
k_batch
]
=
problem_size
;
// //
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1_
uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1_
uz
,
stride
});
}
};
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
switch
(
config
.
init_method
)
{
case
0
:
break
;
case
1
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
break
;
default
:
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
}
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
#ifdef BUILD_INT4_EXAMPLE
DeviceMem
a_m_k_device_buf
(
sizeof
(
KernelADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
KernelBDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
KernelCDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
const
Tensor
<
KernelADataType
>
a_m_k_converted
(
a_m_k
);
const
Tensor
<
KernelBDataType
>
b_k_n_converted
(
b_k_n
);
a_m_k_device_buf
.
ToDevice
(
a_m_k_converted
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n_converted
.
mData
.
data
());
#else
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
#endif
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
auto
tsmm
=
DeviceTSMMInstance
{};
auto
invoker
=
tsmm
.
MakeInvoker
();
auto
argument
=
tsmm
.
MakeArgument
(
#ifdef BUILD_INT4_EXAMPLE
static_cast
<
KernelADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelBDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelCDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
#else
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
#endif
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
k_batch
);
// //
// //
if
(
!
tsmm
.
IsSupportedArgument
(
argument
))
{
std
::
cerr
<<
tsmm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
true
;
}
c_m_n_device_buf
.
SetZero
();
if
(
config
.
do_verification
)
{
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
// Run prior to verification
auto
ref_tsmm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_tsmm
.
MakeInvoker
();
auto
ref_argument
=
ref_tsmm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
#ifdef BUILD_INT4_EXAMPLE
Tensor
<
CDataType
>
c_m_n_device_result_converted
(
c_m_n_host_result
.
mDesc
);
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result_converted
.
mData
.
data
());
c_m_n_device_result
=
c_m_n_device_result_converted
.
CopyAsType
<
CDataType
>
();
#else
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
#endif
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
// Run to measure performance
std
::
size_t
flop
=
2_
uz
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
tsmm
.
GetTypeString
()
<<
std
::
endl
;
#ifdef BUILD_INT4_EXAMPLE
return
ck
::
utils
::
check_err
(
c_m_n_device_result_converted
,
c_m_n_host_result
);
#else
return
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
#endif
}
bool
run_tall_and_skinny_gemm_example
(
int
argc
,
char
*
argv
[])
{
ProblemSize
problem_size
;
ExecutionConfig
config
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
5
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
problem_size
.
k_batch
=
std
::
stoi
(
argv
[
4
]);
}
else
if
(
argc
==
11
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
problem_size
.
k_batch
=
std
::
stoi
(
argv
[
4
]);
problem_size
.
M
=
std
::
stoi
(
argv
[
5
]);
problem_size
.
N
=
std
::
stoi
(
argv
[
6
]);
problem_size
.
K
=
std
::
stoi
(
argv
[
7
]);
problem_size
.
stride_A
=
std
::
stoi
(
argv
[
8
]);
problem_size
.
stride_B
=
std
::
stoi
(
argv
[
9
]);
problem_size
.
stride_C
=
std
::
stoi
(
argv
[
10
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4: splitk
\n
"
);
printf
(
"arg5 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
\n
"
);
exit
(
0
);
}
return
run_tall_and_skinny_gemm
(
problem_size
,
config
);
}
example/54_tall_and_skinny_gemm_splitk/tall_and_skinny_gemm_splitk_fp16.cpp
0 → 100644
View file @
49facb91
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp"
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
ALayout
=
Row
;
using
BLayout
=
Row
;
// Col;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmMNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
#define K1 2
#define K0 4
#define N1 2
#define B 64 // block-size:64
#define M1 16
// clang-format off
using
DeviceTSMMInstance
=
ck
::
tensor_operation
::
device
::
deviceTsmmDl
/*
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer | ABlockTransfer| ABlockTransfer | BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|SrcVectorTensorLengths| SrcVectorTensor|DstVectorTensorLengths| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 1, 64, 32, 2, 1, 1, 1, S<1, 1, 1, 2>, S<32, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 1>;*/
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmMNPadding
,
B
,
M1
,
B
*
N1
,
K0
,
K1
,
M1
,
N1
,
1
,
S
<
1
,
1
,
1
,
1
,
K1
>
,
S
<
1
,
K0
,
1
,
M1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
K1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
K1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
N1
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
N1
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
#include "run_tall_and_skinny_gemm_splitk_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_tall_and_skinny_gemm_example
(
argc
,
argv
);
}
include/ck/tensor_operation/gpu/block/blockwise_tall_and_skinny_gemm.hpp
0 → 100755
View file @
49facb91
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tall_and_skinny_gemm.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
ABlockDesc_K0_M_K1
,
typename
BThreadDesc_K0_N_K1
,
index_t
MPerThread
,
index_t
NPerBlock
,
index_t
K0PerLoop
>
struct
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
using
CIndex
=
MultiIndex
<
4
>
;
static
constexpr
auto
K0
=
ABlockDesc_K0_M_K1
{}.
GetLength
(
I0
);
static
constexpr
auto
M
=
ABlockDesc_K0_M_K1
{}.
GetLength
(
I1
);
static
constexpr
auto
K1
=
ABlockDesc_K0_M_K1
{}.
GetLength
(
I2
);
static
constexpr
auto
NPerThread
=
BThreadDesc_K0_N_K1
{}.
GetLength
(
I1
);
static
constexpr
auto
M0
=
M
/
MPerThread
;
static
constexpr
auto
M1
=
MPerThread
;
static
constexpr
auto
N
=
NPerBlock
;
static
constexpr
auto
N0
=
N
/
NPerThread
;
static
constexpr
auto
N1
=
NPerThread
;
static
constexpr
auto
a_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
K0PerLoop
>
{},
Number
<
MPerThread
>
{},
Number
<
K1
>
{}));
static
constexpr
auto
b_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
K0PerLoop
>
{},
Number
<
NPerThread
>
{},
Number
<
K1
>
{}));
static
constexpr
auto
c_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
I1
>
{},
Number
<
M1
>
{},
Number
<
I1
>
{},
Number
<
N1
>
{}));
__device__
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
()
:
c_thread_origin_data_idx_
{
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
get_thread_local_1d_id
())},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I0
]
*
MPerThread
,
0
)}
{
static_assert
(
ABlockDesc_K0_M_K1
::
IsKnownAtCompileTime
()
&&
BThreadDesc_K0_N_K1
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ABlockDesc_K0_M_K1
{}.
GetLength
(
I0
)
==
BThreadDesc_K0_N_K1
{}.
GetLength
(
I0
)
&&
ABlockDesc_K0_M_K1
{}.
GetLength
(
I2
)
==
BThreadDesc_K0_N_K1
{}.
GetLength
(
I2
),
"wrong! E dimension not consistent
\n
"
);
static_assert
(
K0
%
K0PerLoop
==
0
,
""
);
static_assert
(
M
%
MPerThread
==
0
&&
N
%
NPerThread
==
0
,
"wrong! Cannot evenly divide work among
\n
"
);
static_assert
(
BlockSize
==
M0
*
N0
,
"wrong! wrong blocksize
\n
"
);
}
__device__
static
constexpr
auto
GetCThreadTensorLengths_BM0_BM1_BN0_BN1
()
{
return
Sequence
<
I1
,
M1
,
I1
,
N1
>
{};
}
__device__
static
CIndex
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
index_t
thread_id
)
{
constexpr
auto
c_threadid_to_m0_m1_n0_n1_thread_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
I1
,
N0
,
I1
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
c_m0_m1_n0_n1_thread_cluster_idx
=
c_threadid_to_m0_m1_n0_n1_thread_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
return
c_m0_m1_n0_n1_thread_cluster_idx
;
}
template
<
typename
ABlockBuffer
,
typename
BThreadBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BThreadBuffer
&
b_thread_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
static_assert
(
is_same
<
remove_cvref_t
<
typename
ABlockBuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BThreadBuffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CThreadBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
a_block_mtx
=
ABlockDesc_K0_M_K1
{};
// thread A buffer for GEMM
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
,
a_thread_mtx_
.
GetElementSpaceSize
(),
true
>
a_thread_buf
;
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmDlops_km_kn_mn_v3
<
FloatA
,
FloatB
,
FloatC
,
decltype
(
a_thread_mtx_
),
decltype
(
b_thread_mtx_
),
decltype
(
c_thread_mtx_
)
>
{};
static_for
<
0
,
K0
,
K0PerLoop
>
{}([
&
](
auto
k0_begin
)
{
a_thread_copy_
.
Run
(
a_block_mtx
,
make_tuple
(
k0_begin
,
I0
,
I0
),
a_block_buf
,
a_thread_mtx_
,
make_tuple
(
I0
,
I0
,
I0
),
a_thread_buf
);
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
k0_begin
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
));
});
}
private:
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
ABlockDesc_K0_M_K1
,
decltype
(
a_thread_mtx_
),
Sequence
<
K0PerLoop
,
MPerThread
,
K1
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
K1
,
K1
>
;
CIndex
c_thread_origin_data_idx_
;
AThreadCopy
a_thread_copy_
;
};
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/device/device_tall_and_skinny_gemm.hpp
0 → 100755
View file @
49facb91
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceTsmm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
KBatch
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp
0 → 100644
View file @
49facb91
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_tall_and_skinny_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_tall_and_skinny_gemm_splitk.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
K1
,
index_t
MPerThread
,
index_t
NPerThread
,
index_t
KPerThread
,
typename
ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1
,
typename
ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1
,
typename
ABlockTransferSrcVectorTensorContiguousDimOrder
,
typename
ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1
,
typename
BThreadTransferSrcDstAccessOrder
,
index_t
BThreadTransferSrcVectorDim
,
index_t
BThreadTransferSrcScalarPerVector
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
enable_if_t
<
is_same_v
<
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
&&
is_same_v
<
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
&&
is_same_v
<
CElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
,
bool
>
=
false
>
struct
deviceTsmmDl
:
public
DeviceTsmm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
// GridwiseTsmm
using
GridwiseTsmm
=
GridwiseTsmmDl_km_kn_mn
<
BlockSize
,
ADataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
,
GemmSpec
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
K1
,
MPerThread
,
NPerThread
,
KPerThread
,
ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1
,
ABlockTransferSrcVectorTensorContiguousDimOrder
,
ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1
,
BThreadTransferSrcDstAccessOrder
,
BThreadTransferSrcVectorDim
,
BThreadTransferSrcScalarPerVector
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
>
;
using
DefaultBlock2CTileMap
=
typename
GridwiseTsmm
::
DefaultBlock2CTileMap
;
using
Argument
=
typename
GridwiseTsmm
::
Argument
;
// Invoker
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
karg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
index_t
grid_size
=
GridwiseTsmm
::
CalculateGridSize
(
karg
.
M
,
karg
.
N
,
karg
.
k_batch
);
// const auto b2c_map = DefaultBlock2CTileMap{};
const
auto
K0
=
karg
.
K0
;
const
bool
has_main_k_block_loop
=
GridwiseTsmm
::
CalculateHasMainKBlockLoop
(
K0
);
const
bool
has_double_tail_k_block_loop
=
GridwiseTsmm
::
CalculateHasDoubleTailKBlockLoop
(
K0
);
float
ave_time
=
0
;
if
(
karg
.
k_batch
>
1
)
hipGetErrorString
(
hipMemset
(
karg
.
p_c_grid
,
0
,
karg
.
M
*
karg
.
N
*
sizeof
(
CDataType
)));
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
if
(
karg
.
k_batch
==
1
)
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
true
,
true
,
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
}
else
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
CDataType
,
InMemoryDataOperationEnum
::
AtomicAdd
,
true
,
true
,
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
}
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
if
(
karg
.
k_batch
==
1
)
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
true
,
false
,
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
}
else
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
CDataType
,
InMemoryDataOperationEnum
::
AtomicAdd
,
true
,
false
,
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
}
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
if
(
karg
.
k_batch
==
1
)
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
false
,
true
,
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
}
else
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
CDataType
,
InMemoryDataOperationEnum
::
AtomicAdd
,
false
,
true
,
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
}
}
else
{
if
(
karg
.
k_batch
==
1
)
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
false
,
false
,
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
}
else
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
CDataType
,
InMemoryDataOperationEnum
::
AtomicAdd
,
false
,
false
,
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
}
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
// //
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx1030"
||
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1102"
)
{
return
GridwiseTsmm
::
CheckValidity
(
arg
);
}
else
{
return
false
;
}
}
// //
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
index_t
KBatch
)
// //
{
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
// GridwiseTsmm::CalculateMPadded(M),
// GridwiseTsmm::CalculateNPadded(N),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm
::
CalculateK0
(
K
,
KBatch
),
KBatch
};
// //
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
ck
::
index_t
KBatch
=
1
)
override
// //
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
// GridwiseTsmm::CalculateMPadded(M),
// GridwiseTsmm::CalculateNPadded(N),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm
::
CalculateK0
(
K
,
KBatch
),
KBatch
);
// //
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"deviceTsmmDl"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
K1
<<
", "
<<
MPerThread
<<
", "
<<
NPerThread
<<
", "
<<
KPerThread
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_tall_and_skinny_gemm_splitk.hpp
0 → 100644
View file @
49facb91
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_tall_and_skinny_gemm.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseTsmm
,
typename
FloatAB
,
typename
FloatC
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
,
typename
Block2CTileMap
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_tsmm_dl_v1r3
(
typename
GridwiseTsmm
::
Argument
karg
)
//: in __global__ functions, struct is
// better for reduced load overhead
{
GridwiseTsmm
::
template
Run
<
HasMainKBlockLoop
,
HasDoubleTailKBlockLoop
,
GridwiseTsmm
,
CGlobalMemoryDataOperation
>(
karg
);
}
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
K1Value
,
index_t
MPerThread
,
index_t
NPerThread
,
index_t
KPerThread
,
typename
ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1
,
typename
ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1
,
typename
ABlockTransferSrcVectorTensorContiguousDimOrder
,
typename
ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1
,
typename
BThreadTransferSrcDstAccessOrder
,
index_t
BThreadTransferSrcVectorDim
,
index_t
BThreadTransferSrcScalarPerVector
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
>
struct
GridwiseTsmmDl_km_kn_mn
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
// Argument
struct
Argument
:
public
tensor_operation
::
device
::
BaseArgument
//
{
Argument
(
const
FloatAB
*
p_a_grid_
,
const
FloatAB
*
p_b_grid_
,
FloatC
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
,
// index_t MPadded_,
// index_t NPadded_,
// index_t KPadded_,
index_t
K0_
,
index_t
k_batch_
)
:
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_c_grid
{
p_c_grid_
},
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideC
{
StrideC_
},
// MPadded(MPadded_),
// NPadded(NPadded_),
// KPadded(KPadded_),
K0
(
K0_
),
k_batch
(
k_batch_
)
{
}
// private:
const
FloatAB
*
p_a_grid
;
const
FloatAB
*
p_b_grid
;
FloatC
*
p_c_grid
;
index_t
M
,
N
,
K
;
index_t
StrideA
,
StrideB
,
StrideC
;
//:
// index_t MPadded;
// index_t NPadded;
// index_t KPadded;
index_t
K0
;
index_t
k_batch
;
};
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// TODO: change this. I think it needs multi-dimensional alignment
constexpr
auto
max_lds_align
=
K1
;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k_m
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
a_block_desc_k_m
.
GetElementSpaceSize
(),
max_lds_align
);
return
2
*
(
a_block_aligned_space_size
)
*
sizeof
(
FloatAB
);
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
,
index_t
k_batch
)
//
{
const
index_t
grid_size
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
)
*
math
::
integer_divide_ceil
(
M
,
MPerBlock
)
*
k_batch
;
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K0
)
{
const
bool
has_main_k_block_loop
=
(
K0
+
K0PerBlock
)
/
(
2
*
K0PerBlock
)
>
1
;
return
has_main_k_block_loop
;
}
__host__
__device__
static
constexpr
bool
CalculateHasDoubleTailKBlockLoop
(
index_t
K0
)
{
const
bool
has_double_tail_k_block_loop
=
(
K0
/
K0PerBlock
)
%
2
==
0
;
return
has_double_tail_k_block_loop
;
}
__host__
__device__
static
auto
CalculateMPadded
(
index_t
M
)
{
return
math
::
integer_least_multiple
(
M
,
MPerBlock
);
}
__host__
__device__
static
auto
CalculateNPadded
(
index_t
N
)
{
return
math
::
integer_least_multiple
(
N
,
NPerBlock
);
}
__host__
__device__
static
auto
CalculateK0
(
index_t
K
,
index_t
K_Batch
=
1
)
{
// k_batch * k0 * k0_per_block * k1
auto
K_t
=
K_Batch
*
K0PerBlock
*
K1
;
return
(
K
+
K_t
-
1
)
/
K_t
*
K0PerBlock
;
}
__host__
__device__
static
auto
CalculateKPadded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
auto
K0
=
CalculateK0
(
K
,
K_Batch
);
return
K_Batch
*
K0
*
K1
;
}
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
// M, K -> KBatch, K0, M, K1: M -> MPad, K->KBatch, K0, K1
__host__
__device__
static
auto
MakeAGridDescriptor_KBatch_K0_M_K1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
StrideA
,
index_t
KBatch
,
index_t
K0
)
{
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1Number
)),
// unmerge is split 1D to 3D
make_right_pad_transform
(
M
,
MPad
-
M
)),
//
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
// mapped to input M & K; sequence 0 is M;
// 1 is K; make unmerge is working on K;
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
// input is M,K; output we want is Kbatch, K0 and K1
// -> 0, 1, 3; output is transformed from 2D to 4D
Sequence
<
2
>
{}));
// 2->M
}
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
}
__host__
__device__
static
auto
MakeBGridDescriptor_KBatch_K0_N_K1
(
index_t
K
,
index_t
NPad
,
index_t
N
,
index_t
StrideB
,
index_t
KBatch
,
index_t
K0
)
{
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
}
}();
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1Number
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
{
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
}
__host__
__device__
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}();
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_right_pad_transform
(
M
,
PadM
),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_pass_through_transform
(
M
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
__host__
__device__
static
auto
GetKPad
(
index_t
K
,
index_t
KBatch
)
{
const
index_t
K0
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
*
KBatch
)
*
K0PerBlock
;
const
index_t
KPad
=
KBatch
*
K0
*
K1
;
return
KPad
;
}
using
AGridDesc_Kbatch_K0_M_K1
=
decltype
(
MakeAGridDescriptor_KBatch_K0_M_K1
(
1
,
1
,
1
,
1
,
1
,
1
));
using
BGridDesc_Kbatch_K0_N_K1
=
decltype
(
MakeBGridDescriptor_KBatch_K0_N_K1
(
1
,
1
,
1
,
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
{
const
auto
MPadded
=
CalculateMPadded
(
karg
.
M
);
const
auto
NPadded
=
CalculateNPadded
(
karg
.
N
);
const
auto
a_grid_desc_kbatch_k0_m_k1
=
MakeAGridDescriptor_KBatch_K0_M_K1
(
karg
.
M
,
MPadded
,
karg
.
K
,
karg
.
StrideA
,
karg
.
k_batch
,
karg
.
K0
);
const
auto
b_grid_desc_kbatch_k0_n_k1
=
MakeBGridDescriptor_KBatch_K0_N_K1
(
karg
.
K
,
NPadded
,
karg
.
N
,
karg
.
StrideB
,
karg
.
k_batch
,
karg
.
K0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
);
const
auto
KBatch_a
=
a_grid_desc_kbatch_k0_m_k1
.
GetLength
(
I0
);
const
auto
KBatch_b
=
b_grid_desc_kbatch_k0_n_k1
.
GetLength
(
I0
);
const
auto
K0_
=
a_grid_desc_kbatch_k0_m_k1
.
GetLength
(
I1
);
const
auto
M_
=
a_grid_desc_kbatch_k0_m_k1
.
GetLength
(
I2
);
const
auto
N_
=
b_grid_desc_kbatch_k0_n_k1
.
GetLength
(
I2
);
return
(
M_
%
MPerBlock
==
0
&&
N_
%
NPerBlock
==
0
&&
K0_
%
K0PerBlock
==
0
&&
M_
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N_
==
c_grid_desc_m_n
.
GetLength
(
I1
)
&&
a_grid_desc_kbatch_k0_m_k1
.
GetLength
(
I3
)
==
b_grid_desc_kbatch_k0_n_k1
.
GetLength
(
I3
)
&&
karg
.
k_batch
>=
1
&&
KBatch_a
==
karg
.
k_batch
&&
KBatch_b
==
karg
.
k_batch
);
}
// KBatch, K0, M, K1 -> KBatch, K0, M0, M1 (MPerBlock), K1
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_Kbatch_K0_M0_M1_K1
(
const
AGridDesc_Kbatch_K0_M_K1
&
a_grid_desc_kbatch_k0_m_k1
)
{
const
auto
KBatch
=
a_grid_desc_kbatch_k0_m_k1
.
GetLength
(
I0
);
const
auto
K0
=
a_grid_desc_kbatch_k0_m_k1
.
GetLength
(
I1
);
const
auto
M
=
a_grid_desc_kbatch_k0_m_k1
.
GetLength
(
I2
);
const
auto
M1
=
Number
<
MPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
a_grid_desc_kbatch_k0_m0_m1_k1
=
transform_tensor_descriptor
(
a_grid_desc_kbatch_k0_m_k1
,
make_tuple
(
make_pass_through_transform
(
KBatch
),
make_pass_through_transform
(
K0
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
// IP
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{}));
// OP
return
a_grid_desc_kbatch_k0_m0_m1_k1
;
}
__host__
__device__
static
constexpr
auto
MakeBGridDescriptor_Kbatch_K0_N0_N1_K1
(
const
BGridDesc_Kbatch_K0_N_K1
&
b_grid_desc_kbatch_k0_n_k1
)
{
const
auto
KBatch
=
b_grid_desc_kbatch_k0_n_k1
.
GetLength
(
I0
);
const
auto
K0
=
b_grid_desc_kbatch_k0_n_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_kbatch_k0_n_k1
.
GetLength
(
I2
);
const
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
N0
=
N
/
N1
;
const
auto
b_grid_desc_kbatch_k0_n0_n1_k1
=
transform_tensor_descriptor
(
b_grid_desc_kbatch_k0_n_k1
,
make_tuple
(
make_pass_through_transform
(
KBatch
),
make_pass_through_transform
(
K0
),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{}));
return
b_grid_desc_kbatch_k0_n0_n1_k1
;
}
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
constexpr
auto
M11
=
Number
<
MPerThread
>
{};
constexpr
auto
N11
=
Number
<
NPerThread
>
{};
constexpr
auto
M10
=
M1
/
M11
;
constexpr
auto
N10
=
N1
/
N11
;
const
auto
c_grid_desc_m0_m10_m11_n0_n10_n11
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M10
,
M11
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N10
,
N11
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
return
c_grid_desc_m0_m10_m11_n0_n10_n11
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
()
{
//: 3d ksplit for C
return
BlockToCTileMap_3DGrid_KSplit
<
MPerBlock
,
NPerBlock
>
();
}
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
())
>
;
//
using
AGridDesc_K0_M0_M1_K1
=
decltype
(
MakeAGridDescriptor_Kbatch_K0_M0_M1_K1
(
AGridDesc_Kbatch_K0_M_K1
{}));
using
BGridDesc_K0_N0_N1_K1
=
decltype
(
MakeBGridDescriptor_Kbatch_K0_N0_N1_K1
(
BGridDesc_Kbatch_K0_N_K1
{}));
using
CGridDesc_M0_M10_M11_N0_N10_N11
=
decltype
(
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
CGridDesc_M_N
{}));
//
using
Block2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
());
//
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
,
typename
GridwiseTsmm
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
__device__
static
void
Run
(
const
Argument
&
karg
)
{
constexpr
index_t
shared_block_size
=
GridwiseTsmm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
const
Block2CTileMap
&
block_2_ctile_map
=
Block2CTileMap
{};
const
auto
MPadded
=
CalculateMPadded
(
karg
.
M
);
const
auto
NPadded
=
CalculateNPadded
(
karg
.
N
);
const
FloatAB
*
p_a_grid
=
karg
.
p_a_grid
;
const
FloatAB
*
p_b_grid
=
karg
.
p_b_grid
;
FloatC
*
p_c_grid
=
karg
.
p_c_grid
;
const
auto
a_grid_desc_kbatch_k0_m_k1
=
GridwiseTsmm
::
MakeAGridDescriptor_KBatch_K0_M_K1
(
karg
.
M
,
MPadded
,
karg
.
K
,
karg
.
StrideA
,
karg
.
k_batch
,
karg
.
K0
);
//
const
auto
b_grid_desc_kbatch_k0_n_k1
=
GridwiseTsmm
::
MakeBGridDescriptor_KBatch_K0_N_K1
(
karg
.
K
,
NPadded
,
karg
.
N
,
karg
.
StrideB
,
karg
.
k_batch
,
karg
.
K0
);
//
const
auto
c_grid_desc_m_n
=
GridwiseTsmm
::
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
);
const
auto
a_grid_desc_kbatch_k0_m0_m1_k1
=
GridwiseTsmm
::
MakeAGridDescriptor_Kbatch_K0_M0_M1_K1
(
a_grid_desc_kbatch_k0_m_k1
);
//
const
auto
b_grid_desc_kbatch_k0_n0_n1_k1
=
GridwiseTsmm
::
MakeBGridDescriptor_Kbatch_K0_N0_N1_K1
(
b_grid_desc_kbatch_k0_n_k1
);
//
const
auto
c_grid_desc_m0_m10_m11_n0_n10_n11
=
GridwiseTsmm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
c_grid_desc_m_n
);
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_kbatch_k0_m0_m1_k1
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_kbatch_k0_n0_n1_k1
.
GetElementSpaceSize
());
ignore
=
b_global_buf
;
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetElementSpaceSize
());
const
auto
c_m0_n0_block_cluster_idx
=
block_2_ctile_map
.
convert_1D_block_idx_to_3D_tuple
(
get_block_1d_id
(),
karg
.
N
,
karg
.
k_batch
);
// HACK: this force index data into SGPR
const
index_t
im0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I0
]);
const
index_t
in0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I1
]);
const
index_t
kbatch_id
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I2
]);
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
make_tuple
(
im0
,
in0
),
make_tuple
(
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I0
),
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I3
))))
{
return
;
}
// TODO: change this. I think it needs multi-dimensional alignment
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
a_block_desc_copy_kbatch_k0_m0_m1_k1
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
I1
,
Number
<
K0PerBlock
>
{},
I1
,
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v5r1
<
BlockSize
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
K0PerBlock
,
1
,
MPerBlock
,
K1
.
value
>
,
//: 5 dimensions; kbatch for each
// dimension is 1
ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1
,
ABlockTransferThreadClusterArrangeOrder
,
// 0, 1, 2, 3, 4
FloatAB
,
FloatAB
,
remove_reference_t
<
decltype
(
a_grid_desc_kbatch_k0_m0_m1_k1
)
>
,
// Global tensor desc
decltype
(
a_block_desc_copy_kbatch_k0_m0_m1_k1
),
// block tensor desc
ABlockTransferSrcAccessOrder
,
// 5-dim
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1
,
// SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1
,
// DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder
,
// SrcVectorTensorContiguousDimOrder
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DstVectorTensorContiguousDimOrder
false
,
true
>
(
a_grid_desc_kbatch_k0_m0_m1_k1
,
// for src desc
make_multi_index
(
kbatch_id
,
0
,
im0
,
0
,
0
),
//: calculate start index of K
a_block_desc_copy_kbatch_k0_m0_m1_k1
,
// for dst desc
make_multi_index
(
0
,
0
,
0
,
0
,
0
));
static
constexpr
auto
b_thread_desc_copy_kbatch_k0_n0_n1_k1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
K0PerBlock
>
{},
I1
,
Number
<
NPerThread
>
{},
Number
<
K1
>
{}));
//: this descriptor is used only for copy
static
constexpr
auto
b_thread_desc_copy_k0_n0_n1_k1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
K0PerBlock
>
{},
I1
,
Number
<
NPerThread
>
{},
Number
<
K1
>
{}));
auto
b_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
FloatAB
,
FloatAB
,
remove_reference_t
<
decltype
(
b_grid_desc_kbatch_k0_n0_n1_k1
)
>
,
decltype
(
b_thread_desc_copy_kbatch_k0_n0_n1_k1
),
//
Sequence
<
1
,
K0PerBlock
,
1
,
NPerThread
,
K1
.
value
>
,
BThreadTransferSrcDstAccessOrder
,
BThreadTransferSrcVectorDim
,
BThreadTransferSrcScalarPerVector
,
1
,
false
,
true
>
(
b_grid_desc_kbatch_k0_n0_n1_k1
,
make_multi_index
(
kbatch_id
,
0
,
in0
,
get_thread_local_1d_id
()
*
NPerThread
,
0
));
static
constexpr
auto
b_k0_n_k1_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerThread
>
{},
Number
<
K1
>
{}));
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_block_desc_k0_m0_m1_k1
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
I1
,
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
constexpr
auto
a_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
static_assert
(
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
()
==
a_k0_m_k1_block_desc
.
GetElementSpaceSize
()
&&
"wrong!"
);
const
auto
blockwise_tsmm
=
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_thread_desc
),
MPerThread
,
NPerBlock
,
KPerThread
>
{};
constexpr
auto
c_m10_m11_n10_n11_thread_tensor_lengths
=
decltype
(
blockwise_tsmm
)
::
GetCThreadTensorLengths_BM0_BM1_BN0_BN1
();
constexpr
auto
c_thread_desc_m10_m11_n10_n11
=
make_naive_tensor_descriptor_packed
(
sequence_to_tuple_of_number
(
c_m10_m11_n10_n11_thread_tensor_lengths
));
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block_double
=
p_shared_block
;
auto
b_thread_odd_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
b_k0_n_k1_thread_desc
.
GetElementSpaceSize
());
auto
b_thread_even_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
b_k0_n_k1_thread_desc
.
GetElementSpaceSize
());
// register allocation for output
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
>
(
c_thread_desc_m10_m11_n10_n11
.
GetElementSpaceSize
());
// Initialize C
c_thread_buf
.
Clear
();
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
,
0
);
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
,
0
);
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block_double
,
a_block_desc_copy_kbatch_k0_m0_m1_k1
.
GetElementSpaceSize
());
auto
a_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block_double
+
a_block_aligned_space_size
,
a_block_desc_copy_kbatch_k0_m0_m1_k1
.
GetElementSpaceSize
());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_grid_desc_kbatch_k0_m0_m1_k1
,
a_global_buf
);
// a_global_buf -> reg_tmp_buf
a_blockwise_copy
.
RunWrite
(
a_block_desc_copy_kbatch_k0_m0_m1_k1
,
a_block_even_buf
);
// reg_tmp_buf->a_block_even_buf
b_threadwise_copy
.
Run
(
b_grid_desc_kbatch_k0_n0_n1_k1
,
b_global_buf
,
b_thread_desc_copy_k0_n0_n1_k1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
);
}
if
constexpr
(
HasMainKBlockLoop
)
{
const
auto
K0
=
a_grid_desc_kbatch_k0_m0_m1_k1
.
GetLength
(
I1
);
index_t
k_block_data_begin
=
0
;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_kbatch_k0_m0_m1_k1
,
a_block_slice_copy_step
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_kbatch_k0_n0_n1_k1
,
b_thread_slice_copy_step
);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_kbatch_k0_m0_m1_k1
,
a_global_buf
);
b_threadwise_copy
.
Run
(
b_grid_desc_kbatch_k0_n0_n1_k1
,
b_global_buf
,
b_thread_desc_copy_k0_n0_n1_k1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
);
block_sync_lds
();
// LDS double buffer: GEMM on current data
blockwise_tsmm
.
Run
(
a_block_even_buf
,
b_thread_even_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_block_desc_copy_kbatch_k0_m0_m1_k1
,
a_block_odd_buf
);
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_kbatch_k0_m0_m1_k1
,
a_block_slice_copy_step
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_kbatch_k0_n0_n1_k1
,
b_thread_slice_copy_step
);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_kbatch_k0_m0_m1_k1
,
a_global_buf
);
b_threadwise_copy
.
Run
(
b_grid_desc_kbatch_k0_n0_n1_k1
,
b_global_buf
,
b_thread_desc_copy_k0_n0_n1_k1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
);
block_sync_lds
();
// LDS double buffer: GEMM on current data
blockwise_tsmm
.
Run
(
a_block_odd_buf
,
b_thread_odd_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_block_desc_copy_kbatch_k0_m0_m1_k1
,
a_block_even_buf
);
k_block_data_begin
+=
2
*
K0PerBlock
;
}
while
(
k_block_data_begin
<
K0
-
2
*
K0PerBlock
);
}
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_kbatch_k0_m0_m1_k1
,
a_block_slice_copy_step
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_kbatch_k0_n0_n1_k1
,
b_thread_slice_copy_step
);
block_sync_lds
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_kbatch_k0_m0_m1_k1
,
a_global_buf
);
b_threadwise_copy
.
Run
(
b_grid_desc_kbatch_k0_n0_n1_k1
,
b_global_buf
,
b_thread_desc_copy_k0_n0_n1_k1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_tsmm
.
Run
(
a_block_even_buf
,
b_thread_even_buf
,
c_thread_buf
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_block_desc_copy_kbatch_k0_m0_m1_k1
,
a_block_odd_buf
);
block_sync_lds
();
// LDS double buffer: GEMM on last data
blockwise_tsmm
.
Run
(
a_block_odd_buf
,
b_thread_odd_buf
,
c_thread_buf
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_tsmm
.
Run
(
a_block_even_buf
,
b_thread_even_buf
,
c_thread_buf
);
}
// output: register to global memory
{
constexpr
auto
c_thread_desc_m0_m10_m11_n0_n10_n11
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
]
>
{},
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
]
>
{},
I1
,
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I2
]
>
{},
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I3
]
>
{}));
const
auto
c_m10_m11_n10_n11_thread_origin_idx_on_block
=
blockwise_tsmm
.
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
get_thread_local_1d_id
());
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_thread_desc_m0_m10_m11_n0_n10_n11
),
decltype
(
c_grid_desc_m0_m10_m11_n0_n10_n11
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
],
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
],
1
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I2
],
c_m10_m11_n10_n11_thread_tensor_lengths
[
I3
]
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_grid_desc_m0_m10_m11_n0_n10_n11
,
make_multi_index
(
im0
,
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I0
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I1
],
in0
,
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I2
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I3
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}}
.
Run
(
c_thread_desc_m0_m10_m11_n0_n10_n11
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
c_grid_buf
);
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/thread/threadwise_tall_and_skinny_gemm.hpp
0 → 100755
View file @
49facb91
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP
#define CK_THREADWISE_GEMM_DLOPS_V3_HPP
#include "ck/utility/common_header.hpp"
namespace
ck
{
// C[M, N] += transpose(A[M, M]) * B[M, N]
// Element of matrix can be vectorized data
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
AThreadDesc_K0_M_K1
,
typename
BThreadDesc_K0_N_K1
,
typename
CThreadDesc_M_N
,
typename
enable_if
<
AThreadDesc_K0_M_K1
::
IsKnownAtCompileTime
()
&&
BThreadDesc_K0_N_K1
::
IsKnownAtCompileTime
()
&&
CThreadDesc_M_N
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseGemmDlops_km_kn_mn_v3
{
template
<
typename
ABuffer
,
typename
AOriginIdx
,
typename
BBuffer
,
typename
BOriginIdx
,
typename
CBuffer
,
typename
COriginIdx
>
__device__
static
void
Run
(
const
ABuffer
&
a_buf
,
AOriginIdx
,
const
BBuffer
&
b_buf
,
BOriginIdx
,
CBuffer
&
c_buf
,
COriginIdx
)
{
static_assert
(
AThreadDesc_K0_M_K1
::
IsKnownAtCompileTime
()
&&
BThreadDesc_K0_N_K1
::
IsKnownAtCompileTime
()
&&
CThreadDesc_M_N
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
AOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
BOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
COriginIdx
>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
ABuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BBuffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
K0
=
AThreadDesc_K0_M_K1
{}.
GetLength
(
I0
);
constexpr
auto
M
=
AThreadDesc_K0_M_K1
{}.
GetLength
(
I1
);
constexpr
auto
K1
=
AThreadDesc_K0_M_K1
{}.
GetLength
(
I2
);
constexpr
auto
N
=
BThreadDesc_K0_N_K1
{}.
GetLength
(
I1
);
constexpr
auto
a_origin_idx
=
to_multi_index
(
AOriginIdx
{});
constexpr
auto
b_origin_idx
=
to_multi_index
(
BOriginIdx
{});
constexpr
auto
c_origin_idx
=
to_multi_index
(
COriginIdx
{});
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
)
{
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
n
)
{
static_for
<
0
,
K0
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
K1
,
1
>
{}([
&
](
auto
k1
)
{
constexpr
index_t
a_offset
=
AThreadDesc_K0_M_K1
{}.
CalculateOffset
(
a_origin_idx
+
make_tuple
(
k0
,
m
,
k1
));
constexpr
index_t
b_offset
=
BThreadDesc_K0_N_K1
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
k0
,
n
,
k1
));
constexpr
index_t
c_offset
=
CThreadDesc_M_N
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
0
,
m
,
0
,
n
));
inner_product
<
FloatA
,
FloatB
,
FloatC
>
(
a_buf
[
Number
<
a_offset
>
{}],
b_buf
[
Number
<
b_offset
>
{}],
c_buf
(
Number
<
c_offset
>
{}));
});
});
});
});
}
// namespace ck
};
}
// namespace ck
#endif
library/src/tensor_operation_instance/gpu/gemv_splitk/CMakeLists.txt
0 → 100755
View file @
49facb91
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
set
(
GEMV_SPLITK_INSTANCES
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND GEMV_SPLITK_INSTANCES device_gemv_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
)
list
(
APPEND GEMV_SPLITK_INSTANCES device_gemv_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
)
endif
()
add_instance_library
(
device_gemv_splitk_instance
${
GEMV_SPLITK_INSTANCES
}
)
set
(
target 1
)
endif
()
endforeach
()
\ No newline at end of file
library/src/tensor_operation_instance/gpu/gemv_splitk/device_gemv_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
0 → 100755
View file @
49facb91
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmMNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_gemv_splitk_f16_f16_f16_mk_kn_mn_instances
=
std
::
tuple
<
// clang-format off
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer | ABlockTransfer| ABlockTransfer | BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|SrcVectorTensorLengths| SrcVectorTensor|DstVectorTensorLengths| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
///< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, B, M1, B*N1, K0, K1, 1, N1, 1, S<1,1, 1, 1, K1>, S<1,K0, 1,M1, 1>,S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, 3, N1, S<0, 1, 2, 3, 4, 5>, 5, N1>;
//N1=2
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
1
,
2
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
1
,
4
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
1
,
8
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
2
,
2
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
2
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
2
,
4
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
2
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
2
,
8
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
2
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
3
,
2
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
3
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
3
,
4
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
3
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
3
,
8
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
3
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
4
,
2
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
4
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
4
,
4
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
4
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
4
,
8
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
4
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
5
,
2
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
5
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
5
,
4
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
5
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
5
,
8
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
5
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
6
,
2
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
6
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
6
,
4
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
6
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
6
,
8
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
6
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
7
,
2
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
7
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
7
,
4
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
7
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
7
,
8
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
7
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
8
,
2
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
8
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
8
,
4
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
8
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
8
,
8
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
8
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
//N1=4
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
1
,
2
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
1
,
4
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
1
,
8
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
2
,
2
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
2
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
2
,
4
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
2
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
2
,
8
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
2
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
3
,
2
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
3
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
3
,
4
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
3
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
3
,
8
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
3
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
4
,
2
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
4
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
4
,
4
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
4
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
4
,
8
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
4
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
5
,
2
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
5
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
5
,
4
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
5
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
5
,
8
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
5
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
6
,
2
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
6
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
6
,
4
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
6
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
6
,
8
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
6
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
7
,
2
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
7
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
7
,
4
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
7
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
7
,
8
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
7
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
8
,
2
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
8
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
8
,
4
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
8
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
8
,
8
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
8
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
//N1=8
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
1
,
2
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
1
,
4
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
1
,
8
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
2
,
2
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
2
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
2
,
4
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
2
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
2
,
8
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
2
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
3
,
2
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
3
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
3
,
4
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
3
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
3
,
8
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
3
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
4
,
2
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
4
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
4
,
4
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
4
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
4
,
8
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
4
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
5
,
2
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
5
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
5
,
4
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
5
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
5
,
8
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
5
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
6
,
2
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
6
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
6
,
4
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
6
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
6
,
8
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
6
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
7
,
2
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
7
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
7
,
4
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
7
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
7
,
8
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
7
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
8
,
2
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
8
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
8
,
4
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
8
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
8
,
8
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
8
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
// clang-format on
>
;
void
add_device_gemv_splitk_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceTsmm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemv_splitk_f16_f16_f16_mk_kn_mn_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemv_splitk/device_gemv_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
0 → 100755
View file @
49facb91
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmMNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_gemv_splitk_f16_f16_f16_mk_nk_mn_instances
=
std
::
tuple
<
// clang-format off
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer | ABlockTransfer| ABlockTransfer | BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|SrcVectorTensorLengths| SrcVectorTensor|DstVectorTensorLengths| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
///< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, B, M1, B*N1, K0, K1, 1, N1, 1, S<1,1, 1, 1, K1>, S<1,K0, 1,M1, 1>,S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, 4, K1, S<0, 1, 2, 3, 4, 5>, 5, N1>;
//N1=2
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
1
,
2
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
1
,
4
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
1
,
8
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
2
,
2
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
2
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
2
,
4
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
2
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
2
,
8
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
2
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
3
,
2
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
3
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
3
,
4
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
3
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
3
,
8
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
3
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
4
,
2
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
4
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
4
,
4
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
4
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
4
,
8
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
4
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
5
,
2
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
5
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
5
,
4
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
5
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
5
,
8
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
5
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
6
,
2
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
6
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
6
,
4
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
6
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
6
,
8
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
6
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
7
,
2
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
7
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
7
,
4
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
7
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
7
,
8
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
7
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
8
,
2
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
8
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
8
,
4
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
8
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
128
,
8
,
8
,
1
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
8
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
//N1=4
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
1
,
2
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
1
,
4
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
1
,
8
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
2
,
2
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
2
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
2
,
4
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
2
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
2
,
8
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
2
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
3
,
2
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
3
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
3
,
4
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
3
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
3
,
8
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
3
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
4
,
2
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
4
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
4
,
4
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
4
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
4
,
8
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
4
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
5
,
2
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
5
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
5
,
4
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
5
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
5
,
8
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
5
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
6
,
2
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
6
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
6
,
4
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
6
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
6
,
8
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
6
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
7
,
2
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
7
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
7
,
4
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
7
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
7
,
8
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
7
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
8
,
2
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
8
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
8
,
4
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
8
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
256
,
8
,
8
,
1
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
8
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
//N1=8
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
1
,
2
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
1
,
4
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
1
,
8
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
2
,
2
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
2
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
2
,
4
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
2
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
2
,
8
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
2
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
3
,
2
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
3
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
3
,
4
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
3
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
3
,
8
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
3
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
4
,
2
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
4
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
4
,
4
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
4
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
4
,
8
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
4
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
5
,
2
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
5
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
5
,
4
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
5
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
5
,
8
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
5
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
6
,
2
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
6
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
6
,
4
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
6
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
6
,
8
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
6
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
7
,
2
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
7
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
7
,
4
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
7
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
7
,
8
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
7
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
8
,
2
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
8
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
8
,
4
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
8
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
1
,
512
,
8
,
8
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
8
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
// clang-format on
>
;
void
add_device_gemv_splitk_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceTsmm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemv_splitk_f16_f16_f16_mk_nk_mn_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/tall_and_skinny_gemm_splitk/CMakeLists.txt
0 → 100755
View file @
49facb91
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
set
(
TALL_AND_SKINNY_GEMM_SPLITK_INSTANCES
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND TALL_AND_SKINNY_GEMM_SPLITK_INSTANCES device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
)
list
(
APPEND TALL_AND_SKINNY_GEMM_SPLITK_INSTANCES device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
)
endif
()
add_instance_library
(
device_tall_and_skinny_gemm_splitk_instance
${
TALL_AND_SKINNY_GEMM_SPLITK_INSTANCES
}
)
set
(
target 1
)
endif
()
endforeach
()
\ No newline at end of file
library/src/tensor_operation_instance/gpu/tall_and_skinny_gemm_splitk/device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
0 → 100755
View file @
49facb91
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmMNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_kn_mn_instances
=
std
::
tuple
<
// clang-format off
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer | ABlockTransfer| ABlockTransfer | BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|SrcVectorTensorLengths| SrcVectorTensor|DstVectorTensorLengths| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
///< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, B, M1, B*N1, K0, K1, M1, N1, 1, S<1,1, 1, 1, K1>, S<1,K0, 1,M1, 1>,S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, 3, N1, S<0, 1, 2, 3, 4, 5>, 5, N1>;
//M1 is always tied to 16
//N1=2
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
1
,
2
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
1
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
1
,
4
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
1
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
1
,
8
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
1
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
2
,
2
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
2
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
2
,
4
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
2
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
2
,
8
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
2
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
3
,
2
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
3
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
3
,
4
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
3
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
3
,
8
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
3
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
4
,
2
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
4
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
4
,
4
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
4
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
4
,
8
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
4
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 5, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 5, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 5, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 6, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 6, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 6, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
//ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 7, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
//ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 7, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
//ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 7, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 8, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 8, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 8, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
//N1=4
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
1
,
2
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
1
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
1
,
4
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
1
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
1
,
8
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
1
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
2
,
2
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
2
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
2
,
4
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
2
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
2
,
8
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
2
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
3
,
2
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
3
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
3
,
4
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
3
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
3
,
8
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
3
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
4
,
2
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
4
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
4
,
4
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
4
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
4
,
8
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
4
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 5, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 5, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 5, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 6, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 6, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 6, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 7, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 7, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 7, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 8, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 8, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 8, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// //N1=8
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
1
,
2
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
1
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
1
,
4
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
1
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
1
,
8
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
1
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
2
,
2
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
2
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
2
,
4
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
2
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
2
,
8
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
2
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
3
,
2
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
3
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
3
,
4
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
3
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
3
,
8
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
3
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
4
,
2
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
4
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
4
,
4
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
4
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
4
,
8
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
4
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
3
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 5, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 5, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 5, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 6, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 6, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 6, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 7, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 7, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 7, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 8, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 8, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 8, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 3, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>
// clang-format on
>
;
void
add_device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceTsmm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_kn_mn_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment