Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d1e9682a
Commit
d1e9682a
authored
Oct 02, 2024
by
Mirza Halilcevic
Browse files
Introduce gemm_elementwise_gemm.
parent
11b7a4db
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
598 additions
and
0 deletions
+598
-0
codegen/include/ck/host/device_gemm_elementwise_gemm/operation.hpp
...nclude/ck/host/device_gemm_elementwise_gemm/operation.hpp
+58
-0
codegen/include/ck/host/device_gemm_elementwise_gemm/problem.hpp
.../include/ck/host/device_gemm_elementwise_gemm/problem.hpp
+47
-0
codegen/include/ck/host/operation/gemm.hpp
codegen/include/ck/host/operation/gemm.hpp
+20
-0
codegen/src/device_gemm_elementwise_gemm.cpp
codegen/src/device_gemm_elementwise_gemm.cpp
+38
-0
codegen/src/device_gemm_elementwise_gemm_operation_xdl_cshuffle.cpp
...c/device_gemm_elementwise_gemm_operation_xdl_cshuffle.cpp
+435
-0
No files found.
codegen/include/ck/host/device_gemm_elementwise_gemm/operation.hpp
0 → 100644
View file @
d1e9682a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp"
#include "ck/host/device_gemm_elementwise_gemm/problem.hpp"
namespace
ck
{
namespace
host
{
namespace
device_gemm_elementwise_gemm
{
// defines all values need for an instance of fwd conv
struct
Operation_Xdl_CShuffle
{
// returns a vector of instances, only given fusion operators: will use default problem spec
static
std
::
vector
<
std
::
vector
<
Operation_Xdl_CShuffle
>>
CreateOperations
(
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
);
// returns a vector of instances, given a problem spec and fusion operators
static
std
::
vector
<
Operation_Xdl_CShuffle
>
CreateOperations
(
const
Problem
&
prob
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
);
TensorDesc
A
{};
TensorDesc
B0
{};
TensorDesc
B1
{};
TensorDesc
C
{};
DataType
acc
=
DataType
::
Float
;
DataType
cs_type
=
DataType
::
Half
;
std
::
string
a_elem_op
=
PassThrough
;
std
::
string
b0_elem_op
=
PassThrough
;
std
::
string
acc0_elem_op
=
PassThrough
;
std
::
string
b1_elem_op
=
PassThrough
;
std
::
string
c_elem_op
=
PassThrough
;
std
::
string
prologue
=
""
;
std
::
string
epilogue
=
""
;
std
::
string
gemm_specialization
=
"ck::tensor_operation::device::GemmSpecialization::Default"
;
// tuning parameters
operation
::
TileDescGemmElementwiseGemm
tile_desc
{};
operation
::
BlockTransferDesc
a_block_transfer
{};
operation
::
BlockTransferDesc
b0_block_transfer
{};
operation
::
BlockTransferDesc
b1_block_transfer
{};
operation
::
CShuffleDesc
cshuffle
{};
operation
::
CBlockTransferDesc
c_block_transfer
{};
// functions to update fusion operators if provided
void
update_prologue
(
const
std
::
string
&
prologue
);
void
update_epilogue
(
const
std
::
string
&
epilogue
);
/**constexpr**/
bool
IsSupported
(
std
::
size_t
MRaw_
,
std
::
size_t
NRaw_
,
std
::
size_t
KRaw_
);
// returns a templated instance
Solution
ToSolution
()
const
;
};
}
// namespace device_gemm_elementwise_gemm
}
// namespace host
}
// namespace ck
codegen/include/ck/host/device_gemm_elementwise_gemm/problem.hpp
0 → 100644
View file @
d1e9682a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
namespace
ck
{
namespace
host
{
namespace
device_gemm_elementwise_gemm
{
// defines the problem specification for a GEMM operation
struct
Problem
{
std
::
size_t
M
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
K
=
0
;
std
::
size_t
O
=
0
;
bool
TransA
=
false
;
bool
TransB0
=
false
;
bool
TransB1
=
false
;
bool
TransC
=
false
;
DataType
ADataType
=
DataType
::
Half
;
DataType
B0DataType
=
DataType
::
Half
;
DataType
B1DataType
=
DataType
::
Half
;
DataType
CDataType
=
DataType
::
Half
;
std
::
string
AElementOp
=
PassThrough
;
std
::
string
B0ElementOp
=
PassThrough
;
std
::
string
Acc0ElementOp
=
PassThrough
;
std
::
string
B1ElementOp
=
PassThrough
;
std
::
string
CElementOp
=
PassThrough
;
// returns the correct device op file for the operation
std
::
string
GetIncludeHeader
()
const
;
// returns a list of instances based on the problem spec and provided fusion operations
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
)
const
;
};
}
// namespace device_gemm_elementwise_gemm
}
// namespace host
}
// namespace ck
codegen/include/ck/host/operation/gemm.hpp
View file @
d1e9682a
...
...
@@ -23,6 +23,26 @@ struct TileDesc
int
n_Xdl_per_wave
=
0
;
int
num_gemmk_prefetch_stage
=
0
;
};
struct
TileDescGemmElementwiseGemm
{
int
block_size
=
0
;
int
gemm01_m_per_block
=
0
;
int
gemm0_n_per_block
=
0
;
int
gemm0_k_per_block
=
0
;
int
gemm1_n_per_block
=
0
;
int
gemm1_k_per_block
=
0
;
int
ak1
=
0
;
int
bk1
=
0
;
int
b1k1
=
0
;
int
m_per_XDL
=
0
;
int
n_per_XDL
=
0
;
int
gemm0_m_Xdl_per_wave
=
0
;
int
gemm0_n_Xdl_per_wave
=
0
;
int
gemm1_n_Xdl_per_wave
=
0
;
int
num_gemmk_prefetch_stage
=
0
;
};
struct
BlockTransferDesc
{
std
::
string
thread_cluster_length
=
""
;
...
...
codegen/src/device_gemm_elementwise_gemm.cpp
0 → 100644
View file @
d1e9682a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_gemm_elementwise_gemm/problem.hpp"
#include "ck/host/device_gemm_elementwise_gemm/operation.hpp"
#include "ck/host/utils.hpp"
#include <algorithm>
namespace
ck
{
namespace
host
{
namespace
device_gemm_elementwise_gemm
{
// return the relevant device op file based on the operation
std
::
string
Problem
::
GetIncludeHeader
()
const
{
return
"ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
;
}
// returns templated instances when provided with a problem specification
std
::
vector
<
Solution
>
Problem
::
GetSolutions
(
const
std
::
string
&
arch
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
)
const
{
if
(
get_xdlop_archs
().
count
(
arch
)
==
0
)
return
{};
auto
ops
=
ck
::
host
::
device_gemm_elementwise_gemm
::
Operation_Xdl_CShuffle
::
CreateOperations
(
*
this
,
prologue
,
epilogue
);
// obtains vector of instances
std
::
vector
<
Solution
>
result
;
std
::
transform
(
ops
.
begin
(),
ops
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
op
)
{
return
op
.
ToSolution
();
// template instance with correct values
});
return
result
;
}
}
// namespace device_gemm_elementwise_gemm
}
// namespace host
}
// namespace ck
codegen/src/device_gemm_elementwise_gemm_operation_xdl_cshuffle.cpp
0 → 100644
View file @
d1e9682a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_gemm_elementwise_gemm/operation.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/utils.hpp"
#include <cassert>
namespace
ck
{
namespace
host
{
namespace
device_gemm_elementwise_gemm
{
// calculate appropriate Gemm Specification based on input tensor dimensions
std
::
string
GetGemmSpec
(
const
std
::
size_t
m
,
const
std
::
size_t
n
,
const
std
::
size_t
k
,
const
std
::
size_t
n1
,
const
std
::
size_t
m_per_block
,
const
std
::
size_t
n_per_block
,
const
std
::
size_t
k_per_block
,
const
std
::
size_t
n1_per_block
)
{
std
::
string
spec
=
""
;
if
(
integer_divide_ceil
(
m
,
m_per_block
)
*
m_per_block
-
m
!=
0
)
spec
+=
"M"
;
if
(
integer_divide_ceil
(
n
,
n_per_block
)
*
n_per_block
-
n
!=
0
)
spec
+=
"N"
;
if
(
integer_divide_ceil
(
k
,
k_per_block
)
*
k_per_block
-
k
!=
0
)
spec
+=
"K"
;
if
(
integer_divide_ceil
(
n1
,
n1_per_block
)
*
n1_per_block
-
n1
!=
0
)
spec
+=
"O"
;
if
(
spec
==
""
)
return
"ck::tensor_operation::device::GemmSpecialization::Default"
;
return
"ck::tensor_operation::device::GemmSpecialization::"
+
spec
+
"Padding"
;
}
// function to update prologue/epilogue with user provided operation
void
Operation_Xdl_CShuffle
::
update_prologue
(
const
std
::
string
&
pro
)
{
if
(
!
prologue
.
empty
())
{
this
->
prologue
=
pro
;
}
else
{
this
->
prologue
=
""
;
}
}
void
Operation_Xdl_CShuffle
::
update_epilogue
(
const
std
::
string
&
epi
)
{
if
(
!
epilogue
.
empty
())
{
this
->
epilogue
=
epi
;
}
else
{
this
->
epilogue
=
""
;
}
}
// accounts for all possible combinations of Row/Col major
static
Layout
ToLayout
(
bool
Trans
)
{
return
Trans
?
Layout
::
Column
:
Layout
::
Row
;
}
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
// instances
std
::
vector
<
Operation_Xdl_CShuffle
>
Operation_Xdl_CShuffle
::
CreateOperations
(
const
Problem
&
prob
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
)
{
assert
(
prob
.
TransA
==
false
);
assert
(
prob
.
TransB0
==
true
);
assert
(
prob
.
TransC
==
false
);
const
auto
b1k1
=
prob
.
TransB1
?
4
:
2
;
std
::
vector
<
Operation_Xdl_CShuffle
>
result
;
std
::
vector
<
operation
::
TileDescGemmElementwiseGemm
>
tile_descriptions
=
{
// clang-format off
// Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| NumGemmK|
// Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Prefetch|
// | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Stage|
// | | | | | | | | | | | Wave| Wave| Wave| |
{
256
,
256
,
128
,
32
,
64
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
2
,
4
,
2
,
1
},
{
256
,
256
,
128
,
32
,
128
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
2
,
4
,
4
,
1
},
{
256
,
128
,
256
,
32
,
64
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
1
,
8
,
2
,
1
},
{
256
,
128
,
256
,
32
,
128
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
1
,
8
,
4
,
1
},
{
256
,
128
,
128
,
64
,
64
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
1
,
4
,
2
,
1
},
{
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
1
,
4
,
2
,
1
},
{
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
1
,
4
,
4
,
1
},
{
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
1
,
4
,
4
,
1
},
{
256
,
64
,
256
,
32
,
128
,
32
,
8
,
8
,
b1k1
,
16
,
16
,
1
,
16
,
8
,
1
},
{
256
,
64
,
256
,
32
,
64
,
32
,
8
,
8
,
b1k1
,
16
,
16
,
1
,
16
,
4
,
1
},
{
256
,
64
,
256
,
64
,
128
,
32
,
8
,
8
,
b1k1
,
16
,
16
,
1
,
16
,
8
,
1
},
{
256
,
64
,
256
,
64
,
64
,
32
,
8
,
8
,
b1k1
,
16
,
16
,
1
,
16
,
4
,
1
},
// Padded fallback kerne
{
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
1
,
4
,
4
,
1
},
{
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
b1k1
,
32
,
32
,
1
,
2
,
4
,
1
},
// clang-format on
};
const
std
::
vector
<
operation
::
BlockTransferDesc
>
a_block_descriptions
=
{
// clang-format off
// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
// Padded fallback kernel
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
// clang-format on
};
const
auto
&
b0_block_descriptions_rowmajor
=
a_block_descriptions
;
const
std
::
vector
<
operation
::
BlockTransferDesc
>
b0_block_descriptions_colmajor
=
{
// clang-format off
// B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
},
// clang-format on
};
const
std
::
vector
<
operation
::
BlockTransferDesc
>
b1_block_descriptions_rowmajor
=
{
// clang-format off
// B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
{
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
// Padded fallback kernel
{
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
{
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
},
// clang-format on
};
const
std
::
vector
<
operation
::
BlockTransferDesc
>
b1_block_descriptions_colmajor
=
{
// clang-format off
// B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
// Padded fallback kernel
{
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
},
// clang-format on
};
std
::
vector
<
operation
::
CShuffleDesc
>
cshuffle_descriptions
=
{
// clang-format off
// CShuffle| CShuffle|
// MXdlPerWave| NXdlPerWave|
// PerShuffle| PerShuffle|
// | |
{
1
,
2
},
{
1
,
2
},
{
1
,
2
},
{
1
,
2
},
{
1
,
2
},
{
1
,
2
},
{
1
,
2
},
{
1
,
2
},
{
1
,
8
},
{
1
,
4
},
{
1
,
8
},
{
1
,
4
},
// Padded fallback kernel
{
1
,
2
},
{
1
,
2
},
// clang-format on
};
std
::
vector
<
operation
::
CBlockTransferDesc
>
c_block_descriptions
=
{
// clang-format off
// CBlockTransferClusterLengths| CBlockTransfer
// _MBlock_MWaveMPerXdl| ScalarPerVector
// _NBlock_NWaveNPerXdl| _NWaveNPerXdl
// |
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
16
,
1
,
16
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
16
,
1
,
16
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
// Padded fallback kernel
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
// clang-format on
};
// choose correct arrangement of tuning parameters based on the layout of each tensor
const
auto
&
b0_block_descriptions
=
prob
.
TransB1
?
b0_block_descriptions_colmajor
:
b0_block_descriptions_rowmajor
;
const
auto
&
b1_block_descriptions
=
prob
.
TransB1
?
b1_block_descriptions_colmajor
:
b1_block_descriptions_rowmajor
;
assert
(
tile_descriptions
.
size
()
==
a_block_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
b1_block_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
cshuffle_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
c_block_descriptions
.
size
());
// Put all values together into a single operation > store into the result vector
for
(
std
::
size_t
i
=
0
;
i
<
tile_descriptions
.
size
();
i
++
)
{
Operation_Xdl_CShuffle
x
;
x
.
tile_desc
=
tile_descriptions
[
i
];
x
.
a_block_transfer
=
a_block_descriptions
[
i
];
x
.
b0_block_transfer
=
b0_block_descriptions
[
i
];
x
.
b1_block_transfer
=
b1_block_descriptions
[
i
];
x
.
cshuffle
=
cshuffle_descriptions
[
i
];
x
.
c_block_transfer
=
c_block_descriptions
[
i
];
x
.
A
=
TensorDesc
{
prob
.
ADataType
,
ToLayout
(
prob
.
TransA
)};
x
.
B0
=
TensorDesc
{
prob
.
B0DataType
,
ToLayout
(
prob
.
TransB0
)};
x
.
B1
=
TensorDesc
{
prob
.
B1DataType
,
ToLayout
(
prob
.
TransB1
)};
x
.
C
=
TensorDesc
{
prob
.
CDataType
,
ToLayout
(
prob
.
TransC
)};
x
.
a_elem_op
=
prob
.
AElementOp
;
x
.
b0_elem_op
=
prob
.
B0ElementOp
;
x
.
b1_elem_op
=
prob
.
B1ElementOp
;
x
.
c_elem_op
=
prob
.
CElementOp
;
x
.
acc0_elem_op
=
prob
.
Acc0ElementOp
;
x
.
gemm_specialization
=
GetGemmSpec
(
prob
.
M
,
prob
.
N
,
prob
.
K
,
prob
.
O
,
x
.
tile_desc
.
gemm01_m_per_block
,
x
.
tile_desc
.
gemm0_n_per_block
,
x
.
tile_desc
.
gemm0_k_per_block
,
x
.
tile_desc
.
gemm1_n_per_block
);
x
.
update_prologue
(
prologue
);
x
.
update_epilogue
(
epilogue
);
result
.
push_back
(
x
);
}
return
result
;
}
// set up instances when not provided with a problem specification, use default operation values and
// all possible layout combinations
std
::
vector
<
std
::
vector
<
Operation_Xdl_CShuffle
>>
Operation_Xdl_CShuffle
::
CreateOperations
(
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
)
{
std
::
vector
<
std
::
vector
<
Operation_Xdl_CShuffle
>>
operations
;
Problem
prob
;
prob
.
TransA
=
false
;
prob
.
TransB0
=
true
;
prob
.
TransB1
=
false
;
prob
.
TransC
=
false
;
operations
.
push_back
(
CreateOperations
(
prob
,
prologue
,
epilogue
));
prob
.
TransB1
=
true
;
operations
.
push_back
(
CreateOperations
(
prob
,
prologue
,
epilogue
));
return
operations
;
}
static
const
char
*
const
DeviceBatchedGemmGemm_Xdl_CShuffleTemplate
=
"ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<${LayoutA}, "
"${LayoutB0}, ${LayoutB1}, ${LayoutC}, ${ADataType}, ${B0DataType}, ${B1DataType}, "
"${CDataType}, ${AccDataType}, ${CShuffleDataType}, ${AElementwiseOperation}, "
"${B0ElementwiseOperation}, ${Acc0ElementwiseOperation}, ${B1ElementwiseOperation}, "
"${CElementwiseOperation}, ${GemmSpecialization}, ${NumGemmkPrefetchStage}, ${BlockSize}, "
"${Gemm01MPerBlock}, ${Gemm0NPerBlock}, ${Gemm0KPerBlock}, ${Gemm1NPerBlock}, "
"${Gemm1KPerBlock}, ${AK1}, ${BK1}, ${B1K1}, ${MPerXDL}, ${NPerXDL}, ${Gemm0MXdlPerWave}, "
"${Gemm0NXdlPerWave}, ${Gemm1NXdlPerWave}, ${ABlockTransferThreadClusterLengths_AK0_M_AK1}, "
"${ABlockTransferThreadClusterArrangeOrder}, ${ABlockTransferSrcAccessOrder}, "
"${ABlockTransferSrcVectorDim}, ${ABlockTransferSrcScalarPerVector}, "
"${ABlockTransferDstScalarPerVector_AK1}, ${ABlockLdsExtraM}, "
"${B0BlockTransferThreadClusterLengths_BK0_N_BK1}, "
"${B0BlockTransferThreadClusterArrangeOrder}, ${B0BlockTransferSrcAccessOrder}, "
"${B0BlockTransferSrcVectorDim}, ${B0BlockTransferSrcScalarPerVector}, "
"${B0BlockTransferDstScalarPerVector_BK1}, ${B0BlockLdsExtraN}, "
"${B1BlockTransferThreadClusterLengths_BK0_N_BK1}, "
"${B1BlockTransferThreadClusterArrangeOrder}, ${B1BlockTransferSrcAccessOrder}, "
"${B1BlockTransferSrcVectorDim}, ${B1BlockTransferSrcScalarPerVector}, "
"${B1BlockTransferDstScalarPerVector_BK1}, ${B1BlockLdsExtraN}, "
"${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
"${CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl}, "
"${CBlockTransferScalarPerVector_NWaveNPerXdl}>"
;
// use hardcoded instances from vector of operations to substitute values into instance template
Solution
Operation_Xdl_CShuffle
::
ToSolution
()
const
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
values
=
{
{
"name"
,
std
::
to_string
(
this
->
tile_desc
.
block_size
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm01_m_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_n_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_k_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm1_n_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm1_k_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
ak1
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
bk1
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
b1k1
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
m_per_XDL
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
n_per_XDL
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_m_Xdl_per_wave
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_n_Xdl_per_wave
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm1_n_Xdl_per_wave
)},
{
"LayoutA"
,
ToString
(
this
->
A
.
layout
)},
{
"LayoutB0"
,
ToString
(
this
->
B0
.
layout
)},
{
"LayoutB1"
,
ToString
(
this
->
B1
.
layout
)},
{
"LayoutC"
,
ToString
(
this
->
C
.
layout
)},
{
"ADataType"
,
ToString
(
this
->
A
.
element
)},
{
"B0DataType"
,
ToString
(
this
->
B0
.
element
)},
{
"B1DataType"
,
ToString
(
this
->
B1
.
element
)},
{
"CDataType"
,
ToString
(
this
->
C
.
element
)},
{
"AccDataType"
,
ToString
(
this
->
acc
)},
{
"CShuffleDataType"
,
ToString
(
this
->
cs_type
)},
{
"AElementwiseOperation"
,
this
->
a_elem_op
},
{
"B0ElementwiseOperation"
,
this
->
b0_elem_op
},
{
"Acc0ElementwiseOperation"
,
this
->
acc0_elem_op
},
{
"B1ElementwiseOperation"
,
this
->
b1_elem_op
},
{
"CElementwiseOperation"
,
this
->
c_elem_op
},
{
"GemmSpecialization"
,
this
->
gemm_specialization
},
{
"NumGemmkPrefetchStage"
,
std
::
to_string
(
this
->
tile_desc
.
num_gemmk_prefetch_stage
)},
{
"BlockSize"
,
std
::
to_string
(
this
->
tile_desc
.
block_size
)},
{
"Gemm01MPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm01_m_per_block
)},
{
"Gemm0NPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0_n_per_block
)},
{
"Gemm0KPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0_k_per_block
)},
{
"Gemm1NPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm1_n_per_block
)},
{
"Gemm1KPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm1_k_per_block
)},
{
"AK1"
,
std
::
to_string
(
this
->
tile_desc
.
ak1
)},
{
"BK1"
,
std
::
to_string
(
this
->
tile_desc
.
bk1
)},
{
"B1K1"
,
std
::
to_string
(
this
->
tile_desc
.
b1k1
)},
{
"MPerXDL"
,
std
::
to_string
(
this
->
tile_desc
.
m_per_XDL
)},
{
"NPerXDL"
,
std
::
to_string
(
this
->
tile_desc
.
n_per_XDL
)},
{
"Gemm0MXdlPerWave"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0_m_Xdl_per_wave
)},
{
"Gemm0NXdlPerWave"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0_n_Xdl_per_wave
)},
{
"Gemm1NXdlPerWave"
,
std
::
to_string
(
this
->
tile_desc
.
gemm1_n_Xdl_per_wave
)},
{
"ABlockTransferThreadClusterLengths_AK0_M_AK1"
,
this
->
a_block_transfer
.
thread_cluster_length
},
{
"ABlockTransferThreadClusterArrangeOrder"
,
this
->
a_block_transfer
.
thread_cluster_arrange_order
},
{
"ABlockTransferSrcAccessOrder"
,
this
->
a_block_transfer
.
src_access_order
},
{
"ABlockTransferSrcVectorDim"
,
std
::
to_string
(
this
->
a_block_transfer
.
src_vec_dim
)},
{
"ABlockTransferSrcScalarPerVector"
,
std
::
to_string
(
this
->
a_block_transfer
.
src_scalar_per_vector
)},
{
"ABlockTransferDstScalarPerVector_AK1"
,
std
::
to_string
(
this
->
a_block_transfer
.
dst_scalar_per_vector_k1
)},
{
"ABlockLdsExtraM"
,
std
::
to_string
(
this
->
a_block_transfer
.
lds_add_extra_dim
)},
{
"B0BlockTransferThreadClusterLengths_BK0_N_BK1"
,
this
->
b0_block_transfer
.
thread_cluster_length
},
{
"B0BlockTransferThreadClusterArrangeOrder"
,
this
->
b0_block_transfer
.
thread_cluster_arrange_order
},
{
"B0BlockTransferSrcAccessOrder"
,
this
->
b0_block_transfer
.
src_access_order
},
{
"B0BlockTransferSrcVectorDim"
,
std
::
to_string
(
this
->
b0_block_transfer
.
src_vec_dim
)},
{
"B0BlockTransferSrcScalarPerVector"
,
std
::
to_string
(
this
->
b0_block_transfer
.
src_scalar_per_vector
)},
{
"B0BlockTransferDstScalarPerVector_BK1"
,
std
::
to_string
(
this
->
b0_block_transfer
.
dst_scalar_per_vector_k1
)},
{
"B0BlockLdsExtraN"
,
std
::
to_string
(
this
->
b0_block_transfer
.
lds_add_extra_dim
)},
{
"B1BlockTransferThreadClusterLengths_BK0_N_BK1"
,
this
->
b1_block_transfer
.
thread_cluster_length
},
{
"B1BlockTransferThreadClusterArrangeOrder"
,
this
->
b1_block_transfer
.
thread_cluster_arrange_order
},
{
"B1BlockTransferSrcAccessOrder"
,
this
->
b1_block_transfer
.
src_access_order
},
{
"B1BlockTransferSrcVectorDim"
,
std
::
to_string
(
this
->
b1_block_transfer
.
src_vec_dim
)},
{
"B1BlockTransferSrcScalarPerVector"
,
std
::
to_string
(
this
->
b1_block_transfer
.
src_scalar_per_vector
)},
{
"B1BlockTransferDstScalarPerVector_BK1"
,
std
::
to_string
(
this
->
b1_block_transfer
.
dst_scalar_per_vector_k1
)},
{
"B1BlockLdsExtraN"
,
std
::
to_string
(
this
->
b1_block_transfer
.
lds_add_extra_dim
)},
{
"CShuffleMXdlPerWavePerShuffle"
,
std
::
to_string
(
this
->
cshuffle
.
m_Xdl_per_wave_per_shuffle
)},
{
"CShuffleNXdlPerWavePerShuffle"
,
std
::
to_string
(
this
->
cshuffle
.
n_Xdl_per_wave_per_shuffle
)},
{
"CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl"
,
this
->
c_block_transfer
.
cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl
},
{
"CBlockTransferScalarPerVector_NWaveNPerXdl"
,
std
::
to_string
(
this
->
c_block_transfer
.
scalar_per_vector_n_wave_n_per_Xdl
)},
};
return
Solution
{
InterpolateString
(
DeviceBatchedGemmGemm_Xdl_CShuffleTemplate
,
values
),
std
::
move
(
values
)};
}
}
// namespace device_gemm_elementwise_gemm
}
// namespace host
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment