Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
460b059e
Commit
460b059e
authored
Jun 25, 2022
by
Chao Liu
Browse files
add KNN, KKN, MNN, MKN layout
parent
7fd0e649
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
68 additions
and
18 deletions
+68
-18
example/23_contraction/contraction_xdl_fp32.cpp
example/23_contraction/contraction_xdl_fp32.cpp
+68
-18
No files found.
example/23_contraction/contraction_xdl_fp32.cpp
View file @
460b059e
...
@@ -42,14 +42,40 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough;
...
@@ -42,14 +42,40 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// clang-format off
// clang-format off
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
// Fast changing dimension in A/B/C are K/N/N dimensions
using
ContractionInstanceKNN
=
ck
::
tensor_operation
::
device
::
//############################| NumDimM| NumDimN| NumDimK| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| NumDimM| NumDimN| NumDimK| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContraction_Xdl_CShuffle
<
NumDimM
,
NumDimN
,
NumDimK
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
;
DeviceContraction_Xdl_CShuffle
<
NumDimM
,
NumDimN
,
NumDimK
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
16
,
4
,
1
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
;
// Fast changing dimension in A/B/C are K/K/N dimensions
using
ContractionInstanceKKN
=
ck
::
tensor_operation
::
device
::
//############################| NumDimM| NumDimN| NumDimK| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContraction_Xdl_CShuffle
<
NumDimM
,
NumDimN
,
NumDimK
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
;
// Fast changing dimension in A/B/C are M/N/N dimensions
using
ContractionInstanceMNN
=
ck
::
tensor_operation
::
device
::
//############################| NumDimM| NumDimN| NumDimK| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContraction_Xdl_CShuffle
<
NumDimM
,
NumDimN
,
NumDimK
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
16
,
1
,
1
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
;
// Fast changing dimension in A/B/C are M/K/N dimensions
using
ContractionInstanceMKN
=
ck
::
tensor_operation
::
device
::
//############################| NumDimM| NumDimN| NumDimK| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContraction_Xdl_CShuffle
<
NumDimM
,
NumDimN
,
NumDimK
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
16
,
1
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
;
// clang-format on
// clang-format on
using
ContractionInstance
=
ContractionInstanceMKN
;
template
<
typename
T
,
typename
Range
>
template
<
typename
T
,
typename
Range
>
void
LogRangeToFile
(
std
::
ofstream
&
fs
,
Range
&&
range
,
std
::
string
delim
)
void
LogRangeToFile
(
std
::
ofstream
&
fs
,
Range
&&
range
,
std
::
string
delim
)
...
@@ -242,6 +268,7 @@ int main(int argc, char* argv[])
...
@@ -242,6 +268,7 @@ int main(int argc, char* argv[])
#if 0
#if 0
// fast changing dimension: K/K/N
// a[m0, m1, k0, k1]
// a[m0, m1, k0, k1]
std::vector<ck::index_t> a_ms_ks_lengths{30, 128, 32, 64};
std::vector<ck::index_t> a_ms_ks_lengths{30, 128, 32, 64};
//std::vector<ck::index_t> a_ms_ks_strides{524288, 4096, 128, 1};
//std::vector<ck::index_t> a_ms_ks_strides{524288, 4096, 128, 1};
...
@@ -251,19 +278,52 @@ int main(int argc, char* argv[])
...
@@ -251,19 +278,52 @@ int main(int argc, char* argv[])
// c[m0, m1, n0, n1]
// c[m0, m1, n0, n1]
std::vector<ck::index_t> c_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> c_ms_ns_lengths{30, 128, 32, 64};
//std::vector<ck::index_t> c_ms_ns_strides{524288, 4096, 128, 1};
//std::vector<ck::index_t> c_ms_ns_strides{524288, 4096, 128, 1};
#else
#elif
0
// fast changing dimension: K/N/N
// a[m0, m1, k0, k1]
// a[m0, m1, k0, k1]
std
::
vector
<
ck
::
index_t
>
a_ms_ks_lengths
{
5
,
6
,
3
,
4
};
std
::
vector
<
ck
::
index_t
>
a_ms_ks_lengths
{
5
,
6
,
3
,
4
};
//
std::vector<ck::index_t> a_ms_ks_strides{108,20,16,1};
std
::
vector
<
ck
::
index_t
>
a_ms_ks_strides
{
108
,
20
,
16
,
1
};
// b[k0, k1, n0, n1]
// b[k0, k1, n0, n1]
std
::
vector
<
ck
::
index_t
>
b_ks_ns_lengths
{
3
,
4
,
3
,
4
};
std
::
vector
<
ck
::
index_t
>
b_ks_ns_lengths
{
3
,
4
,
3
,
4
};
//
std::vector<ck::index_t> b_ks_ns_strides{
16,1,108,20
};
std
::
vector
<
ck
::
index_t
>
b_ks_ns_strides
{
48
,
12
,
4
,
1
};
// c[m0, m1, n0, n1]
// c[m0, m1, n0, n1]
std
::
vector
<
ck
::
index_t
>
c_ms_ns_lengths
{
5
,
6
,
3
,
4
};
std
::
vector
<
ck
::
index_t
>
c_ms_ns_lengths
{
5
,
6
,
3
,
4
};
//std::vector<ck::index_t> c_ms_ns_strides{108,20,16,1};
std
::
vector
<
ck
::
index_t
>
c_ms_ns_strides
{
108
,
20
,
16
,
1
};
#elif 0
// fast changing dimension: K/K/N
// a[m0, m1, k0, k1]
std
::
vector
<
ck
::
index_t
>
a_ms_ks_lengths
{
5
,
6
,
3
,
4
};
std
::
vector
<
ck
::
index_t
>
a_ms_ks_strides
{
108
,
20
,
16
,
1
};
// b[k0, k1, n0, n1]
std
::
vector
<
ck
::
index_t
>
b_ks_ns_lengths
{
3
,
4
,
3
,
4
};
std
::
vector
<
ck
::
index_t
>
b_ks_ns_strides
{
16
,
1
,
108
,
20
};
// c[m0, m1, n0, n1]
std
::
vector
<
ck
::
index_t
>
c_ms_ns_lengths
{
5
,
6
,
3
,
4
};
std
::
vector
<
ck
::
index_t
>
c_ms_ns_strides
{
108
,
20
,
16
,
1
};
#elif 0
// fast changing dimension: M/N/N
// a[m0, m1, k0, k1]
std
::
vector
<
ck
::
index_t
>
a_ms_ks_lengths
{
5
,
6
,
3
,
4
};
std
::
vector
<
ck
::
index_t
>
a_ms_ks_strides
{
6
,
1
,
72
,
24
};
// b[k0, k1, n0, n1]
std
::
vector
<
ck
::
index_t
>
b_ks_ns_lengths
{
3
,
4
,
3
,
4
};
std
::
vector
<
ck
::
index_t
>
b_ks_ns_strides
{
48
,
12
,
4
,
1
};
// c[m0, m1, n0, n1]
std
::
vector
<
ck
::
index_t
>
c_ms_ns_lengths
{
5
,
6
,
3
,
4
};
std
::
vector
<
ck
::
index_t
>
c_ms_ns_strides
{
108
,
20
,
16
,
1
};
#elif 1
// fast changing dimension: M/K/N
// a[m0, m1, k0, k1]
std
::
vector
<
ck
::
index_t
>
a_ms_ks_lengths
{
5
,
6
,
3
,
4
};
std
::
vector
<
ck
::
index_t
>
a_ms_ks_strides
{
6
,
1
,
72
,
24
};
// b[k0, k1, n0, n1]
std
::
vector
<
ck
::
index_t
>
b_ks_ns_lengths
{
3
,
4
,
3
,
4
};
std
::
vector
<
ck
::
index_t
>
b_ks_ns_strides
{
16
,
1
,
108
,
20
};
// c[m0, m1, n0, n1]
std
::
vector
<
ck
::
index_t
>
c_ms_ns_lengths
{
5
,
6
,
3
,
4
};
std
::
vector
<
ck
::
index_t
>
c_ms_ns_strides
{
108
,
20
,
16
,
1
};
#endif
#endif
#if 0
Tensor
<
ADataType
>
a_ms_ks
(
Tensor
<
ADataType
>
a_ms_ks
(
std
::
vector
<
std
::
size_t
>
(
a_ms_ks_lengths
.
begin
(),
a_ms_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
a_ms_ks_lengths
.
begin
(),
a_ms_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
a_ms_ks_strides
.
begin
(),
a_ms_ks_strides
.
end
()));
std
::
vector
<
std
::
size_t
>
(
a_ms_ks_strides
.
begin
(),
a_ms_ks_strides
.
end
()));
...
@@ -276,16 +336,6 @@ int main(int argc, char* argv[])
...
@@ -276,16 +336,6 @@ int main(int argc, char* argv[])
Tensor
<
CDataType
>
c_ms_ns_device_result
(
Tensor
<
CDataType
>
c_ms_ns_device_result
(
std
::
vector
<
std
::
size_t
>
(
c_ms_ns_lengths
.
begin
(),
c_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
c_ms_ns_lengths
.
begin
(),
c_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
c_ms_ns_strides
.
begin
(),
c_ms_ns_strides
.
end
()));
std
::
vector
<
std
::
size_t
>
(
c_ms_ns_strides
.
begin
(),
c_ms_ns_strides
.
end
()));
#else
Tensor
<
ADataType
>
a_ms_ks
(
std
::
vector
<
std
::
size_t
>
(
a_ms_ks_lengths
.
begin
(),
a_ms_ks_lengths
.
end
()));
Tensor
<
BDataType
>
b_ks_ns
(
std
::
vector
<
std
::
size_t
>
(
b_ks_ns_lengths
.
begin
(),
b_ks_ns_lengths
.
end
()));
Tensor
<
CDataType
>
c_ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
c_ms_ns_lengths
.
begin
(),
c_ms_ns_lengths
.
end
()));
Tensor
<
CDataType
>
c_ms_ns_device_result
(
std
::
vector
<
std
::
size_t
>
(
c_ms_ns_lengths
.
begin
(),
c_ms_ns_lengths
.
end
()));
#endif
std
::
cout
<<
"a_ms_ks: "
<<
a_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_ms_ks: "
<<
a_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_ks_ns: "
<<
b_ks_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_ks_ns: "
<<
b_ks_ns
.
mDesc
<<
std
::
endl
;
...
@@ -330,7 +380,7 @@ int main(int argc, char* argv[])
...
@@ -330,7 +380,7 @@ int main(int argc, char* argv[])
auto
c_element_op
=
CElementOp
{};
auto
c_element_op
=
CElementOp
{};
// device operation
// device operation
auto
op
=
DeviceOp
Instance
{};
auto
op
=
Contraction
Instance
{};
auto
invoker
=
op
.
MakeInvoker
();
auto
invoker
=
op
.
MakeInvoker
();
auto
argument
=
op
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_ms_ks_device_buf
.
GetDeviceBuffer
()),
auto
argument
=
op
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_ms_ks_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_ks_ns_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_ks_ns_device_buf
.
GetDeviceBuffer
()),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment