Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
6260ced2
Unverified
Commit
6260ced2
authored
Jan 17, 2022
by
Chao Liu
Committed by
GitHub
Jan 17, 2022
Browse files
Fix building issue for examples (#66)
* fix build issue
parent
acbd7bd7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
43 additions
and
145 deletions
+43
-145
example/1_gemm_xdl/gemm_xdl.cpp
example/1_gemm_xdl/gemm_xdl.cpp
+8
-8
example/3_gemm_xdl_bias_relu_add/gemm_xdl_bias_relu_add.cpp
example/3_gemm_xdl_bias_relu_add/gemm_xdl_bias_relu_add.cpp
+28
-81
example/3_gemm_xdl_bias_relu_add/include/device_gemm_xdl_two_extra_source_reduce.hpp
...u_add/include/device_gemm_xdl_two_extra_source_reduce.hpp
+7
-56
No files found.
example/1_gemm_xdl/gemm_xdl.cpp
View file @
6260ced2
...
@@ -34,11 +34,11 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough;
...
@@ -34,11 +34,11 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for NT problem
// Compilation parameters for NT problem
// clang-format off
// clang-format off
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
//#########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer|
ABlockLds| BBlockLds|
//#########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer|
ABlockLds|
BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer|
BBlockLds|
CThreadTransfer| CThreadTransfer|
//#########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar|
AddExtraM| AddExtraN|
//#########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar|
AddExtraM|
ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar|
AddExtraN|
SrcDstVectorDim| DstScalar|
//#########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector|
| |
//#########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1|
|
Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1|
|
| PerVector|
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
| |
|
|
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | |
|
| | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceGemmXdl
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
;
ck
::
tensor_operation
::
device
::
DeviceGemmXdl
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
;
// clang-format on
// clang-format on
template
<
typename
AType
,
template
<
typename
AType
,
...
@@ -90,9 +90,9 @@ int main(int argc, char* argv[])
...
@@ -90,9 +90,9 @@ int main(int argc, char* argv[])
if
(
argc
==
4
)
if
(
argc
==
4
)
{
{
M
=
std
::
stoi
(
argv
[
4
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
N
=
std
::
stoi
(
argv
[
5
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
K
=
std
::
stoi
(
argv
[
6
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
10
)
else
if
(
argc
==
10
)
{
{
...
...
example/3_gemm_xdl_bias_relu_add/gemm_xdl_bias_relu_add.cpp
View file @
6260ced2
...
@@ -37,7 +37,7 @@ struct BiasReluAdd
...
@@ -37,7 +37,7 @@ struct BiasReluAdd
{
{
#if 0
#if 0
float a = v1 + v0;
float a = v1 + v0;
float b =
max(a, float(0))
;
float b =
a > 0 ? a : 0
;
float c = b + v2;
float c = b + v2;
return c;
return c;
...
@@ -52,70 +52,13 @@ struct BiasReluAdd
...
@@ -52,70 +52,13 @@ struct BiasReluAdd
}
}
};
};
// v0 is from A * B
struct
DoSomething
// v1 is from C0
// v2 is from C1
struct
BiasLeakyReluAdd
{
template
<
typename
T1
,
typename
T2
>
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
float
a
=
v0
+
v1
;
float
b
=
0.1
*
a
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
}
template
<
typename
T1
,
typename
T2
>
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
constexpr
float
alpha
=
0.1
;
constexpr
float
alpha_inv
=
1.0
/
alpha
;
float
a
=
v2
*
alpha_inv
;
float
b
=
v1
+
v0
;
float
c
=
max
(
b
,
float
(
0
));
float
d
=
alpha
*
(
a
+
c
);
return
d
;
}
};
struct
BiasLeakyRelu
{
template
<
typename
T1
,
typename
T2
>
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
)
const
{
float
a
=
v0
+
v1
;
float
b
=
0.1
*
a
;
float
c
=
b
>
0
?
b
:
0
;
return
c
;
}
template
<
typename
T1
,
typename
T2
>
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
)
const
{
constexpr
float
alpha
=
0.1
;
float
b
=
v1
+
v0
;
float
c
=
max
(
b
,
float
(
0
));
float
d
=
alpha
*
c
;
return
d
;
}
};
struct
BiasAdd
{
{
#if 1
#if 1
// correct result
// correct result
// no scratch memory, good VGPR allocation (59)
// no scratch memory, good VGPR allocation (59)
// good perf (101Tflops)
// good perf (101Tflops @ 1089Mhz)
template
<
typename
T1
,
typename
T2
>
__host__
__device__
constexpr
float
operator
()(
float
v0
,
ck
::
half_t
v1
,
ck
::
half_t
v2
)
const
__host__
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
{
constexpr
float
alpha
=
0.1
;
constexpr
float
alpha
=
0.1
;
constexpr
float
beta
=
0.2
;
constexpr
float
beta
=
0.2
;
...
@@ -124,7 +67,7 @@ struct BiasAdd
...
@@ -124,7 +67,7 @@ struct BiasAdd
// compiler seems very volatile to the order of these calculation:
// compiler seems very volatile to the order of these calculation:
// compiler is very eager to read AccVgpr (v0) out prematurely, resulting in register
// compiler is very eager to read AccVgpr (v0) out prematurely, resulting in register
// over-allocation. Therefore, move v0 calculation to the very end
// over-allocation. Therefore, move v0 calculation to the very end
float
a
=
T1
(
beta
)
*
v1
+
T2
(
gamma
)
*
v2
;
float
a
=
ck
::
half_t
(
beta
)
*
v1
+
ck
::
half_t
(
gamma
)
*
v2
;
float
b
=
a
+
float
(
alpha
)
*
v0
;
float
b
=
a
+
float
(
alpha
)
*
v0
;
return
b
;
return
b
;
...
@@ -137,15 +80,14 @@ struct BiasAdd
...
@@ -137,15 +80,14 @@ struct BiasAdd
// wrong result
// wrong result
// lots of scratch memory
// lots of scratch memory
// huge perf drop
// huge perf drop
template
<
typename
T1
,
typename
T2
>
__host__
__device__
constexpr
float
operator
()(
float
v0
,
ck
::
half_t
v1
,
ck
::
half_t
v2
)
const
__host__
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
{
return
alpha
*
v0
+
beta
*
v1
+
gamma
*
v2
;
return
alpha
*
v0
+
beta
*
v1
+
gamma
*
v2
;
}
}
#elif 0
#elif 0
// correct result
// correct result
// some scratch memory (68 dword)
// some scratch memory (68 dword)
// some perf drop (94Tflops)
// some perf drop (94Tflops
@ 1089MHz
)
// fp64 instructions are used
// fp64 instructions are used
__host__
__device__
constexpr
auto
operator
()(
float
v0
,
ck
::
half_t
v1
,
ck
::
half_t
v2
)
const
__host__
__device__
constexpr
auto
operator
()(
float
v0
,
ck
::
half_t
v1
,
ck
::
half_t
v2
)
const
{
{
...
@@ -185,16 +127,20 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
...
@@ -185,16 +127,20 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
using
AOp
=
PassThrough
;
using
AOp
=
PassThrough
;
using
BOp
=
PassThrough
;
using
BOp
=
PassThrough
;
#if 1
using
COp
=
BiasReluAdd
;
using
COp
=
BiasReluAdd
;
#else
using
COp
=
DoSomething
;
#endif
// Compilation parameters for NT problem
// Compilation parameters for NT problem
// clang-format off
// clang-format off
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
//#################################################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer|
ABlockLds| BBlockLds|
//#################################################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer|
ABlockLds|
BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer|
BBlockLds|
CThreadTransfer| CThreadTransfer|
//#################################################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar|
AddExtraM| AddExtraN|
//#################################################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar|
AddExtraM|
ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar|
AddExtraN|
SrcDstVectorDim| DstScalar|
//#################################################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector|
| |
//#################################################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1|
|
Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1|
|
| PerVector|
//#################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
| |
|
|
//#################################################################| | | | | | | | | | | | | | | | | | | | | | | | | |
|
| | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceGemmXdl_two_extra_source_reduce
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AOp
,
BOp
,
COp
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
;
ck
::
tensor_operation
::
device
::
DeviceGemmXdl_two_extra_source_reduce
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AOp
,
BOp
,
COp
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
;
// clang-format on
// clang-format on
template
<
typename
AType
,
template
<
typename
AType
,
...
@@ -215,16 +161,15 @@ static void host_verify(const Tensor<AType>& a_m_k,
...
@@ -215,16 +161,15 @@ static void host_verify(const Tensor<AType>& a_m_k,
auto
f_mk_kn_mn
=
[
&
](
auto
m
,
auto
n
)
{
auto
f_mk_kn_mn
=
[
&
](
auto
m
,
auto
n
)
{
const
int
K
=
a_m_k
.
mDesc
.
GetLengths
()[
1
];
const
int
K
=
a_m_k
.
mDesc
.
GetLengths
()[
1
];
double
v
=
0
;
float
acc
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
v
+=
static_cast
<
const
double
>
(
a_element_op
(
a_m_k
(
m
,
k
)))
*
acc
+=
static_cast
<
const
double
>
(
a_element_op
(
a_m_k
(
m
,
k
)))
*
static_cast
<
const
double
>
(
b_element_op
(
b_k_n
(
k
,
n
)));
static_cast
<
const
double
>
(
b_element_op
(
b_k_n
(
k
,
n
)));
}
}
c_m_n
(
m
,
n
)
=
c_element_op
(
c_m_n
(
m
,
n
)
=
c_element_op
(
acc
,
c0_m_n
(
m
,
n
),
c1_m_n
(
m
,
n
));
v
,
static_cast
<
const
double
>
(
c0_m_n
(
m
,
n
)),
static_cast
<
const
double
>
(
c1_m_n
(
m
,
n
)));
};
};
make_ParallelTensorFunctor
(
f_mk_kn_mn
,
make_ParallelTensorFunctor
(
f_mk_kn_mn
,
...
@@ -249,9 +194,9 @@ int main(int argc, char* argv[])
...
@@ -249,9 +194,9 @@ int main(int argc, char* argv[])
if
(
argc
==
4
)
if
(
argc
==
4
)
{
{
M
=
std
::
stoi
(
argv
[
4
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
N
=
std
::
stoi
(
argv
[
5
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
K
=
std
::
stoi
(
argv
[
6
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
10
)
else
if
(
argc
==
10
)
{
{
...
@@ -337,7 +282,9 @@ int main(int argc, char* argv[])
...
@@ -337,7 +282,9 @@ int main(int argc, char* argv[])
c0_m_n_device_buf
.
ToDevice
(
c0_m_n
.
mData
.
data
());
c0_m_n_device_buf
.
ToDevice
(
c0_m_n
.
mData
.
data
());
c1_m_n_device_buf
.
ToDevice
(
c1_m_n
.
mData
.
data
());
c1_m_n_device_buf
.
ToDevice
(
c1_m_n
.
mData
.
data
());
auto
c_element_op
=
BiasReluAdd
{};
auto
a_element_op
=
AOp
{};
auto
b_element_op
=
BOp
{};
auto
c_element_op
=
COp
{};
// do GEMM
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
...
@@ -354,8 +301,8 @@ int main(int argc, char* argv[])
...
@@ -354,8 +301,8 @@ int main(int argc, char* argv[])
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
PassThrough
{}
,
a_element_op
,
PassThrough
{}
,
b_element_op
,
c_element_op
);
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
...
...
example/3_gemm_xdl_bias_relu_add/include/device_gemm_xdl_two_extra_source_reduce.hpp
View file @
6260ced2
...
@@ -35,24 +35,22 @@ template <typename ADataType,
...
@@ -35,24 +35,22 @@ template <typename ADataType,
ck
::
index_t
NPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadSliceLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
typename
BBlockTransferThreadSliceLengths_K0_N_K1
,
bool
ABlockLdsAddExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
ck
::
index_t
CThreadTransferDstScalarPerVector
>
bool
ABlockLdsAddExtraM
,
bool
BBlockLdsAddExtraN
>
struct
DeviceGemmXdl_two_extra_source_reduce
:
public
BaseOperator
struct
DeviceGemmXdl_two_extra_source_reduce
:
public
BaseOperator
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -137,45 +135,6 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
...
@@ -137,45 +135,6 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
using
C1GridDesc_M_N
=
using
C1GridDesc_M_N
=
decltype
(
make_naive_tensor_descriptor
(
make_tuple
(
1
,
1
),
make_tuple
(
I1
,
I0
)));
decltype
(
make_naive_tensor_descriptor
(
make_tuple
(
1
,
1
),
make_tuple
(
I1
,
I0
)));
// TODO remove these hacks
static
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
,
0
,
0
>
{},
// 1+: M
Sequence
<
0
,
0
,
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
,
0
,
0
>
{},
// 1-: M
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
static
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
,
0
,
0
>
{},
// 1+: N
Sequence
<
0
,
0
,
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
,
0
,
0
>
{},
// 1-: N
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
static
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
static
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
static
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
<
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
<
BlockSize
,
BlockSize
,
...
@@ -199,7 +158,6 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
...
@@ -199,7 +158,6 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
K1
,
K1
,
MXdlPerWave
,
MXdlPerWave
,
NXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
...
@@ -207,25 +165,18 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
...
@@ -207,25 +165,18 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
B
Block
TransferThreadSliceLengths_K0_N_K1
,
A
Block
LdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
Sequence
<
0
,
2
,
4
,
5
,
6
,
1
,
3
,
7
>
,
// CThreadTransferSrcDstAccessOrder,
Sequence
<
0
,
2
,
4
,
5
,
6
,
1
,
3
,
7
>
,
// CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
>
;
decltype
(
a_k0_m_k1_grid_step_hacks
),
// AGridStepHacks,
decltype
(
b_k0_n_k1_grid_step_hacks
),
// BGridStepHacks,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
// CGridStepHacks,
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
// AGridMoveSliceWindowStepHacks,
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
// BGridMoveSliceWindowStepHacks,
false
,
// CAccessOrderMRepeatNRepeat,
ABlockLdsAddExtraM
,
BBlockLdsAddExtraN
>
;
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment