Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d1e9682a
Commit
d1e9682a
authored
Oct 02, 2024
by
Mirza Halilcevic
Browse files
Introduce gemm_elementwise_gemm.
parent
11b7a4db
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
598 additions
and
0 deletions
+598
-0
codegen/include/ck/host/device_gemm_elementwise_gemm/operation.hpp
...nclude/ck/host/device_gemm_elementwise_gemm/operation.hpp
+58
-0
codegen/include/ck/host/device_gemm_elementwise_gemm/problem.hpp
.../include/ck/host/device_gemm_elementwise_gemm/problem.hpp
+47
-0
codegen/include/ck/host/operation/gemm.hpp
codegen/include/ck/host/operation/gemm.hpp
+20
-0
codegen/src/device_gemm_elementwise_gemm.cpp
codegen/src/device_gemm_elementwise_gemm.cpp
+38
-0
codegen/src/device_gemm_elementwise_gemm_operation_xdl_cshuffle.cpp
...c/device_gemm_elementwise_gemm_operation_xdl_cshuffle.cpp
+435
-0
No files found.
codegen/include/ck/host/device_gemm_elementwise_gemm/operation.hpp
0 → 100644
View file @
d1e9682a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp"
#include "ck/host/device_gemm_elementwise_gemm/problem.hpp"
namespace
ck
{
namespace
host
{
namespace
device_gemm_elementwise_gemm
{
// defines all values need for an instance of fwd conv
struct
Operation_Xdl_CShuffle
{
// returns a vector of instances, only given fusion operators: will use default problem spec
static
std
::
vector
<
std
::
vector
<
Operation_Xdl_CShuffle
>>
CreateOperations
(
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
);
// returns a vector of instances, given a problem spec and fusion operators
static
std
::
vector
<
Operation_Xdl_CShuffle
>
CreateOperations
(
const
Problem
&
prob
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
);
TensorDesc
A
{};
TensorDesc
B0
{};
TensorDesc
B1
{};
TensorDesc
C
{};
DataType
acc
=
DataType
::
Float
;
DataType
cs_type
=
DataType
::
Half
;
std
::
string
a_elem_op
=
PassThrough
;
std
::
string
b0_elem_op
=
PassThrough
;
std
::
string
acc0_elem_op
=
PassThrough
;
std
::
string
b1_elem_op
=
PassThrough
;
std
::
string
c_elem_op
=
PassThrough
;
std
::
string
prologue
=
""
;
std
::
string
epilogue
=
""
;
std
::
string
gemm_specialization
=
"ck::tensor_operation::device::GemmSpecialization::Default"
;
// tuning parameters
operation
::
TileDescGemmElementwiseGemm
tile_desc
{};
operation
::
BlockTransferDesc
a_block_transfer
{};
operation
::
BlockTransferDesc
b0_block_transfer
{};
operation
::
BlockTransferDesc
b1_block_transfer
{};
operation
::
CShuffleDesc
cshuffle
{};
operation
::
CBlockTransferDesc
c_block_transfer
{};
// functions to update fusion operators if provided
void
update_prologue
(
const
std
::
string
&
prologue
);
void
update_epilogue
(
const
std
::
string
&
epilogue
);
/**constexpr**/
bool
IsSupported
(
std
::
size_t
MRaw_
,
std
::
size_t
NRaw_
,
std
::
size_t
KRaw_
);
// returns a templated instance
Solution
ToSolution
()
const
;
};
}
// namespace device_gemm_elementwise_gemm
}
// namespace host
}
// namespace ck
codegen/include/ck/host/device_gemm_elementwise_gemm/problem.hpp
0 → 100644
View file @
d1e9682a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
namespace
ck
{
namespace
host
{
namespace
device_gemm_elementwise_gemm
{
// defines the problem specification for a GEMM operation
struct
Problem
{
std
::
size_t
M
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
K
=
0
;
std
::
size_t
O
=
0
;
bool
TransA
=
false
;
bool
TransB0
=
false
;
bool
TransB1
=
false
;
bool
TransC
=
false
;
DataType
ADataType
=
DataType
::
Half
;
DataType
B0DataType
=
DataType
::
Half
;
DataType
B1DataType
=
DataType
::
Half
;
DataType
CDataType
=
DataType
::
Half
;
std
::
string
AElementOp
=
PassThrough
;
std
::
string
B0ElementOp
=
PassThrough
;
std
::
string
Acc0ElementOp
=
PassThrough
;
std
::
string
B1ElementOp
=
PassThrough
;
std
::
string
CElementOp
=
PassThrough
;
// returns the correct device op file for the operation
std
::
string
GetIncludeHeader
()
const
;
// returns a list of instances based on the problem spec and provided fusion operations
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
)
const
;
};
}
// namespace device_gemm_elementwise_gemm
}
// namespace host
}
// namespace ck
codegen/include/ck/host/operation/gemm.hpp
View file @
d1e9682a
...
...
@@ -23,6 +23,26 @@ struct TileDesc
int
n_Xdl_per_wave
=
0
;
int
num_gemmk_prefetch_stage
=
0
;
};
struct
TileDescGemmElementwiseGemm
{
int
block_size
=
0
;
int
gemm01_m_per_block
=
0
;
int
gemm0_n_per_block
=
0
;
int
gemm0_k_per_block
=
0
;
int
gemm1_n_per_block
=
0
;
int
gemm1_k_per_block
=
0
;
int
ak1
=
0
;
int
bk1
=
0
;
int
b1k1
=
0
;
int
m_per_XDL
=
0
;
int
n_per_XDL
=
0
;
int
gemm0_m_Xdl_per_wave
=
0
;
int
gemm0_n_Xdl_per_wave
=
0
;
int
gemm1_n_Xdl_per_wave
=
0
;
int
num_gemmk_prefetch_stage
=
0
;
};
struct
BlockTransferDesc
{
std
::
string
thread_cluster_length
=
""
;
...
...
codegen/src/device_gemm_elementwise_gemm.cpp
0 → 100644
View file @
d1e9682a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_gemm_elementwise_gemm/problem.hpp"
#include "ck/host/device_gemm_elementwise_gemm/operation.hpp"
#include "ck/host/utils.hpp"
#include <algorithm>
namespace
ck
{
namespace
host
{
namespace
device_gemm_elementwise_gemm
{
// return the relevant device op file based on the operation
std
::
string
Problem
::
GetIncludeHeader
()
const
{
return
"ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
;
}
// returns templated instances when provided with a problem specification
std
::
vector
<
Solution
>
Problem
::
GetSolutions
(
const
std
::
string
&
arch
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
)
const
{
if
(
get_xdlop_archs
().
count
(
arch
)
==
0
)
return
{};
auto
ops
=
ck
::
host
::
device_gemm_elementwise_gemm
::
Operation_Xdl_CShuffle
::
CreateOperations
(
*
this
,
prologue
,
epilogue
);
// obtains vector of instances
std
::
vector
<
Solution
>
result
;
std
::
transform
(
ops
.
begin
(),
ops
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
op
)
{
return
op
.
ToSolution
();
// template instance with correct values
});
return
result
;
}
}
// namespace device_gemm_elementwise_gemm
}
// namespace host
}
// namespace ck
codegen/src/device_gemm_elementwise_gemm_operation_xdl_cshuffle.cpp
0 → 100644
View file @
d1e9682a
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment