Commit cc2535e0 authored by turneram's avatar turneram
Browse files

Merge elementwise

parent b41a56cf
...@@ -118,6 +118,7 @@ register_migraphx_ops( ...@@ -118,6 +118,7 @@ register_migraphx_ops(
broadcast broadcast
capture capture
ceil ceil
ck_elementwise
ck_gemm ck_gemm
clip clip
concat concat
......
...@@ -40,6 +40,7 @@ ...@@ -40,6 +40,7 @@
#include <migraphx/op/broadcast.hpp> #include <migraphx/op/broadcast.hpp>
#include <migraphx/op/capture.hpp> #include <migraphx/op/capture.hpp>
#include <migraphx/op/ceil.hpp> #include <migraphx/op/ceil.hpp>
#include <migraphx/op/ck_elementwise.hpp>
#include <migraphx/op/ck_gemm.hpp> #include <migraphx/op/ck_gemm.hpp>
#include <migraphx/op/clip.hpp> #include <migraphx/op/clip.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// NOLINTNEXTLINE
static const char* const ck_elementwise_kernel = R"__migraphx__(
#include <migraphx/kernels/ck_elementwise.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void ck_elementwise_kernel(void* a_p, void* b_p, void* c_p)
{
make_tensors()(a_p, b_p, c_p)([](auto&&... xs) {
ck_elementwise(xs...);
});
}
}
} // namespace migraphx
)__migraphx__";
struct ck_elementwise_compiler : compiler<ck_elementwise_compiler>
{
std::vector<std::string> names() const { return {"ck_elementwise"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
auto out_s = inputs.back();
options.set_launch_params(v, compute_global_for(ctx, out_s.elements()));
options.inputs = inputs;
options.output = out_s;
options.kernel_name = "ck_elementwise_kernel";
options.virtual_inputs = inputs;
return compile_hip_code_object(ck_elementwise_kernel, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -39,6 +39,7 @@ namespace migraphx { ...@@ -39,6 +39,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
// NOLINTNEXTLINE // NOLINTNEXTLINE
static const char* const ck_gemm_kernel = R"__migraphx__( static const char* const ck_gemm_kernel = R"__migraphx__(
#include <migraphx/kernels/ck_gemm.hpp> #include <migraphx/kernels/ck_gemm.hpp>
...@@ -47,132 +48,15 @@ static const char* const ck_gemm_kernel = R"__migraphx__( ...@@ -47,132 +48,15 @@ static const char* const ck_gemm_kernel = R"__migraphx__(
#include <migraphx/kernels/generic_constant.hpp> #include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp> #include <args.hpp>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
namespace migraphx { namespace migraphx {
extern "C" { extern "C" {
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
// clang-format off
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
// clang-format on
__global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p) __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
{ {
// GEMM shape make_tensors()(a_p, b_p, c_p)([](auto&&... xs) {
ck::index_t M = 3840; ck_gemm(xs...);
ck::index_t N = 4096; });
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CShuffleDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp,
ck::InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementOp,
BElementOp,
CElementOp,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
kernel<<<1, 1, 1, 0>>>(p_a, p_b, p_c);
} }
} }
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CK_ELEMENTWISE_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_ELEMENTWISE_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
namespace migraphx {
using ADataType = float;
using BDataType = float;
using CDataType = float;
using ElementwiseFunctor = float;
static constexpr auto I0 = ck::Number<0>{};
template <class L, class S, class N>
constexpr auto MakeDescriptor_M(const L& lengths, const S& strides, const N& /* ndim */)
{
auto gridSize = 72;
auto blockSize = 1024;
constexpr auto ndim = 1;
// auto idx = make_index();
auto tupleOfShape = generate_tuple([&](auto I) { return static_cast<ck::index_t>(lengths[I]); },
ck::Number<ndim>{});
auto tupleOfStride = generate_tuple(
[&](auto I) { return static_cast<ck::index_t>(strides[I]); }, ck::Number<1>{});
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
auto desc_m = desc;
// merge nd to 1d desc - [s0 * s1 * ...]
if constexpr(ndim > 1)
{
desc_m = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, ck::Number<ndim>{})),
make_tuple(ck::Sequence<0>{}));
}
const auto M = desc_m.GetLength(I0);
const ck::index_t loop_step = /* idx.nglobal(); // */ gridSize * blockSize /* * MPerThread */;
const auto pad = ck::math::integer_least_multiple(M, loop_step) - M;
const auto desc_m_pad =
transform_tensor_descriptor(desc_m,
make_tuple(ck::make_right_pad_transform(M, pad)),
make_tuple(ck::Sequence<0>{}),
make_tuple(ck::Sequence<0>{}));
return desc_m_pad;
}
struct Add
{
template <typename Y, typename X0, typename X1>
__device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
{
y = x0 + x1;
};
};
template <class T, class U, class V>
__device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
{
auto idx = make_index();
if(idx.global == 0)
{
constexpr auto lengths = get_shape_c<T>{}.lens;
constexpr auto strides = get_shape_c<T>{}.strides;
constexpr auto a_desc = MakeDescriptor_M(lengths, strides, 1);
using AGridDesc_M = decltype(a_desc);
using GridwiseBinEltwise = ck::GridwiseBinaryElementwise_1D<ADataType,
BDataType,
CDataType,
CDataType,
AGridDesc_M,
AGridDesc_M,
AGridDesc_M,
Add,
1,
1,
1,
1>;
auto op = Add{};
GridwiseBinEltwise::Run(a_t.data(), b_t.data(), c_t.data(), a_desc, a_desc, a_desc, op);
}
}
} // namespace migraphx
#endif
...@@ -27,11 +27,156 @@ ...@@ -27,11 +27,156 @@
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp> #include <migraphx/kernels/algorithm.hpp>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace migraphx { namespace migraphx {
// static constexpr auto I0 = Number<0>{};
// static constexpr auto I1 = Number<1>{};
// static constexpr auto I2 = Number<2>{};
// static constexpr auto I3 = Number<3>{};
// static constexpr auto I4 = Number<4>{};
// static constexpr auto I5 = Number<5>{};
// static constexpr auto K1Number = Number<1>{};
// static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
// {
// assert(K % K1 == 0);
// const index_t K0 = K / K1;
// const auto a_grid_desc_m_k = [&]() {
// if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
// }
// else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
// }
// }();
// if constexpr(GemmSpec == GemmSpecialization::MNPadding)
// {
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
// return transform_tensor_descriptor(
// a_grid_desc_m_k,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
// make_right_pad_transform(M, PadM)),
// make_tuple(Sequence<1>{}, Sequence<0>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// }
// else
// {
// return transform_tensor_descriptor(
// a_grid_desc_m_k,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
// make_pass_through_transform(M)),
// make_tuple(Sequence<1>{}, Sequence<0>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// }
// }
// static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
// {
// assert(K % K1 == 0);
// const index_t K0 = K / K1;
// const auto b_grid_desc_k_n = [&]() {
// if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
// }
// else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
// }
// }();
// if constexpr(GemmSpec == GemmSpecialization::MNPadding)
// {
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
// return transform_tensor_descriptor(
// b_grid_desc_k_n,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
// make_right_pad_transform(N, PadN)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// }
// else
// {
// return transform_tensor_descriptor(
// b_grid_desc_k_n,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
// make_pass_through_transform(N)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// }
// }
// static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
// {
// const auto c_grid_desc_m_n = [&]() {
// if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
// }
// else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
// }
// }();
// if constexpr(GemmSpec == GemmSpecialization::MNPadding)
// {
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
// return transform_tensor_descriptor(
// c_grid_desc_m_n,
// make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// }
// else
// {
// return transform_tensor_descriptor(
// c_grid_desc_m_n,
// make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// }
// }
template <class T, class U, class V> template <class T, class U, class V>
__device__ void ck_gemm(const T& /* data_t */, const U& /* indices_t */, const V& /* output_t */) __device__ void ck_gemm(const T& /* a_t */, const U& /* b_t */, const V& /* c_t */)
{ {
constexpr auto alens = get_shape_c<T>{}.lens;
constexpr auto m = alens[0];
constexpr auto k = alens[1];
constexpr auto alens = get_shape_c<U>{}.lens;
constexpr auto n = alens[1];
constexpr auto astrides = get_shape_c<T>{}.strides;
constexpr auto as = astrides[1];
constexpr auto bstrides = get_shape_c<U>{}.strides;
constexpr auto bs = bstrides[1];
constexpr auto cstrides = get_shape_c<V>{}.strides;
constexpr auto cs = cstrides[1];
printf("%i %i %i, %i %i %i\n", int(m), int(n), int(k), int(as), int(bs), int(cs));
} }
} // namespace migraphx } // namespace migraphx
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct ck_elementwise : verify_program<ck_elementwise>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {20}};
// migraphx::shape m2_shape{migraphx::shape::float_type, {20, 10}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m1_shape);
mm->add_instruction(migraphx::make_op("ck_elementwise"), l1, l2);
return p;
}
};
...@@ -34,7 +34,7 @@ struct ck_gemm : verify_program<ck_gemm> ...@@ -34,7 +34,7 @@ struct ck_gemm : verify_program<ck_gemm>
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {10, 20}}; migraphx::shape m1_shape{migraphx::shape::float_type, {10, 20}};
migraphx::shape m2_shape{migraphx::shape::float_type, {20, 10}}; migraphx::shape m2_shape{migraphx::shape::float_type, {20, 20}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment