"include/ck/utility/functional.hpp" did not exist on "8a4b59785b4f5ba48468d53618ca270c5da599a7"
Commit 8d378877 authored by turneram's avatar turneram
Browse files

Switch to elementwise

parent b41a56cf
...@@ -118,7 +118,7 @@ register_migraphx_ops( ...@@ -118,7 +118,7 @@ register_migraphx_ops(
broadcast broadcast
capture capture
ceil ceil
ck_gemm ck_elementwise
clip clip
concat concat
contiguous contiguous
......
...@@ -21,59 +21,52 @@ ...@@ -21,59 +21,52 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_OPERATORS_CK_GEMM_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_CK_ELEMENTWISE_HPP
#define MIGRAPHX_GUARD_OPERATORS_CK_GEMM_HPP #define MIGRAPHX_GUARD_OPERATORS_CK_ELEMENTWISE_HPP
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/gemm.hpp> #include <migraphx/par_for.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct ck_gemm struct ck_elementwise
{ {
std::string name() const { return "ck_gemm"; } std::string name() const { return "ck_elementwise"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.same_type().has(2); check_shapes{inputs, *this}.has(2).same_type().same_dims();
const shape& a = inputs.at(0); auto s0 = inputs.at(0);
const shape& b = inputs.at(1); auto s1 = inputs.at(1);
auto t = a.type(); if(s0 == s1 and s0.packed())
if(!std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
{ {
MIGRAPHX_THROW("DOT: dot only accept 2 or more dims operands"); return s0;
} }
else if(s0.packed() != s1.packed())
// only handle the case that the batch size of a and b are the same
if(!std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{ {
MIGRAPHX_THROW("DOT: batch size of A and B mismatch: {" + to_string_range(a.lens()) + return s0.packed() ? s0 : s1;
"} x {" + to_string_range(b.lens()) + "}");
} }
else if(s0.broadcasted() != s1.broadcasted())
std::size_t dim_0 = a.lens().size() - 2;
std::size_t dim_1 = a.lens().size() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0])
{ {
MIGRAPHX_THROW("DOT: inner dimensions do not match: {" + to_string_range(a.lens()) + return s0.broadcasted() ? s1.with_lens(s0.lens()) : s0.with_lens(s0.lens());
"} x {" + to_string_range(b.lens()) + "}"); }
else
{
return {s0.type(), s0.lens()};
} }
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
return {t, out_lens};
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
argument result = argument{output_shape}; argument result{output_shape};
visit_all(result, args[0], args[1])( visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); }); par_for(output_shape.elements(), [&](const auto i) {
output[i] = input1[i] + input2[i];
});
});
return result; return result;
} }
}; };
......
...@@ -40,7 +40,7 @@ ...@@ -40,7 +40,7 @@
#include <migraphx/op/broadcast.hpp> #include <migraphx/op/broadcast.hpp>
#include <migraphx/op/capture.hpp> #include <migraphx/op/capture.hpp>
#include <migraphx/op/ceil.hpp> #include <migraphx/op/ceil.hpp>
#include <migraphx/op/ck_gemm.hpp> #include <migraphx/op/ck_elementwise.hpp>
#include <migraphx/op/clip.hpp> #include <migraphx/op/clip.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/op/concat.hpp> #include <migraphx/op/concat.hpp>
......
...@@ -40,139 +40,57 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -40,139 +40,57 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
// NOLINTNEXTLINE // NOLINTNEXTLINE
static const char* const ck_gemm_kernel = R"__migraphx__( static const char* const ck_elementwise_kernel = R"__migraphx__(
#include <migraphx/kernels/ck_gemm.hpp> #include <migraphx/kernels/ck_elementwise.hpp>
#include <migraphx/kernels/ops.hpp> #include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp> #include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp> #include <args.hpp>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
namespace migraphx { namespace migraphx {
extern "C" { extern "C" {
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; __global__ void ck_elementwise_kernel(void* a_p, void* b_p, void* c_p)
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
// clang-format off
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
// clang-format on
__global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
{ {
// GEMM shape using F16 = ck::half_t;
ck::index_t M = 3840; using F32 = float;
ck::index_t N = 4096;
ck::index_t K = 4096; using ABDataType = F16;
using CDataType = F16;
ck::index_t StrideA = 4096; using EltwiseComputeDataType = F32;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096; using Add = ck::tensor_operation::element_wise::Add;
auto a_element_op = AElementOp{}; using DeviceElementwiseAddInstance =
auto b_element_op = BElementOp{}; ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
auto c_element_op = CElementOp{}; ABDataType,
CDataType,
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); EltwiseComputeDataType,
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); Add,
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); 1,
8,
// GridwiseGemm 8,
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< 8,
ADataType, // TODO: distinguish A/B datatype 8>;
AccDataType, ck::index_t M = 1024;
CShuffleDataType, std::array<const void*, 2> input = {a_p,
CDataType, b_p};
AElementOp, std::array<void*, 1> output = {c_p};
BElementOp,
CElementOp, std::vector<ck::index_t> a_strides = {1};
ck::InMemoryDataOperationEnum::Set, std::vector<ck::index_t> b_strides = {1};
AGridDesc_AK0_M_AK1, std::vector<ck::index_t> c_strides = {1};
BGridDesc_BK0_N_BK1,
CGridDesc_M_N, auto broadcastAdd = DeviceElementwiseAddInstance{};
NumGemmKPrefetchStage, auto argument = broadcastAdd.MakeArgumentPointer(
BlockSize, input, output, {M}, {{a_strides}, b_strides}, {c_strides}, Add{});
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementOp,
BElementOp,
CElementOp,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
kernel<<<1, 1, 1, 0>>>(p_a, p_b, p_c);
} }
} }
...@@ -181,9 +99,9 @@ __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p) ...@@ -181,9 +99,9 @@ __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
)__migraphx__"; )__migraphx__";
struct ck_gemm_compiler : compiler<ck_gemm_compiler> struct ck_elementwise_compiler : compiler<ck_elementwise_compiler>
{ {
std::vector<std::string> names() const { return {"ck_gemm"}; } std::vector<std::string> names() const { return {"ck_elementwise"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
...@@ -192,10 +110,10 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -192,10 +110,10 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
options.set_launch_params(v, compute_global_for(ctx, out_s.elements())); options.set_launch_params(v, compute_global_for(ctx, out_s.elements()));
options.inputs = inputs; options.inputs = inputs;
options.output = out_s; options.output = out_s;
options.kernel_name = "ck_gemm_kernel"; options.kernel_name = "ck_elementwise_kernel";
options.virtual_inputs = inputs; options.virtual_inputs = inputs;
return compile_hip_code_object(ck_gemm_kernel, options); return compile_hip_code_object(ck_elementwise_kernel, options);
} }
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
\ No newline at end of file
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP #ifndef MIGRAPHX_GUARD_KERNELS_CK_ELEMENTWISE_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP #define MIGRAPHX_GUARD_KERNELS_CK_ELEMENTWISE_HPP
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp> #include <migraphx/kernels/algorithm.hpp>
...@@ -30,9 +30,10 @@ ...@@ -30,9 +30,10 @@
namespace migraphx { namespace migraphx {
template <class T, class U, class V> template <class T, class U, class V>
__device__ void ck_gemm(const T& /* data_t */, const U& /* indices_t */, const V& /* output_t */) __device__ void ck_elementwise(const T& /* data_t */, const U& /* indices_t */, const V& /* output_t */)
{ {
} }
} // namespace migraphx } // namespace migraphx
#endif #endif
...@@ -27,18 +27,18 @@ ...@@ -27,18 +27,18 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct ck_gemm : verify_program<ck_gemm> struct ck_elementwise : verify_program<ck_elementwise>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {10, 20}}; migraphx::shape m1_shape{migraphx::shape::float_type, {10, 20}};
migraphx::shape m2_shape{migraphx::shape::float_type, {20, 10}}; //migraphx::shape m2_shape{migraphx::shape::float_type, {20, 10}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m1_shape);
mm->add_instruction(migraphx::make_op("ck_gemm"), l1, l2); mm->add_instruction(migraphx::make_op("ck_elementwise"), l1, l2);
return p; return p;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment