Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
250a89f3
Commit
250a89f3
authored
Oct 14, 2024
by
Mirza Halilcevic
Browse files
Replace gemm_gemm with gemm_multiple_d_gemm_multiple_d.
parent
d1e9682a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
638 additions
and
43 deletions
+638
-43
codegen/include/ck/host/device_batched_gemm_multiple_d_gemm_multiple_d/operation.hpp
...ice_batched_gemm_multiple_d_gemm_multiple_d/operation.hpp
+24
-20
codegen/include/ck/host/device_batched_gemm_multiple_d_gemm_multiple_d/problem.hpp
...evice_batched_gemm_multiple_d_gemm_multiple_d/problem.hpp
+51
-0
codegen/include/ck/host/operation/gemm.hpp
codegen/include/ck/host/operation/gemm.hpp
+27
-16
codegen/src/device_batched_gemm_multiple_d_gemm_multiple_d.cpp
...en/src/device_batched_gemm_multiple_d_gemm_multiple_d.cpp
+8
-7
codegen/src/device_batched_gemm_multiple_d_gemm_multiple_d_operation_xdl_cshuffle.cpp
...emm_multiple_d_gemm_multiple_d_operation_xdl_cshuffle.cpp
+528
-0
No files found.
codegen/include/ck/host/device_
gemm_elementwise_gemm
/operation.hpp
→
codegen/include/ck/host/device_
batched_gemm_multiple_d_gemm_multiple_d
/operation.hpp
View file @
250a89f3
...
@@ -8,13 +8,13 @@
...
@@ -8,13 +8,13 @@
#include <string>
#include <string>
#include "ck/host/types.hpp"
#include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp"
#include "ck/host/operation/gemm.hpp"
#include "ck/host/device_
gemm_elementwise_gemm
/problem.hpp"
#include "ck/host/device_
batched_gemm_multiple_d_gemm_multiple_d
/problem.hpp"
namespace
ck
{
namespace
ck
{
namespace
host
{
namespace
host
{
namespace
device_
gemm_elementwise_gemm
{
namespace
device_
batched_gemm_multiple_d_gemm_multiple_d
{
// defines all values need for an instance
of fwd conv
// defines all values need
ed
for an instance
struct
Operation_Xdl_CShuffle
struct
Operation_Xdl_CShuffle
{
{
// returns a vector of instances, only given fusion operators: will use default problem spec
// returns a vector of instances, only given fusion operators: will use default problem spec
...
@@ -23,36 +23,40 @@ struct Operation_Xdl_CShuffle
...
@@ -23,36 +23,40 @@ struct Operation_Xdl_CShuffle
// returns a vector of instances, given a problem spec and fusion operators
// returns a vector of instances, given a problem spec and fusion operators
static
std
::
vector
<
Operation_Xdl_CShuffle
>
static
std
::
vector
<
Operation_Xdl_CShuffle
>
CreateOperations
(
const
Problem
&
prob
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
);
CreateOperations
(
const
Problem
&
prob
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
);
TensorDesc
A
{};
TensorDesc
A
0
{};
TensorDesc
B0
{};
TensorDesc
B0
{};
std
::
vector
<
TensorDesc
>
D0s
=
{};
TensorDesc
B1
{};
TensorDesc
B1
{};
TensorDesc
C
{};
std
::
vector
<
TensorDesc
>
D1s
=
{};
DataType
acc
=
DataType
::
Float
;
TensorDesc
E1
{}
;
DataType
c
s
_type
=
DataType
::
Half
;
DataType
ac
c_type
=
DataType
::
Float
;
std
::
string
a_elem_op
=
PassThrough
;
DataType
cshuffle_type
=
DataType
::
Float
;
std
::
string
b
0_elem_op
=
PassThrough
;
std
::
string
a
0_elem_op
=
PassThrough
;
std
::
string
acc
0_elem_op
=
PassThrough
;
std
::
string
b
0_elem_op
=
PassThrough
;
std
::
string
b1
_elem_op
=
PassThrough
;
std
::
string
cde0
_elem_op
=
PassThrough
;
std
::
string
c
_elem_op
=
PassThrough
;
std
::
string
b1
_elem_op
=
PassThrough
;
std
::
string
prologue
=
""
;
std
::
string
cde1_elem_op
=
PassThrough
;
std
::
string
epi
logue
=
""
;
std
::
string
pro
logue
=
""
;
std
::
string
gemm_specialization
=
"ck::tensor_operation::device::GemmSpecialization::Default
"
;
std
::
string
epilogue
=
"
"
;
// tuning parameters
// tuning parameters
operation
::
TileDescGemmElementwiseGemm
tile_desc
{};
operation
::
PaddingDesc
padding_desc
{};
operation
::
BlockTransferDesc
a_block_transfer
{};
operation
::
TileDescGemmGemm
tile_desc
{};
operation
::
BlockTransferDesc
a0_block_transfer
{};
operation
::
BlockTransferDesc
b0_block_transfer
{};
operation
::
BlockTransferDesc
b0_block_transfer
{};
operation
::
BlockTransferDesc
cde0_block_transfer
{};
operation
::
BlockTransferDesc
b1_block_transfer
{};
operation
::
BlockTransferDesc
b1_block_transfer
{};
operation
::
CShuffleDesc
cshuffle
{};
operation
::
CShuffleDesc
cshuffle
{};
operation
::
CBlockTransferDesc
c_block_transfer
{};
operation
::
CBlockTransferDesc
c
de1
_block_transfer
{};
// functions to update fusion operators if provided
// functions to update fusion operators if provided
void
update_prologue
(
const
std
::
string
&
prologue
);
void
update_prologue
(
const
std
::
string
&
prologue
);
void
update_epilogue
(
const
std
::
string
&
epilogue
);
void
update_epilogue
(
const
std
::
string
&
epilogue
);
/**constexpr**/
bool
IsSupported
(
std
::
size_t
MRaw_
,
std
::
size_t
NRaw_
,
std
::
size_t
KRaw_
);
/**constexpr**/
bool
IsSupported
(
std
::
size_t
MRaw_
,
std
::
size_t
NRaw_
,
std
::
size_t
KRaw_
,
std
::
size_t
Gemm1NRaw_
);
// returns a templated instance
// returns a templated instance
Solution
ToSolution
()
const
;
Solution
ToSolution
()
const
;
};
};
}
// namespace device_
gemm_elementwise_gemm
}
// namespace device_
batched_gemm_multiple_d_gemm_multiple_d
}
// namespace host
}
// namespace host
}
// namespace ck
}
// namespace ck
codegen/include/ck/host/device_
gemm_elementwise_gemm
/problem.hpp
→
codegen/include/ck/host/device_
batched_gemm_multiple_d_gemm_multiple_d
/problem.hpp
View file @
250a89f3
...
@@ -10,28 +10,32 @@
...
@@ -10,28 +10,32 @@
namespace
ck
{
namespace
ck
{
namespace
host
{
namespace
host
{
namespace
device_
gemm_elementwise_gemm
{
namespace
device_
batched_gemm_multiple_d_gemm_multiple_d
{
// defines the problem specification for a GEMM operation
// defines the problem specification for a GEMM
_ELEMENTWISE_GEMM
operation
struct
Problem
struct
Problem
{
{
std
::
size_t
M
=
0
;
std
::
size_t
M
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
K
=
0
;
std
::
size_t
K
=
0
;
std
::
size_t
O
=
0
;
std
::
size_t
O
=
0
;
bool
TransA
=
false
;
bool
TransA0
=
false
;
bool
TransB0
=
false
;
bool
TransB0
=
false
;
bool
TransB1
=
false
;
std
::
vector
<
bool
>
D0sTrans
=
{};
bool
TransC
=
false
;
bool
TransB1
=
false
;
DataType
ADataType
=
DataType
::
Half
;
std
::
vector
<
bool
>
D1sTrans
=
{};
DataType
B0DataType
=
DataType
::
Half
;
bool
TransE1
=
false
;
DataType
B1DataType
=
DataType
::
Half
;
DataType
A0DataType
=
DataType
::
Half
;
DataType
CDataType
=
DataType
::
Half
;
DataType
B0DataType
=
DataType
::
Half
;
std
::
string
AElementOp
=
PassThrough
;
std
::
vector
<
DataType
>
D0sDataType
=
{};
std
::
string
B0ElementOp
=
PassThrough
;
DataType
B1DataType
=
DataType
::
Half
;
std
::
string
Acc0ElementOp
=
PassThrough
;
std
::
vector
<
DataType
>
D1sDataType
=
{};
std
::
string
B1ElementOp
=
PassThrough
;
DataType
E1DataType
=
DataType
::
Half
;
std
::
string
CElementOp
=
PassThrough
;
std
::
string
A0ElementOp
=
PassThrough
;
std
::
string
B0ElementOp
=
PassThrough
;
std
::
string
CDE0ElementOp
=
PassThrough
;
std
::
string
B1ElementOp
=
PassThrough
;
std
::
string
CDE1ElementOp
=
PassThrough
;
// returns the correct device op file for the operation
// returns the correct device op file for the operation
std
::
string
GetIncludeHeader
()
const
;
std
::
string
GetIncludeHeader
()
const
;
...
@@ -42,6 +46,6 @@ struct Problem
...
@@ -42,6 +46,6 @@ struct Problem
const
std
::
string
&
epilogue
)
const
;
const
std
::
string
&
epilogue
)
const
;
};
};
}
// namespace device_
gemm_elementwise_gemm
}
// namespace device_
batched_gemm_multiple_d_gemm_multiple_d
}
// namespace host
}
// namespace host
}
// namespace ck
}
// namespace ck
codegen/include/ck/host/operation/gemm.hpp
View file @
250a89f3
...
@@ -9,6 +9,15 @@ namespace ck {
...
@@ -9,6 +9,15 @@ namespace ck {
namespace
host
{
namespace
host
{
namespace
operation
{
namespace
operation
{
struct
PaddingDesc
{
bool
pad_gemm0_m
=
0
;
bool
pad_gemm0_n
=
0
;
bool
pad_gemm0_k
=
0
;
bool
pad_gemm1_n
=
0
;
bool
pad_gemm1_k
=
0
;
};
struct
TileDesc
struct
TileDesc
{
{
int
block_size
=
0
;
int
block_size
=
0
;
...
@@ -24,23 +33,23 @@ struct TileDesc
...
@@ -24,23 +33,23 @@ struct TileDesc
int
num_gemmk_prefetch_stage
=
0
;
int
num_gemmk_prefetch_stage
=
0
;
};
};
struct
TileDescGemm
Elementwise
Gemm
struct
TileDescGemmGemm
{
{
int
block_size
=
0
;
int
block_size
=
0
;
int
gemm0
1
_m_per_block
=
0
;
int
gemm0_m_per_block
=
0
;
int
gemm0_n_per_block
=
0
;
int
gemm0_n_per_block
=
0
;
int
gemm0_k_per_block
=
0
;
int
gemm0_k_per_block
=
0
;
int
gemm1_n_per_block
=
0
;
int
gemm1_n_per_block
=
0
;
int
gemm1_k_per_block
=
0
;
int
gemm1_k_per_block
=
0
;
int
ak1
=
0
;
int
a
0
k1
=
0
;
int
bk1
=
0
;
int
b
0
k1
=
0
;
int
b1k1
=
0
;
int
b1k1
=
0
;
int
m_per_XDL
=
0
;
int
m_per_XDL
=
0
;
int
n_per_XDL
=
0
;
int
n_per_XDL
=
0
;
int
gemm0_m_Xdl_per_wave
=
0
;
int
gemm0_m_Xdl_per_wave
=
0
;
int
gemm0_n_Xdl_per_wave
=
0
;
int
gemm0_n_Xdl_per_wave
=
0
;
int
gemm1_n_Xdl_per_wave
=
0
;
int
gemm1_n_Xdl_per_wave
=
0
;
int
num_gemmk_prefetch_stage
=
0
;
int
num_gemm
0
k_prefetch_stage
=
0
;
};
};
struct
BlockTransferDesc
struct
BlockTransferDesc
...
@@ -53,11 +62,13 @@ struct BlockTransferDesc
...
@@ -53,11 +62,13 @@ struct BlockTransferDesc
int
dst_scalar_per_vector_k1
=
0
;
int
dst_scalar_per_vector_k1
=
0
;
int
lds_add_extra_dim
=
0
;
int
lds_add_extra_dim
=
0
;
};
};
struct
CShuffleDesc
struct
CShuffleDesc
{
{
int
m_Xdl_per_wave_per_shuffle
=
0
;
int
m_Xdl_per_wave_per_shuffle
=
0
;
int
n_Xdl_per_wave_per_shuffle
=
0
;
int
n_Xdl_per_wave_per_shuffle
=
0
;
};
};
struct
CBlockTransferDesc
struct
CBlockTransferDesc
{
{
std
::
string
cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl
=
""
;
std
::
string
cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl
=
""
;
...
...
codegen/src/device_
gemm_elementwise_gemm
.cpp
→
codegen/src/device_
batched_gemm_multiple_d_gemm_multiple_d
.cpp
View file @
250a89f3
...
@@ -2,19 +2,20 @@
...
@@ -2,19 +2,20 @@
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_
gemm_elementwise_gemm
/problem.hpp"
#include "ck/host/device_
batched_gemm_multiple_d_gemm_multiple_d
/problem.hpp"
#include "ck/host/device_
gemm_elementwise_gemm
/operation.hpp"
#include "ck/host/device_
batched_gemm_multiple_d_gemm_multiple_d
/operation.hpp"
#include "ck/host/utils.hpp"
#include "ck/host/utils.hpp"
#include <algorithm>
#include <algorithm>
namespace
ck
{
namespace
ck
{
namespace
host
{
namespace
host
{
namespace
device_
gemm_elementwise_gemm
{
namespace
device_
batched_gemm_multiple_d_gemm_multiple_d
{
// return the relevant device op file based on the operation
// return the relevant device op file based on the operation
std
::
string
Problem
::
GetIncludeHeader
()
const
std
::
string
Problem
::
GetIncludeHeader
()
const
{
{
return
"ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
;
return
"ck/tensor_operation/gpu/device/impl/"
"device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp"
;
}
}
// returns templated instances when provided with a problem specification
// returns templated instances when provided with a problem specification
...
@@ -24,8 +25,8 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch,
...
@@ -24,8 +25,8 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch,
{
{
if
(
get_xdlop_archs
().
count
(
arch
)
==
0
)
if
(
get_xdlop_archs
().
count
(
arch
)
==
0
)
return
{};
return
{};
auto
ops
=
ck
::
host
::
device_
gemm_elementwise_gemm
::
Operation_Xdl_CShuffle
::
CreateOperations
(
auto
ops
=
ck
::
host
::
device_
batched_gemm_multiple_d_gemm_multiple_d
::
Operation_Xdl_CShuffle
::
*
this
,
prologue
,
epilogue
);
// obtains vector of instances
CreateOperations
(
*
this
,
prologue
,
epilogue
);
// obtains vector of instances
std
::
vector
<
Solution
>
result
;
std
::
vector
<
Solution
>
result
;
std
::
transform
(
ops
.
begin
(),
ops
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
op
)
{
std
::
transform
(
ops
.
begin
(),
ops
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
op
)
{
return
op
.
ToSolution
();
// template instance with correct values
return
op
.
ToSolution
();
// template instance with correct values
...
@@ -33,6 +34,6 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch,
...
@@ -33,6 +34,6 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch,
return
result
;
return
result
;
}
}
}
// namespace device_
gemm_elementwise_gemm
}
// namespace device_
batched_gemm_multiple_d_gemm_multiple_d
}
// namespace host
}
// namespace host
}
// namespace ck
}
// namespace ck
codegen/src/device_
gemm_elementwise_gemm
_operation_xdl_cshuffle.cpp
→
codegen/src/device_
batched_gemm_multiple_d_gemm_multiple_d
_operation_xdl_cshuffle.cpp
View file @
250a89f3
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_
gemm_elementwise_gemm
/operation.hpp"
#include "ck/host/device_
batched_gemm_multiple_d_gemm_multiple_d
/operation.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/utils.hpp"
#include "ck/host/utils.hpp"
#include <cassert>
#include <cassert>
namespace
ck
{
namespace
ck
{
namespace
host
{
namespace
host
{
namespace
device_
gemm_elementwise_gemm
{
namespace
device_
batched_gemm_multiple_d_gemm_multiple_d
{
// calculate appropriate Gemm Specification based on input tensor dimensions
// calculate appropriate Gemm Specification based on input tensor dimensions
std
::
string
GetGemmSpec
(
const
std
::
size_t
m
,
operation
::
PaddingDesc
GetPaddingDesc
(
const
std
::
size_t
m
,
const
std
::
size_t
n
,
const
std
::
size_t
n
,
const
std
::
size_t
k
,
const
std
::
size_t
k
,
const
std
::
size_t
n1
,
const
std
::
size_t
n1
,
const
std
::
size_t
m_per_block
,
const
std
::
size_t
m_per_block
,
const
std
::
size_t
n_per_block
,
const
std
::
size_t
n_per_block
,
const
std
::
size_t
k_per_block
,
const
std
::
size_t
k_per_block
,
const
std
::
size_t
n1_per_block
)
const
std
::
size_t
n1_per_block
,
const
std
::
size_t
k1_per_block
)
{
{
std
::
string
spec
=
""
;
operation
::
PaddingDesc
desc
;
if
(
integer_divide_ceil
(
m
,
m_per_block
)
*
m_per_block
-
m
!=
0
)
if
(
integer_divide_ceil
(
m
,
m_per_block
)
*
m_per_block
-
m
!=
0
)
spec
+=
"M"
;
desc
.
pad_gemm0_m
=
true
;
if
(
integer_divide_ceil
(
n
,
n_per_block
)
*
n_per_block
-
n
!=
0
)
if
(
integer_divide_ceil
(
n
,
n_per_block
)
*
n_per_block
-
n
!=
0
)
spec
+=
"N"
;
desc
.
pad_gemm0_n
=
true
;
if
(
integer_divide_ceil
(
k
,
k_per_block
)
*
k_per_block
-
k
!=
0
)
if
(
integer_divide_ceil
(
k
,
k_per_block
)
*
k_per_block
-
k
!=
0
)
spec
+=
"K"
;
desc
.
pad_gemm0_k
=
true
;
if
(
integer_divide_ceil
(
n1
,
n1_per_block
)
*
n1_per_block
-
n1
!=
0
)
if
(
integer_divide_ceil
(
n1
,
n1_per_block
)
*
n1_per_block
-
n1
!=
0
)
spec
+=
"O"
;
desc
.
pad_gemm1_n
=
true
;
if
(
spec
==
""
)
if
(
integer_divide_ceil
(
n
,
k1_per_block
)
*
k1_per_block
-
n
!=
0
)
// TODO is n == k1 ?
return
"ck::tensor_operation::device::GemmSpecialization::Default"
;
desc
.
pad_gemm1_k
=
true
;
return
"ck::tensor_operation::device::GemmSpecialization::"
+
spec
+
"Padding"
;
return
desc
;
}
}
// function to update prologue/epilogue with user provided operation
// function to update prologue/epilogue with user provided operation
...
@@ -41,6 +42,9 @@ void Operation_Xdl_CShuffle::update_prologue(const std::string& pro)
...
@@ -41,6 +42,9 @@ void Operation_Xdl_CShuffle::update_prologue(const std::string& pro)
if
(
!
prologue
.
empty
())
if
(
!
prologue
.
empty
())
{
{
this
->
prologue
=
pro
;
this
->
prologue
=
pro
;
// TODO is this right?
this
->
cde0_elem_op
=
"CDE0ElementOp"
;
this
->
cde1_elem_op
=
"CDE1ElementOp"
;
}
}
else
else
{
{
...
@@ -53,6 +57,9 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
...
@@ -53,6 +57,9 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
if
(
!
epilogue
.
empty
())
if
(
!
epilogue
.
empty
())
{
{
this
->
epilogue
=
epi
;
this
->
epilogue
=
epi
;
// TODO is this right?
this
->
cde0_elem_op
=
"CDE0ElementOp"
;
this
->
cde1_elem_op
=
"CDE1ElementOp"
;
}
}
else
else
{
{
...
@@ -68,24 +75,19 @@ static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row
...
@@ -68,24 +75,19 @@ static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row
std
::
vector
<
Operation_Xdl_CShuffle
>
Operation_Xdl_CShuffle
::
CreateOperations
(
std
::
vector
<
Operation_Xdl_CShuffle
>
Operation_Xdl_CShuffle
::
CreateOperations
(
const
Problem
&
prob
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
)
const
Problem
&
prob
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
)
{
{
assert
(
prob
.
TransA
==
false
);
std
::
vector
<
Operation_Xdl_CShuffle
>
result
;
assert
(
prob
.
TransB0
==
true
);
assert
(
prob
.
TransC
==
false
);
const
auto
b1k1
=
prob
.
TransB1
?
4
:
2
;
const
auto
b1k1
=
prob
.
TransB1
?
4
:
2
;
std
::
vector
<
Operation_Xdl_CShuffle
>
result
;
std
::
vector
<
operation
::
TileDescGemmGemm
>
tile_descriptions
=
{
std
::
vector
<
operation
::
TileDescGemmElementwiseGemm
>
tile_descriptions
=
{
// clang-format off
// clang-format off
// Block| Gemm0
1
| Gemm0| Gemm0| Gemm1| Gemm1|
AK1|
BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|
NumGemmK|
// Block|
Gemm0| Gemm0| Gemm0| Gemm1| Gemm1|A
0
K1|B
0
K1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|NumGemm
0
K|
// Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Prefetch|
// Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Prefetch|
// | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Stage|
// | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Stage|
// | | | | | | | | | | | Wave| Wave| Wave| |
// | | | | | | | | | | | Wave| Wave| Wave| |
{
256
,
256
,
128
,
32
,
64
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
2
,
4
,
2
,
1
},
//generic
{
256
,
256
,
128
,
32
,
128
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
2
,
4
,
4
,
1
},
{
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
1
,
2
,
4
,
1
},
{
256
,
128
,
256
,
32
,
64
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
1
,
8
,
2
,
1
},
// no padding
{
256
,
128
,
256
,
32
,
128
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
1
,
8
,
4
,
1
},
{
256
,
128
,
128
,
64
,
64
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
1
,
4
,
2
,
1
},
{
256
,
128
,
128
,
64
,
64
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
1
,
4
,
2
,
1
},
{
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
1
,
4
,
2
,
1
},
{
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
1
,
4
,
2
,
1
},
{
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
1
,
4
,
4
,
1
},
{
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
1
,
4
,
4
,
1
},
...
@@ -94,22 +96,29 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -94,22 +96,29 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
256
,
64
,
256
,
32
,
64
,
32
,
8
,
8
,
b1k1
,
16
,
16
,
1
,
16
,
4
,
1
},
{
256
,
64
,
256
,
32
,
64
,
32
,
8
,
8
,
b1k1
,
16
,
16
,
1
,
16
,
4
,
1
},
{
256
,
64
,
256
,
64
,
128
,
32
,
8
,
8
,
b1k1
,
16
,
16
,
1
,
16
,
8
,
1
},
{
256
,
64
,
256
,
64
,
128
,
32
,
8
,
8
,
b1k1
,
16
,
16
,
1
,
16
,
8
,
1
},
{
256
,
64
,
256
,
64
,
64
,
32
,
8
,
8
,
b1k1
,
16
,
16
,
1
,
16
,
4
,
1
},
{
256
,
64
,
256
,
64
,
64
,
32
,
8
,
8
,
b1k1
,
16
,
16
,
1
,
16
,
4
,
1
},
// Padded fallback kerne
// Padded fallback kerne
l
{
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
1
,
4
,
4
,
1
},
{
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
1
,
4
,
4
,
1
},
{
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
1
,
2
,
4
,
1
},
{
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
1
,
2
,
4
,
1
},
// clang-format on
// clang-format on
};
};
if
(
prob
.
TransB1
)
{
// clang-format off
tile_descriptions
.
push_back
(
{
256
,
256
,
128
,
32
,
128
,
32
,
8
,
8
,
4
,
32
,
32
,
2
,
4
,
4
,
1
}
);
// clang-format on
}
const
std
::
vector
<
operation
::
BlockTransferDesc
>
a_block_descriptions
=
{
std
::
vector
<
operation
::
BlockTransferDesc
>
a
0
_block_descriptions
=
{
// clang-format off
// clang-format off
//
ABlockTransfer|
ABlockTransfer|
ABlockTransfer|
ABlockTransfer|
ABlockTransfer|
ABlockTransfer|
ABlockLds|
// A
0
BlockTransfer|A
0
BlockTransfer|A
0
BlockTransfer|A
0
BlockTransfer|A
0
BlockTransfer|A
0
BlockTransfer|A
0
BlockLds|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector|
PerVector_K1| |
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_
A
K1| |
// | | | | | | |
// | | | | | | |
//generic
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
// no padding
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
},
...
@@ -123,90 +132,111 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -123,90 +132,111 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
// clang-format on
// clang-format on
};
};
if
(
prob
.
TransB1
)
const
auto
&
b0_block_descriptions_rowmajor
=
a_block_descriptions
;
{
const
std
::
vector
<
operation
::
BlockTransferDesc
>
b0_block_descriptions_colmajor
=
{
// clang-format off
// clang-format off
// B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds|
a0_block_descriptions
.
push_back
(
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
}
// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
);
// | | | | | | |
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
// clang-format on
// clang-format on
};
}
auto
b0_block_descriptions
=
a0_block_descriptions
;
if
(
prob
.
TransB1
)
{
b0_block_descriptions
[
1
].
lds_add_extra_dim
=
true
;
b0_block_descriptions
[
3
].
lds_add_extra_dim
=
true
;
}
const
std
::
vector
<
operation
::
BlockTransferDesc
>
b1
_block_descriptions
_rowmajor
=
{
std
::
vector
<
operation
::
BlockTransferDesc
>
cde0
_block_descriptions
=
{
// clang-format off
// clang-format off
// B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds|
// ... | CDE0BlockTransfer| CDE0BlockTransfer| ... |
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
// ... | SrcVectorDim| SrcScalar| ... |
// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// ... | | PerVector| ... |
// | | | | | | |
// | | | |
{
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
//generic
{
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
""
,
""
,
""
,
9
,
1
,
0
,
0
},
{
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
// no padding
{
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
""
,
""
,
""
,
9
,
4
,
0
,
0
},
{
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
""
,
""
,
""
,
9
,
4
,
0
,
0
},
{
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
""
,
""
,
""
,
9
,
4
,
0
,
0
},
{
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
""
,
""
,
""
,
9
,
4
,
0
,
0
},
{
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
""
,
""
,
""
,
9
,
4
,
0
,
0
},
{
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
""
,
""
,
""
,
9
,
4
,
0
,
0
},
{
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
""
,
""
,
""
,
9
,
4
,
0
,
0
},
{
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
""
,
""
,
""
,
9
,
4
,
0
,
0
},
{
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
// Padded fallback kernel
// Padded fallback kernel
{
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
""
,
""
,
""
,
9
,
4
,
0
,
0
},
{
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
""
,
""
,
""
,
9
,
4
,
0
,
0
},
// clang-format on
// clang-format on
};
};
if
(
prob
.
TransB1
)
const
std
::
vector
<
operation
::
BlockTransferDesc
>
b1_block_descriptions_colmajor
=
{
{
// clang-format off
// clang-format off
// B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds|
cde0_block_descriptions
.
push_back
(
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
{
""
,
""
,
""
,
9
,
4
,
0
,
0
}
// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
);
// | | | | | | |
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
// Padded fallback kernel
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
// clang-format on
// clang-format on
};
}
const
std
::
vector
<
operation
::
BlockTransferDesc
>
b1_block_descriptions_rowmajor
=
{
// clang-format off
// B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
//generic
{
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
// no padding
{
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
// Padded fallback kernel
{
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
// clang-format on
};
const
std
::
vector
<
operation
::
BlockTransferDesc
>
b1_block_descriptions_colmajor
=
{
// clang-format off
// B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
//generic
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
// no padding
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
// Padded fallback kernel
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
// clang-format on
};
std
::
vector
<
operation
::
CShuffleDesc
>
cshuffle_descriptions
=
{
std
::
vector
<
operation
::
CShuffleDesc
>
cshuffle_descriptions
=
{
// clang-format off
// clang-format off
//
CShuffle|
CShuffle|
// C
1
Shuffle| C
1
Shuffle|
// MXdlPerWave| NXdlPerWave|
// MXdlPerWave| NXdlPerWave|
// PerShuffle| PerShuffle|
// PerShuffle| PerShuffle|
// | |
// | |
// generic
{
1
,
2
},
{
1
,
2
},
{
1
,
2
},
// no padding
{
1
,
2
},
{
1
,
2
},
{
1
,
2
},
{
1
,
2
},
{
1
,
2
},
{
1
,
2
},
{
1
,
2
},
{
1
,
2
},
...
@@ -220,69 +250,92 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -220,69 +250,92 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
1
,
2
},
{
1
,
2
},
// clang-format on
// clang-format on
};
};
if
(
prob
.
TransB1
)
{
// clang-format off
cshuffle_descriptions
.
push_back
(
{
1
,
2
}
);
// clang-format on
}
std
::
vector
<
operation
::
CBlockTransferDesc
>
c_block_descriptions
=
{
std
::
vector
<
operation
::
CBlockTransferDesc
>
c
de1
_block_descriptions
=
{
// clang-format off
// clang-format off
// CBlockTransferClusterLengths| CBlockTransfer
// CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
// _MBlock_MWaveMPerXdl| ScalarPerVector
// _MBlock_MWaveMPerXdl| ScalarPerVector|
// _NBlock_NWaveNPerXdl| _NWaveNPerXdl
// _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// |
// | |
{
S
<
1
,
32
,
1
,
8
>
,
8
},
// generic
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
// no padding
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
16
,
1
,
16
>
,
8
},
{
S
<
1
,
16
,
1
,
16
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
16
,
1
,
16
>
,
8
},
{
S
<
1
,
16
,
1
,
16
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
// Padded fallback kernel
// Padded fallback kernel
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
// clang-format on
// clang-format on
};
};
if
(
prob
.
TransB1
)
{
// clang-format off
cde1_block_descriptions
.
push_back
(
{
S
<
1
,
32
,
1
,
8
>
,
8
}
);
// clang-format on
}
// choose correct arrangement of tuning parameters based on the layout of each tensor
// choose correct arrangement of tuning parameters based on the layout of each tensor
const
auto
&
b0_block_descriptions
=
prob
.
TransB1
?
b0_block_descriptions_colmajor
:
b0_block_descriptions_rowmajor
;
const
auto
&
b1_block_descriptions
=
const
auto
&
b1_block_descriptions
=
prob
.
TransB1
?
b1_block_descriptions_colmajor
:
b1_block_descriptions_rowmajor
;
prob
.
TransB1
?
b1_block_descriptions_colmajor
:
b1_block_descriptions_rowmajor
;
assert
(
tile_descriptions
.
size
()
==
a_block_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
a0_block_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
b0_block_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
cde0_block_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
b1_block_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
b1_block_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
cshuffle_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
cshuffle_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
c_block_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
c
de1
_block_descriptions
.
size
());
// Put all values together into a single operation > store into the result vector
// Put all values together into a single operation > store into the result vector
for
(
std
::
size_t
i
=
0
;
i
<
tile_descriptions
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
tile_descriptions
.
size
();
i
++
)
{
{
Operation_Xdl_CShuffle
x
;
Operation_Xdl_CShuffle
x
;
x
.
tile_desc
=
tile_descriptions
[
i
];
x
.
tile_desc
=
tile_descriptions
[
i
];
x
.
a_block_transfer
=
a_block_descriptions
[
i
];
x
.
a
0
_block_transfer
=
a
0
_block_descriptions
[
i
];
x
.
b0_block_transfer
=
b0_block_descriptions
[
i
];
x
.
b0_block_transfer
=
b0_block_descriptions
[
i
];
x
.
cde0_block_transfer
=
cde0_block_descriptions
[
i
];
x
.
b1_block_transfer
=
b1_block_descriptions
[
i
];
x
.
b1_block_transfer
=
b1_block_descriptions
[
i
];
x
.
cshuffle
=
cshuffle_descriptions
[
i
];
x
.
cshuffle
=
cshuffle_descriptions
[
i
];
x
.
c_block_transfer
=
c_block_descriptions
[
i
];
x
.
c
de1
_block_transfer
=
c
de1
_block_descriptions
[
i
];
x
.
A
=
TensorDesc
{
prob
.
ADataType
,
ToLayout
(
prob
.
TransA
)};
x
.
A
0
=
TensorDesc
{
prob
.
A
0
DataType
,
ToLayout
(
prob
.
TransA
0
)};
x
.
B0
=
TensorDesc
{
prob
.
B0DataType
,
ToLayout
(
prob
.
TransB0
)};
x
.
B0
=
TensorDesc
{
prob
.
B0DataType
,
ToLayout
(
prob
.
TransB0
)};
x
.
D0s
=
Transform
(
prob
.
D0sTrans
,
prob
.
D0sDataType
,
[](
auto
trans
,
auto
dt
)
{
return
TensorDesc
{
dt
,
ToLayout
(
trans
)};
});
x
.
B1
=
TensorDesc
{
prob
.
B1DataType
,
ToLayout
(
prob
.
TransB1
)};
x
.
B1
=
TensorDesc
{
prob
.
B1DataType
,
ToLayout
(
prob
.
TransB1
)};
x
.
C
=
TensorDesc
{
prob
.
CDataType
,
ToLayout
(
prob
.
TransC
)};
x
.
D1s
=
Transform
(
prob
.
D1sTrans
,
prob
.
D1sDataType
,
[](
auto
trans
,
auto
dt
)
{
x
.
a_elem_op
=
prob
.
AElementOp
;
return
TensorDesc
{
dt
,
ToLayout
(
trans
)};
});
x
.
E1
=
TensorDesc
{
prob
.
E1DataType
,
ToLayout
(
prob
.
TransE1
)};
x
.
a0_elem_op
=
prob
.
A0ElementOp
;
x
.
b0_elem_op
=
prob
.
B0ElementOp
;
x
.
b0_elem_op
=
prob
.
B0ElementOp
;
x
.
cde0_elem_op
=
prob
.
CDE0ElementOp
;
x
.
b1_elem_op
=
prob
.
B1ElementOp
;
x
.
b1_elem_op
=
prob
.
B1ElementOp
;
x
.
c_elem_op
=
prob
.
CElementOp
;
x
.
c
de1
_elem_op
=
prob
.
C
DE1
ElementOp
;
x
.
acc0_elem_op
=
prob
.
Acc0ElementOp
;
x
.
padding_desc
=
GetPaddingDesc
(
prob
.
M
,
x
.
gemm_specialization
=
GetGemmSpec
(
prob
.
M
,
prob
.
N
,
prob
.
N
,
prob
.
K
,
prob
.
K
,
prob
.
O
,
prob
.
O
,
x
.
tile_desc
.
gemm0_m_per_block
,
x
.
tile_desc
.
gemm0
1_m
_per_block
,
x
.
tile_desc
.
gemm0
_n
_per_block
,
x
.
tile_desc
.
gemm0_
n
_per_block
,
x
.
tile_desc
.
gemm0_
k
_per_block
,
x
.
tile_desc
.
gemm
0_k
_per_block
,
x
.
tile_desc
.
gemm
1_n
_per_block
,
x
.
tile_desc
.
gemm1_
n
_per_block
);
x
.
tile_desc
.
gemm1_
k
_per_block
);
x
.
update_prologue
(
prologue
);
x
.
update_prologue
(
prologue
);
x
.
update_epilogue
(
epilogue
);
x
.
update_epilogue
(
epilogue
);
result
.
push_back
(
x
);
result
.
push_back
(
x
);
...
@@ -298,10 +351,7 @@ Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std:
...
@@ -298,10 +351,7 @@ Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std:
std
::
vector
<
std
::
vector
<
Operation_Xdl_CShuffle
>>
operations
;
std
::
vector
<
std
::
vector
<
Operation_Xdl_CShuffle
>>
operations
;
Problem
prob
;
Problem
prob
;
prob
.
TransA
=
false
;
prob
.
TransB0
=
true
;
prob
.
TransB0
=
true
;
prob
.
TransB1
=
false
;
prob
.
TransC
=
false
;
operations
.
push_back
(
CreateOperations
(
prob
,
prologue
,
epilogue
));
operations
.
push_back
(
CreateOperations
(
prob
,
prologue
,
epilogue
));
prob
.
TransB1
=
true
;
prob
.
TransB1
=
true
;
...
@@ -310,29 +360,43 @@ Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std:
...
@@ -310,29 +360,43 @@ Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std:
return
operations
;
return
operations
;
}
}
static
const
char
*
const
DeviceBatchedGemmGemm_Xdl_CShuffleTemplate
=
static
const
char
*
const
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffleTemplate
=
"ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<${LayoutA}, "
"ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle<"
"${LayoutB0}, ${LayoutB1}, ${LayoutC}, ${ADataType}, ${B0DataType}, ${B1DataType}, "
"${A0Layout}, ${B0Layout}, ${D0sLayout}, ${B1Layout}, ${D1sLayout}, ${E1Layout}, "
"${CDataType}, ${AccDataType}, ${CShuffleDataType}, ${AElementwiseOperation}, "
"${B0ElementwiseOperation}, ${Acc0ElementwiseOperation}, ${B1ElementwiseOperation}, "
"${A0DataType}, ${B0DataType}, ${Acc0DataType}, ${D0sDataType}, ${B1DataType}, "
"${CElementwiseOperation}, ${GemmSpecialization}, ${NumGemmkPrefetchStage}, ${BlockSize}, "
"${Acc1DataType}, ${C1ShuffleDataType}, ${D1sDataType}, ${E1DataType}, "
"${Gemm01MPerBlock}, ${Gemm0NPerBlock}, ${Gemm0KPerBlock}, ${Gemm1NPerBlock}, "
"${Gemm1KPerBlock}, ${AK1}, ${BK1}, ${B1K1}, ${MPerXDL}, ${NPerXDL}, ${Gemm0MXdlPerWave}, "
"${A0ElementwiseOperation}, ${B0ElementwiseOperation}, ${CDE0ElementwiseOperation}, "
"${Gemm0NXdlPerWave}, ${Gemm1NXdlPerWave}, ${ABlockTransferThreadClusterLengths_AK0_M_AK1}, "
"${B1ElementwiseOperation}, ${CDE1ElementwiseOperation}, "
"${ABlockTransferThreadClusterArrangeOrder}, ${ABlockTransferSrcAccessOrder}, "
"${ABlockTransferSrcVectorDim}, ${ABlockTransferSrcScalarPerVector}, "
"${PadGemm0M}, ${PadGemm0N}, ${PadGemm0K}, ${PadGemm1N}, ${PadGemm1K}, "
"${ABlockTransferDstScalarPerVector_AK1}, ${ABlockLdsExtraM}, "
"${NumGemm0KPrefetchStage}, ${BlockSize}, ${Gemm0MPerBlock}, ${Gemm0NPerBlock}, "
"${Gemm0KPerBlock}, ${Gemm1NPerBlock}, ${Gemm1KPerBlock}, ${A0K1}, ${B0K1}, ${B1K1}, "
"${MPerXDL}, ${NPerXDL}, ${Gemm0MXdlPerWave}, ${Gemm0NXdlPerWave}, ${Gemm1NXdlPerWave}, "
"${A0BlockTransferThreadClusterLengths_AK0_M_AK1}, "
"${A0BlockTransferThreadClusterArrangeOrder}, ${A0BlockTransferSrcAccessOrder}, "
"${A0BlockTransferSrcVectorDim}, ${A0BlockTransferSrcScalarPerVector}, "
"${A0BlockTransferDstScalarPerVector_AK1}, ${A0BlockLdsExtraM}, "
"${B0BlockTransferThreadClusterLengths_BK0_N_BK1}, "
"${B0BlockTransferThreadClusterLengths_BK0_N_BK1}, "
"${B0BlockTransferThreadClusterArrangeOrder}, ${B0BlockTransferSrcAccessOrder}, "
"${B0BlockTransferThreadClusterArrangeOrder}, ${B0BlockTransferSrcAccessOrder}, "
"${B0BlockTransferSrcVectorDim}, ${B0BlockTransferSrcScalarPerVector}, "
"${B0BlockTransferSrcVectorDim}, ${B0BlockTransferSrcScalarPerVector}, "
"${B0BlockTransferDstScalarPerVector_BK1}, ${B0BlockLdsExtraN}, "
"${B0BlockTransferDstScalarPerVector_BK1}, ${B0BlockLdsExtraN}, "
"${CDE0BlockTransferSrcVectorDim}, ${CDE0BlockTransferSrcScalarPerVector}, "
"${B1BlockTransferThreadClusterLengths_BK0_N_BK1}, "
"${B1BlockTransferThreadClusterLengths_BK0_N_BK1}, "
"${B1BlockTransferThreadClusterArrangeOrder}, ${B1BlockTransferSrcAccessOrder}, "
"${B1BlockTransferThreadClusterArrangeOrder}, ${B1BlockTransferSrcAccessOrder}, "
"${B1BlockTransferSrcVectorDim}, ${B1BlockTransferSrcScalarPerVector}, "
"${B1BlockTransferSrcVectorDim}, ${B1BlockTransferSrcScalarPerVector}, "
"${B1BlockTransferDstScalarPerVector_BK1}, ${B1BlockLdsExtraN}, "
"${B1BlockTransferDstScalarPerVector_BK1}, ${B1BlockLdsExtraN}, "
"${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
"${CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl}, "
"${C1ShuffleMXdlPerWavePerShuffle}, ${C1ShuffleGemm0NXdlPerWavePerShuffle}, "
"${CBlockTransferScalarPerVector_NWaveNPerXdl}>"
;
"${CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, "
"${CDE1ShuffleBlockTransferScalarPerVector_NPerBlock}>"
;
// use hardcoded instances from vector of operations to substitute values into instance template
// use hardcoded instances from vector of operations to substitute values into instance template
Solution
Operation_Xdl_CShuffle
::
ToSolution
()
const
Solution
Operation_Xdl_CShuffle
::
ToSolution
()
const
...
@@ -340,60 +404,80 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
...
@@ -340,60 +404,80 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
values
=
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
values
=
{
{
"name"
,
{
"name"
,
std
::
to_string
(
this
->
tile_desc
.
block_size
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
block_size
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0
1
_m_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_m_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_n_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_n_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_k_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_k_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm1_n_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm1_n_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm1_k_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm1_k_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
ak1
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
bk1
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
a
0
k1
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
b
0
k1
)
+
std
::
to_string
(
this
->
tile_desc
.
b1k1
)
+
"_"
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
b1k1
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
m_per_XDL
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
m_per_XDL
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
n_per_XDL
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
n_per_XDL
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_m_Xdl_per_wave
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_m_Xdl_per_wave
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_n_Xdl_per_wave
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_n_Xdl_per_wave
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm1_n_Xdl_per_wave
)},
std
::
to_string
(
this
->
tile_desc
.
gemm1_n_Xdl_per_wave
)},
{
"LayoutA"
,
ToString
(
this
->
A
.
layout
)},
{
"LayoutB0"
,
ToString
(
this
->
B0
.
layout
)},
{
"A0Layout"
,
ToString
(
this
->
A0
.
layout
)},
{
"LayoutB1"
,
ToString
(
this
->
B1
.
layout
)},
{
"B0Layout"
,
ToString
(
this
->
B0
.
layout
)},
{
"LayoutC"
,
ToString
(
this
->
C
.
layout
)},
{
"D0sLayout"
,
{
"ADataType"
,
ToString
(
this
->
A
.
element
)},
MakeTuple
(
Transform
(
this
->
D0s
,
[](
auto
tensor
)
{
return
ToString
(
tensor
.
layout
);
}))},
{
"B1Layout"
,
ToString
(
this
->
B1
.
layout
)},
{
"D1sLayout"
,
MakeTuple
(
Transform
(
this
->
D1s
,
[](
auto
tensor
)
{
return
ToString
(
tensor
.
layout
);
}))},
{
"E1Layout"
,
ToString
(
this
->
E1
.
layout
)},
{
"ADataType"
,
ToString
(
this
->
A0
.
element
)},
{
"B0DataType"
,
ToString
(
this
->
B0
.
element
)},
{
"B0DataType"
,
ToString
(
this
->
B0
.
element
)},
{
"Acc0DataType"
,
ToString
(
this
->
acc_type
)},
{
"D0sDataType"
,
MakeTuple
(
Transform
(
this
->
D0s
,
[](
auto
tensor
)
{
return
ToString
(
tensor
.
element
);
}))},
{
"B1DataType"
,
ToString
(
this
->
B1
.
element
)},
{
"B1DataType"
,
ToString
(
this
->
B1
.
element
)},
{
"CDataType"
,
ToString
(
this
->
C
.
element
)},
{
"Acc1DataType"
,
ToString
(
this
->
acc_type
)},
{
"AccDataType"
,
ToString
(
this
->
acc
)},
{
"C1ShuffleDataType"
,
ToString
(
this
->
cshuffle_type
)},
{
"CShuffleDataType"
,
ToString
(
this
->
cs_type
)},
{
"D1sDataType"
,
{
"AElementwiseOperation"
,
this
->
a_elem_op
},
MakeTuple
(
Transform
(
this
->
D1s
,
[](
auto
tensor
)
{
return
ToString
(
tensor
.
element
);
}))},
{
"E1DataType"
,
ToString
(
this
->
E1
.
element
)},
{
"A0ElementwiseOperation"
,
this
->
a0_elem_op
},
{
"B0ElementwiseOperation"
,
this
->
b0_elem_op
},
{
"B0ElementwiseOperation"
,
this
->
b0_elem_op
},
{
"
Acc
0ElementwiseOperation"
,
this
->
acc
0_elem_op
},
{
"
CDE
0ElementwiseOperation"
,
this
->
cde
0_elem_op
},
{
"B1ElementwiseOperation"
,
this
->
b1_elem_op
},
{
"B1ElementwiseOperation"
,
this
->
b1_elem_op
},
{
"CElementwiseOperation"
,
this
->
c_elem_op
},
{
"CDE1ElementwiseOperation"
,
this
->
cde1_elem_op
},
{
"GemmSpecialization"
,
this
->
gemm_specialization
},
{
"NumGemmkPrefetchStage"
,
std
::
to_string
(
this
->
tile_desc
.
num_gemmk_prefetch_stage
)},
{
"PadGemm0M"
,
std
::
to_string
(
this
->
padding_desc
.
pad_gemm0_m
)},
{
"PadGemm0N"
,
std
::
to_string
(
this
->
padding_desc
.
pad_gemm0_n
)},
{
"PadGemm0K"
,
std
::
to_string
(
this
->
padding_desc
.
pad_gemm0_k
)},
{
"PadGemm1N"
,
std
::
to_string
(
this
->
padding_desc
.
pad_gemm1_n
)},
{
"PadGemm1K"
,
std
::
to_string
(
this
->
padding_desc
.
pad_gemm1_k
)},
{
"NumGemm0KPrefetchStage"
,
std
::
to_string
(
this
->
tile_desc
.
num_gemm0k_prefetch_stage
)},
{
"BlockSize"
,
std
::
to_string
(
this
->
tile_desc
.
block_size
)},
{
"BlockSize"
,
std
::
to_string
(
this
->
tile_desc
.
block_size
)},
{
"Gemm0
1
MPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0
1
_m_per_block
)},
{
"Gemm0MPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0_m_per_block
)},
{
"Gemm0NPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0_n_per_block
)},
{
"Gemm0NPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0_n_per_block
)},
{
"Gemm0KPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0_k_per_block
)},
{
"Gemm0KPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0_k_per_block
)},
{
"Gemm1NPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm1_n_per_block
)},
{
"Gemm1NPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm1_n_per_block
)},
{
"Gemm1KPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm1_k_per_block
)},
{
"Gemm1KPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm1_k_per_block
)},
{
"AK1"
,
std
::
to_string
(
this
->
tile_desc
.
ak1
)},
{
"A
0
K1"
,
std
::
to_string
(
this
->
tile_desc
.
a
0
k1
)},
{
"BK1"
,
std
::
to_string
(
this
->
tile_desc
.
bk1
)},
{
"B
0
K1"
,
std
::
to_string
(
this
->
tile_desc
.
b
0
k1
)},
{
"B1K1"
,
std
::
to_string
(
this
->
tile_desc
.
b1k1
)},
{
"B1K1"
,
std
::
to_string
(
this
->
tile_desc
.
b1k1
)},
{
"MPerXDL"
,
std
::
to_string
(
this
->
tile_desc
.
m_per_XDL
)},
{
"MPerXDL"
,
std
::
to_string
(
this
->
tile_desc
.
m_per_XDL
)},
{
"NPerXDL"
,
std
::
to_string
(
this
->
tile_desc
.
n_per_XDL
)},
{
"NPerXDL"
,
std
::
to_string
(
this
->
tile_desc
.
n_per_XDL
)},
{
"Gemm0MXdlPerWave"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0_m_Xdl_per_wave
)},
{
"Gemm0MXdlPerWave"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0_m_Xdl_per_wave
)},
{
"Gemm0NXdlPerWave"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0_n_Xdl_per_wave
)},
{
"Gemm0NXdlPerWave"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0_n_Xdl_per_wave
)},
{
"Gemm1NXdlPerWave"
,
std
::
to_string
(
this
->
tile_desc
.
gemm1_n_Xdl_per_wave
)},
{
"Gemm1NXdlPerWave"
,
std
::
to_string
(
this
->
tile_desc
.
gemm1_n_Xdl_per_wave
)},
{
"ABlockTransferThreadClusterLengths_AK0_M_AK1"
,
this
->
a_block_transfer
.
thread_cluster_length
},
{
"A0BlockTransferThreadClusterLengths_AK0_M_AK1"
,
{
"ABlockTransferThreadClusterArrangeOrder"
,
this
->
a0_block_transfer
.
thread_cluster_length
},
this
->
a_block_transfer
.
thread_cluster_arrange_order
},
{
"A0BlockTransferThreadClusterArrangeOrder"
,
{
"ABlockTransferSrcAccessOrder"
,
this
->
a_block_transfer
.
src_access_order
},
this
->
a0_block_transfer
.
thread_cluster_arrange_order
},
{
"ABlockTransferSrcVectorDim"
,
std
::
to_string
(
this
->
a_block_transfer
.
src_vec_dim
)},
{
"A0BlockTransferSrcAccessOrder"
,
this
->
a0_block_transfer
.
src_access_order
},
{
"ABlockTransferSrcScalarPerVector"
,
{
"A0BlockTransferSrcVectorDim"
,
std
::
to_string
(
this
->
a0_block_transfer
.
src_vec_dim
)},
std
::
to_string
(
this
->
a_block_transfer
.
src_scalar_per_vector
)},
{
"A0BlockTransferSrcScalarPerVector"
,
{
"ABlockTransferDstScalarPerVector_AK1"
,
std
::
to_string
(
this
->
a0_block_transfer
.
src_scalar_per_vector
)},
std
::
to_string
(
this
->
a_block_transfer
.
dst_scalar_per_vector_k1
)},
{
"A0BlockTransferDstScalarPerVector_AK1"
,
{
"ABlockLdsExtraM"
,
std
::
to_string
(
this
->
a_block_transfer
.
lds_add_extra_dim
)},
std
::
to_string
(
this
->
a0_block_transfer
.
dst_scalar_per_vector_k1
)},
{
"A0BlockLdsExtraM"
,
std
::
to_string
(
this
->
a0_block_transfer
.
lds_add_extra_dim
)},
{
"B0BlockTransferThreadClusterLengths_BK0_N_BK1"
,
{
"B0BlockTransferThreadClusterLengths_BK0_N_BK1"
,
this
->
b0_block_transfer
.
thread_cluster_length
},
this
->
b0_block_transfer
.
thread_cluster_length
},
{
"B0BlockTransferThreadClusterArrangeOrder"
,
{
"B0BlockTransferThreadClusterArrangeOrder"
,
...
@@ -405,6 +489,11 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
...
@@ -405,6 +489,11 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
{
"B0BlockTransferDstScalarPerVector_BK1"
,
{
"B0BlockTransferDstScalarPerVector_BK1"
,
std
::
to_string
(
this
->
b0_block_transfer
.
dst_scalar_per_vector_k1
)},
std
::
to_string
(
this
->
b0_block_transfer
.
dst_scalar_per_vector_k1
)},
{
"B0BlockLdsExtraN"
,
std
::
to_string
(
this
->
b0_block_transfer
.
lds_add_extra_dim
)},
{
"B0BlockLdsExtraN"
,
std
::
to_string
(
this
->
b0_block_transfer
.
lds_add_extra_dim
)},
{
"CDE0BlockTransferSrcVectorDim"
,
std
::
to_string
(
this
->
cde0_block_transfer
.
src_vec_dim
)},
{
"CDE0BlockTransferSrcScalarPerVector"
,
std
::
to_string
(
this
->
cde0_block_transfer
.
src_scalar_per_vector
)},
{
"B1BlockTransferThreadClusterLengths_BK0_N_BK1"
,
{
"B1BlockTransferThreadClusterLengths_BK0_N_BK1"
,
this
->
b1_block_transfer
.
thread_cluster_length
},
this
->
b1_block_transfer
.
thread_cluster_length
},
{
"B1BlockTransferThreadClusterArrangeOrder"
,
{
"B1BlockTransferThreadClusterArrangeOrder"
,
...
@@ -416,20 +505,24 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
...
@@ -416,20 +505,24 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
{
"B1BlockTransferDstScalarPerVector_BK1"
,
{
"B1BlockTransferDstScalarPerVector_BK1"
,
std
::
to_string
(
this
->
b1_block_transfer
.
dst_scalar_per_vector_k1
)},
std
::
to_string
(
this
->
b1_block_transfer
.
dst_scalar_per_vector_k1
)},
{
"B1BlockLdsExtraN"
,
std
::
to_string
(
this
->
b1_block_transfer
.
lds_add_extra_dim
)},
{
"B1BlockLdsExtraN"
,
std
::
to_string
(
this
->
b1_block_transfer
.
lds_add_extra_dim
)},
{
"CShuffleMXdlPerWavePerShuffle"
,
{
"C1ShuffleMXdlPerWavePerShuffle"
,
std
::
to_string
(
this
->
cshuffle
.
m_Xdl_per_wave_per_shuffle
)},
std
::
to_string
(
this
->
cshuffle
.
m_Xdl_per_wave_per_shuffle
)},
{
"CShuffleNXdlPerWavePerShuffle"
,
{
"C
1
Shuffle
Gemm0
NXdlPerWavePerShuffle"
,
std
::
to_string
(
this
->
cshuffle
.
n_Xdl_per_wave_per_shuffle
)},
std
::
to_string
(
this
->
cshuffle
.
n_Xdl_per_wave_per_shuffle
)},
{
"CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl"
,
this
->
c_block_transfer
.
cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl
},
{
"CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock"
,
{
"CBlockTransferScalarPerVector_NWaveNPerXdl"
,
this
->
cde1_block_transfer
std
::
to_string
(
this
->
c_block_transfer
.
scalar_per_vector_n_wave_n_per_Xdl
)},
.
cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl
},
{
"CDE1ShuffleBlockTransferScalarPerVector_NPerBlock"
,
std
::
to_string
(
this
->
cde1_block_transfer
.
scalar_per_vector_n_wave_n_per_Xdl
)},
};
};
return
Solution
{
InterpolateString
(
DeviceBatchedGemmGemm_Xdl_CShuffleTemplate
,
values
),
return
Solution
{
std
::
move
(
values
)};
InterpolateString
(
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffleTemplate
,
values
),
std
::
move
(
values
)};
}
}
}
// namespace device_
gemm_elementwise_gemm
}
// namespace device_
batched_gemm_multiple_d_gemm_multiple_d
}
// namespace host
}
// namespace host
}
// namespace ck
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment