Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
cc2535e0
Commit
cc2535e0
authored
Sep 08, 2022
by
turneram
Browse files
Merge elementwise
parent
b41a56cf
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
409 additions
and
122 deletions
+409
-122
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+1
-0
src/targets/gpu/jit/ck_elementwise.cpp
src/targets/gpu/jit/ck_elementwise.cpp
+92
-0
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+4
-120
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
...s/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
+119
-0
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
+146
-1
test/verify/0ck_elementwise_test.cpp
test/verify/0ck_elementwise_test.cpp
+45
-0
test/verify/0ck_gemm_test.cpp
test/verify/0ck_gemm_test.cpp
+1
-1
No files found.
src/CMakeLists.txt
View file @
cc2535e0
...
@@ -118,6 +118,7 @@ register_migraphx_ops(
...
@@ -118,6 +118,7 @@ register_migraphx_ops(
broadcast
broadcast
capture
capture
ceil
ceil
ck_elementwise
ck_gemm
ck_gemm
clip
clip
concat
concat
...
...
src/include/migraphx/operators.hpp
View file @
cc2535e0
...
@@ -40,6 +40,7 @@
...
@@ -40,6 +40,7 @@
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/op/ceil.hpp>
#include <migraphx/op/ceil.hpp>
#include <migraphx/op/ck_elementwise.hpp>
#include <migraphx/op/ck_gemm.hpp>
#include <migraphx/op/ck_gemm.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/op/common.hpp>
...
...
src/targets/gpu/jit/ck_elementwise.cpp
0 → 100644
View file @
cc2535e0
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
// NOLINTNEXTLINE
static
const
char
*
const
ck_elementwise_kernel
=
R"__migraphx__(
#include <migraphx/kernels/ck_elementwise.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void ck_elementwise_kernel(void* a_p, void* b_p, void* c_p)
{
make_tensors()(a_p, b_p, c_p)([](auto&&... xs) {
ck_elementwise(xs...);
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
ck_elementwise_compiler
:
compiler
<
ck_elementwise_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_elementwise"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
hip_compile_options
options
;
auto
out_s
=
inputs
.
back
();
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
out_s
.
elements
()));
options
.
inputs
=
inputs
;
options
.
output
=
out_s
;
options
.
kernel_name
=
"ck_elementwise_kernel"
;
options
.
virtual_inputs
=
inputs
;
return
compile_hip_code_object
(
ck_elementwise_kernel
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
op
.
to_value
()));
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/jit/ck_gemm.cpp
View file @
cc2535e0
...
@@ -39,6 +39,7 @@ namespace migraphx {
...
@@ -39,6 +39,7 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
// NOLINTNEXTLINE
// NOLINTNEXTLINE
static
const
char
*
const
ck_gemm_kernel
=
R"__migraphx__(
static
const
char
*
const
ck_gemm_kernel
=
R"__migraphx__(
#include <migraphx/kernels/ck_gemm.hpp>
#include <migraphx/kernels/ck_gemm.hpp>
...
@@ -47,132 +48,15 @@ static const char* const ck_gemm_kernel = R"__migraphx__(
...
@@ -47,132 +48,15 @@ static const char* const ck_gemm_kernel = R"__migraphx__(
#include <migraphx/kernels/generic_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
#include <args.hpp>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
namespace migraphx {
namespace migraphx {
extern "C" {
extern "C" {
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
// clang-format off
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
// clang-format on
__global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
__global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
{
{
// GEMM shape
make_tensors()(a_p, b_p, c_p)([](auto&&... xs) {
ck::index_t M = 3840;
ck_gemm(xs...);
ck::index_t N = 4096;
});
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CShuffleDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp,
ck::InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementOp,
BElementOp,
CElementOp,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
kernel<<<1, 1, 1, 0>>>(p_a, p_b, p_c);
}
}
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
0 → 100644
View file @
cc2535e0
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CK_ELEMENTWISE_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_ELEMENTWISE_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
namespace
migraphx
{
using
ADataType
=
float
;
using
BDataType
=
float
;
using
CDataType
=
float
;
using
ElementwiseFunctor
=
float
;
static
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
template
<
class
L
,
class
S
,
class
N
>
constexpr
auto
MakeDescriptor_M
(
const
L
&
lengths
,
const
S
&
strides
,
const
N
&
/* ndim */
)
{
auto
gridSize
=
72
;
auto
blockSize
=
1024
;
constexpr
auto
ndim
=
1
;
// auto idx = make_index();
auto
tupleOfShape
=
generate_tuple
([
&
](
auto
I
)
{
return
static_cast
<
ck
::
index_t
>
(
lengths
[
I
]);
},
ck
::
Number
<
ndim
>
{});
auto
tupleOfStride
=
generate_tuple
(
[
&
](
auto
I
)
{
return
static_cast
<
ck
::
index_t
>
(
strides
[
I
]);
},
ck
::
Number
<
1
>
{});
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleOfShape
,
tupleOfStride
);
auto
desc_m
=
desc
;
// merge nd to 1d desc - [s0 * s1 * ...]
if
constexpr
(
ndim
>
1
)
{
desc_m
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleOfShape
)),
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
ck
::
Number
<
ndim
>
{})),
make_tuple
(
ck
::
Sequence
<
0
>
{}));
}
const
auto
M
=
desc_m
.
GetLength
(
I0
);
const
ck
::
index_t
loop_step
=
/* idx.nglobal(); // */
gridSize
*
blockSize
/* * MPerThread */
;
const
auto
pad
=
ck
::
math
::
integer_least_multiple
(
M
,
loop_step
)
-
M
;
const
auto
desc_m_pad
=
transform_tensor_descriptor
(
desc_m
,
make_tuple
(
ck
::
make_right_pad_transform
(
M
,
pad
)),
make_tuple
(
ck
::
Sequence
<
0
>
{}),
make_tuple
(
ck
::
Sequence
<
0
>
{}));
return
desc_m_pad
;
}
struct
Add
{
template
<
typename
Y
,
typename
X0
,
typename
X1
>
__device__
constexpr
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
)
const
{
y
=
x0
+
x1
;
};
};
template
<
class
T
,
class
U
,
class
V
>
__device__
void
ck_elementwise
(
const
T
&
a_t
,
const
U
&
b_t
,
const
V
&
c_t
)
{
auto
idx
=
make_index
();
if
(
idx
.
global
==
0
)
{
constexpr
auto
lengths
=
get_shape_c
<
T
>
{}.
lens
;
constexpr
auto
strides
=
get_shape_c
<
T
>
{}.
strides
;
constexpr
auto
a_desc
=
MakeDescriptor_M
(
lengths
,
strides
,
1
);
using
AGridDesc_M
=
decltype
(
a_desc
);
using
GridwiseBinEltwise
=
ck
::
GridwiseBinaryElementwise_1D
<
ADataType
,
BDataType
,
CDataType
,
CDataType
,
AGridDesc_M
,
AGridDesc_M
,
AGridDesc_M
,
Add
,
1
,
1
,
1
,
1
>
;
auto
op
=
Add
{};
GridwiseBinEltwise
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
a_desc
,
a_desc
,
a_desc
,
op
);
}
}
}
// namespace migraphx
#endif
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
View file @
cc2535e0
...
@@ -27,11 +27,156 @@
...
@@ -27,11 +27,156 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace
migraphx
{
namespace
migraphx
{
// static constexpr auto I0 = Number<0>{};
// static constexpr auto I1 = Number<1>{};
// static constexpr auto I2 = Number<2>{};
// static constexpr auto I3 = Number<3>{};
// static constexpr auto I4 = Number<4>{};
// static constexpr auto I5 = Number<5>{};
// static constexpr auto K1Number = Number<1>{};
// static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
// {
// assert(K % K1 == 0);
// const index_t K0 = K / K1;
// const auto a_grid_desc_m_k = [&]() {
// if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
// }
// else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
// }
// }();
// if constexpr(GemmSpec == GemmSpecialization::MNPadding)
// {
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
// return transform_tensor_descriptor(
// a_grid_desc_m_k,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
// make_right_pad_transform(M, PadM)),
// make_tuple(Sequence<1>{}, Sequence<0>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// }
// else
// {
// return transform_tensor_descriptor(
// a_grid_desc_m_k,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
// make_pass_through_transform(M)),
// make_tuple(Sequence<1>{}, Sequence<0>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// }
// }
// static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
// {
// assert(K % K1 == 0);
// const index_t K0 = K / K1;
// const auto b_grid_desc_k_n = [&]() {
// if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
// }
// else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
// }
// }();
// if constexpr(GemmSpec == GemmSpecialization::MNPadding)
// {
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
// return transform_tensor_descriptor(
// b_grid_desc_k_n,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
// make_right_pad_transform(N, PadN)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// }
// else
// {
// return transform_tensor_descriptor(
// b_grid_desc_k_n,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
// make_pass_through_transform(N)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// }
// }
// static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
// {
// const auto c_grid_desc_m_n = [&]() {
// if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
// }
// else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
// }
// }();
// if constexpr(GemmSpec == GemmSpecialization::MNPadding)
// {
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
// return transform_tensor_descriptor(
// c_grid_desc_m_n,
// make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// }
// else
// {
// return transform_tensor_descriptor(
// c_grid_desc_m_n,
// make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// }
// }
template
<
class
T
,
class
U
,
class
V
>
template
<
class
T
,
class
U
,
class
V
>
__device__
void
ck_gemm
(
const
T
&
/*
dat
a_t */
,
const
U
&
/*
indices
_t */
,
const
V
&
/*
output
_t */
)
__device__
void
ck_gemm
(
const
T
&
/* a_t */
,
const
U
&
/*
b
_t */
,
const
V
&
/*
c
_t */
)
{
{
constexpr
auto
alens
=
get_shape_c
<
T
>
{}.
lens
;
constexpr
auto
m
=
alens
[
0
];
constexpr
auto
k
=
alens
[
1
];
constexpr
auto
alens
=
get_shape_c
<
U
>
{}.
lens
;
constexpr
auto
n
=
alens
[
1
];
constexpr
auto
astrides
=
get_shape_c
<
T
>
{}.
strides
;
constexpr
auto
as
=
astrides
[
1
];
constexpr
auto
bstrides
=
get_shape_c
<
U
>
{}.
strides
;
constexpr
auto
bs
=
bstrides
[
1
];
constexpr
auto
cstrides
=
get_shape_c
<
V
>
{}.
strides
;
constexpr
auto
cs
=
cstrides
[
1
];
printf
(
"%i %i %i, %i %i %i
\n
"
,
int
(
m
),
int
(
n
),
int
(
k
),
int
(
as
),
int
(
bs
),
int
(
cs
));
}
}
}
// namespace migraphx
}
// namespace migraphx
...
...
test/verify/0ck_elementwise_test.cpp
0 → 100644
View file @
cc2535e0
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
ck_elementwise
:
verify_program
<
ck_elementwise
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
20
}};
// migraphx::shape m2_shape{migraphx::shape::float_type, {20, 10}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m1_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"ck_elementwise"
),
l1
,
l2
);
return
p
;
}
};
test/verify/0ck_gemm_test.cpp
View file @
cc2535e0
...
@@ -34,7 +34,7 @@ struct ck_gemm : verify_program<ck_gemm>
...
@@ -34,7 +34,7 @@ struct ck_gemm : verify_program<ck_gemm>
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
10
,
20
}};
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
10
,
20
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
20
,
1
0
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
20
,
2
0
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment