Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
250a89f3
Commit
250a89f3
authored
Oct 14, 2024
by
Mirza Halilcevic
Browse files
Replace gemm_gemm with gemm_multiple_d_gemm_multiple_d.
parent
d1e9682a
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
638 additions
and
43 deletions
+638
-43
codegen/include/ck/host/device_batched_gemm_multiple_d_gemm_multiple_d/operation.hpp
...ice_batched_gemm_multiple_d_gemm_multiple_d/operation.hpp
+24
-20
codegen/include/ck/host/device_batched_gemm_multiple_d_gemm_multiple_d/problem.hpp
...evice_batched_gemm_multiple_d_gemm_multiple_d/problem.hpp
+51
-0
codegen/include/ck/host/operation/gemm.hpp
codegen/include/ck/host/operation/gemm.hpp
+27
-16
codegen/src/device_batched_gemm_multiple_d_gemm_multiple_d.cpp
...en/src/device_batched_gemm_multiple_d_gemm_multiple_d.cpp
+8
-7
codegen/src/device_batched_gemm_multiple_d_gemm_multiple_d_operation_xdl_cshuffle.cpp
...emm_multiple_d_gemm_multiple_d_operation_xdl_cshuffle.cpp
+528
-0
No files found.
codegen/include/ck/host/device_
gemm_elementwise_gemm
/operation.hpp
→
codegen/include/ck/host/device_
batched_gemm_multiple_d_gemm_multiple_d
/operation.hpp
View file @
250a89f3
...
...
@@ -8,13 +8,13 @@
#include <string>
#include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp"
#include "ck/host/device_
gemm_elementwise_gemm
/problem.hpp"
#include "ck/host/device_
batched_gemm_multiple_d_gemm_multiple_d
/problem.hpp"
namespace
ck
{
namespace
host
{
namespace
device_
gemm_elementwise_gemm
{
namespace
device_
batched_gemm_multiple_d_gemm_multiple_d
{
// defines all values need for an instance
of fwd conv
// defines all values need
ed
for an instance
struct
Operation_Xdl_CShuffle
{
// returns a vector of instances, only given fusion operators: will use default problem spec
...
...
@@ -23,36 +23,40 @@ struct Operation_Xdl_CShuffle
// returns a vector of instances, given a problem spec and fusion operators
static
std
::
vector
<
Operation_Xdl_CShuffle
>
CreateOperations
(
const
Problem
&
prob
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
);
TensorDesc
A
{};
TensorDesc
A
0
{};
TensorDesc
B0
{};
std
::
vector
<
TensorDesc
>
D0s
=
{};
TensorDesc
B1
{};
TensorDesc
C
{};
DataType
acc
=
DataType
::
Float
;
DataType
c
s
_type
=
DataType
::
Half
;
std
::
string
a_elem_op
=
PassThrough
;
std
::
string
b
0_elem_op
=
PassThrough
;
std
::
string
acc
0_elem_op
=
PassThrough
;
std
::
string
b1
_elem_op
=
PassThrough
;
std
::
string
c
_elem_op
=
PassThrough
;
std
::
string
prologue
=
""
;
std
::
string
epi
logue
=
""
;
std
::
string
gemm_specialization
=
"ck::tensor_operation::device::GemmSpecialization::Default
"
;
std
::
vector
<
TensorDesc
>
D1s
=
{};
TensorDesc
E1
{}
;
DataType
ac
c_type
=
DataType
::
Float
;
DataType
cshuffle_type
=
DataType
::
Float
;
std
::
string
a
0_elem_op
=
PassThrough
;
std
::
string
b
0_elem_op
=
PassThrough
;
std
::
string
cde0
_elem_op
=
PassThrough
;
std
::
string
b1
_elem_op
=
PassThrough
;
std
::
string
cde1_elem_op
=
PassThrough
;
std
::
string
pro
logue
=
""
;
std
::
string
epilogue
=
"
"
;
// tuning parameters
operation
::
TileDescGemmElementwiseGemm
tile_desc
{};
operation
::
BlockTransferDesc
a_block_transfer
{};
operation
::
PaddingDesc
padding_desc
{};
operation
::
TileDescGemmGemm
tile_desc
{};
operation
::
BlockTransferDesc
a0_block_transfer
{};
operation
::
BlockTransferDesc
b0_block_transfer
{};
operation
::
BlockTransferDesc
cde0_block_transfer
{};
operation
::
BlockTransferDesc
b1_block_transfer
{};
operation
::
CShuffleDesc
cshuffle
{};
operation
::
CBlockTransferDesc
c_block_transfer
{};
operation
::
CBlockTransferDesc
c
de1
_block_transfer
{};
// functions to update fusion operators if provided
void
update_prologue
(
const
std
::
string
&
prologue
);
void
update_epilogue
(
const
std
::
string
&
epilogue
);
/**constexpr**/
bool
IsSupported
(
std
::
size_t
MRaw_
,
std
::
size_t
NRaw_
,
std
::
size_t
KRaw_
);
/**constexpr**/
bool
IsSupported
(
std
::
size_t
MRaw_
,
std
::
size_t
NRaw_
,
std
::
size_t
KRaw_
,
std
::
size_t
Gemm1NRaw_
);
// returns a templated instance
Solution
ToSolution
()
const
;
};
}
// namespace device_
gemm_elementwise_gemm
}
// namespace device_
batched_gemm_multiple_d_gemm_multiple_d
}
// namespace host
}
// namespace ck
codegen/include/ck/host/device_
gemm_elementwise_gemm
/problem.hpp
→
codegen/include/ck/host/device_
batched_gemm_multiple_d_gemm_multiple_d
/problem.hpp
View file @
250a89f3
...
...
@@ -10,28 +10,32 @@
namespace
ck
{
namespace
host
{
namespace
device_
gemm_elementwise_gemm
{
namespace
device_
batched_gemm_multiple_d_gemm_multiple_d
{
// defines the problem specification for a GEMM operation
// defines the problem specification for a GEMM
_ELEMENTWISE_GEMM
operation
struct
Problem
{
std
::
size_t
M
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
K
=
0
;
std
::
size_t
O
=
0
;
bool
TransA
=
false
;
bool
TransB0
=
false
;
bool
TransB1
=
false
;
bool
TransC
=
false
;
DataType
ADataType
=
DataType
::
Half
;
DataType
B0DataType
=
DataType
::
Half
;
DataType
B1DataType
=
DataType
::
Half
;
DataType
CDataType
=
DataType
::
Half
;
std
::
string
AElementOp
=
PassThrough
;
std
::
string
B0ElementOp
=
PassThrough
;
std
::
string
Acc0ElementOp
=
PassThrough
;
std
::
string
B1ElementOp
=
PassThrough
;
std
::
string
CElementOp
=
PassThrough
;
std
::
size_t
M
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
K
=
0
;
std
::
size_t
O
=
0
;
bool
TransA0
=
false
;
bool
TransB0
=
false
;
std
::
vector
<
bool
>
D0sTrans
=
{};
bool
TransB1
=
false
;
std
::
vector
<
bool
>
D1sTrans
=
{};
bool
TransE1
=
false
;
DataType
A0DataType
=
DataType
::
Half
;
DataType
B0DataType
=
DataType
::
Half
;
std
::
vector
<
DataType
>
D0sDataType
=
{};
DataType
B1DataType
=
DataType
::
Half
;
std
::
vector
<
DataType
>
D1sDataType
=
{};
DataType
E1DataType
=
DataType
::
Half
;
std
::
string
A0ElementOp
=
PassThrough
;
std
::
string
B0ElementOp
=
PassThrough
;
std
::
string
CDE0ElementOp
=
PassThrough
;
std
::
string
B1ElementOp
=
PassThrough
;
std
::
string
CDE1ElementOp
=
PassThrough
;
// returns the correct device op file for the operation
std
::
string
GetIncludeHeader
()
const
;
...
...
@@ -42,6 +46,6 @@ struct Problem
const
std
::
string
&
epilogue
)
const
;
};
}
// namespace device_
gemm_elementwise_gemm
}
// namespace device_
batched_gemm_multiple_d_gemm_multiple_d
}
// namespace host
}
// namespace ck
codegen/include/ck/host/operation/gemm.hpp
View file @
250a89f3
...
...
@@ -9,6 +9,15 @@ namespace ck {
namespace
host
{
namespace
operation
{
struct
PaddingDesc
{
bool
pad_gemm0_m
=
0
;
bool
pad_gemm0_n
=
0
;
bool
pad_gemm0_k
=
0
;
bool
pad_gemm1_n
=
0
;
bool
pad_gemm1_k
=
0
;
};
struct
TileDesc
{
int
block_size
=
0
;
...
...
@@ -24,23 +33,23 @@ struct TileDesc
int
num_gemmk_prefetch_stage
=
0
;
};
struct
TileDescGemm
Elementwise
Gemm
struct
TileDescGemmGemm
{
int
block_size
=
0
;
int
gemm0
1
_m_per_block
=
0
;
int
gemm0_n_per_block
=
0
;
int
gemm0_k_per_block
=
0
;
int
gemm1_n_per_block
=
0
;
int
gemm1_k_per_block
=
0
;
int
ak1
=
0
;
int
bk1
=
0
;
int
b1k1
=
0
;
int
m_per_XDL
=
0
;
int
n_per_XDL
=
0
;
int
gemm0_m_Xdl_per_wave
=
0
;
int
gemm0_n_Xdl_per_wave
=
0
;
int
gemm1_n_Xdl_per_wave
=
0
;
int
num_gemmk_prefetch_stage
=
0
;
int
block_size
=
0
;
int
gemm0_m_per_block
=
0
;
int
gemm0_n_per_block
=
0
;
int
gemm0_k_per_block
=
0
;
int
gemm1_n_per_block
=
0
;
int
gemm1_k_per_block
=
0
;
int
a
0
k1
=
0
;
int
b
0
k1
=
0
;
int
b1k1
=
0
;
int
m_per_XDL
=
0
;
int
n_per_XDL
=
0
;
int
gemm0_m_Xdl_per_wave
=
0
;
int
gemm0_n_Xdl_per_wave
=
0
;
int
gemm1_n_Xdl_per_wave
=
0
;
int
num_gemm
0
k_prefetch_stage
=
0
;
};
struct
BlockTransferDesc
...
...
@@ -53,11 +62,13 @@ struct BlockTransferDesc
int
dst_scalar_per_vector_k1
=
0
;
int
lds_add_extra_dim
=
0
;
};
struct
CShuffleDesc
{
int
m_Xdl_per_wave_per_shuffle
=
0
;
int
n_Xdl_per_wave_per_shuffle
=
0
;
};
struct
CBlockTransferDesc
{
std
::
string
cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl
=
""
;
...
...
codegen/src/device_
gemm_elementwise_gemm
.cpp
→
codegen/src/device_
batched_gemm_multiple_d_gemm_multiple_d
.cpp
View file @
250a89f3
...
...
@@ -2,19 +2,20 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_
gemm_elementwise_gemm
/problem.hpp"
#include "ck/host/device_
gemm_elementwise_gemm
/operation.hpp"
#include "ck/host/device_
batched_gemm_multiple_d_gemm_multiple_d
/problem.hpp"
#include "ck/host/device_
batched_gemm_multiple_d_gemm_multiple_d
/operation.hpp"
#include "ck/host/utils.hpp"
#include <algorithm>
namespace
ck
{
namespace
host
{
namespace
device_
gemm_elementwise_gemm
{
namespace
device_
batched_gemm_multiple_d_gemm_multiple_d
{
// return the relevant device op file based on the operation
std
::
string
Problem
::
GetIncludeHeader
()
const
{
return
"ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
;
return
"ck/tensor_operation/gpu/device/impl/"
"device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp"
;
}
// returns templated instances when provided with a problem specification
...
...
@@ -24,8 +25,8 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch,
{
if
(
get_xdlop_archs
().
count
(
arch
)
==
0
)
return
{};
auto
ops
=
ck
::
host
::
device_
gemm_elementwise_gemm
::
Operation_Xdl_CShuffle
::
CreateOperations
(
*
this
,
prologue
,
epilogue
);
// obtains vector of instances
auto
ops
=
ck
::
host
::
device_
batched_gemm_multiple_d_gemm_multiple_d
::
Operation_Xdl_CShuffle
::
CreateOperations
(
*
this
,
prologue
,
epilogue
);
// obtains vector of instances
std
::
vector
<
Solution
>
result
;
std
::
transform
(
ops
.
begin
(),
ops
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
op
)
{
return
op
.
ToSolution
();
// template instance with correct values
...
...
@@ -33,6 +34,6 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch,
return
result
;
}
}
// namespace device_
gemm_elementwise_gemm
}
// namespace device_
batched_gemm_multiple_d_gemm_multiple_d
}
// namespace host
}
// namespace ck
codegen/src/device_
gemm_elementwise_gemm
_operation_xdl_cshuffle.cpp
→
codegen/src/device_
batched_gemm_multiple_d_gemm_multiple_d
_operation_xdl_cshuffle.cpp
View file @
250a89f3
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment