Commit 8d378877 authored by turneram's avatar turneram
Browse files

Switch to elementwise

parent b41a56cf
......@@ -118,7 +118,7 @@ register_migraphx_ops(
broadcast
capture
ceil
ck_gemm
ck_elementwise
clip
concat
contiguous
......
......@@ -21,59 +21,52 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_CK_GEMM_HPP
#define MIGRAPHX_GUARD_OPERATORS_CK_GEMM_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_CK_ELEMENTWISE_HPP
#define MIGRAPHX_GUARD_OPERATORS_CK_ELEMENTWISE_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gemm.hpp>
#include <migraphx/par_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct ck_gemm
struct ck_elementwise
{
std::string name() const { return "ck_gemm"; }
std::string name() const { return "ck_elementwise"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.same_type().has(2);
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
if(!std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
check_shapes{inputs, *this}.has(2).same_type().same_dims();
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed())
{
MIGRAPHX_THROW("DOT: dot only accept 2 or more dims operands");
return s0;
}
// only handle the case that the batch size of a and b are the same
if(!std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
else if(s0.packed() != s1.packed())
{
MIGRAPHX_THROW("DOT: batch size of A and B mismatch: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
return s0.packed() ? s0 : s1;
}
std::size_t dim_0 = a.lens().size() - 2;
std::size_t dim_1 = a.lens().size() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0])
else if(s0.broadcasted() != s1.broadcasted())
{
MIGRAPHX_THROW("DOT: inner dimensions do not match: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
return s0.broadcasted() ? s1.with_lens(s0.lens()) : s0.with_lens(s0.lens());
}
else
{
return {s0.type(), s0.lens()};
}
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
return {t, out_lens};
}
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result = argument{output_shape};
visit_all(result, args[0], args[1])(
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); });
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
par_for(output_shape.elements(), [&](const auto i) {
output[i] = input1[i] + input2[i];
});
});
return result;
}
};
......
......@@ -40,7 +40,7 @@
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/op/ceil.hpp>
#include <migraphx/op/ck_gemm.hpp>
#include <migraphx/op/ck_elementwise.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/op/concat.hpp>
......
......@@ -40,139 +40,57 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// NOLINTNEXTLINE
static const char* const ck_gemm_kernel = R"__migraphx__(
#include <migraphx/kernels/ck_gemm.hpp>
static const char* const ck_elementwise_kernel = R"__migraphx__(
#include <migraphx/kernels/ck_elementwise.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
namespace migraphx {
extern "C" {
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
// clang-format off
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
// clang-format on
__global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
__global__ void ck_elementwise_kernel(void* a_p, void* b_p, void* c_p)
{
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CShuffleDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp,
ck::InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementOp,
BElementOp,
CElementOp,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
kernel<<<1, 1, 1, 0>>>(p_a, p_b, p_c);
using F16 = ck::half_t;
using F32 = float;
using ABDataType = F16;
using CDataType = F16;
using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::element_wise::Add;
using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
ABDataType,
CDataType,
EltwiseComputeDataType,
Add,
1,
8,
8,
8,
8>;
ck::index_t M = 1024;
std::array<const void*, 2> input = {a_p,
b_p};
std::array<void*, 1> output = {c_p};
std::vector<ck::index_t> a_strides = {1};
std::vector<ck::index_t> b_strides = {1};
std::vector<ck::index_t> c_strides = {1};
auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer(
input, output, {M}, {{a_strides}, b_strides}, {c_strides}, Add{});
}
}
......@@ -181,9 +99,9 @@ __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
)__migraphx__";
struct ck_gemm_compiler : compiler<ck_gemm_compiler>
struct ck_elementwise_compiler : compiler<ck_elementwise_compiler>
{
std::vector<std::string> names() const { return {"ck_gemm"}; }
std::vector<std::string> names() const { return {"ck_elementwise"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
......@@ -192,10 +110,10 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
options.set_launch_params(v, compute_global_for(ctx, out_s.elements()));
options.inputs = inputs;
options.output = out_s;
options.kernel_name = "ck_gemm_kernel";
options.kernel_name = "ck_elementwise_kernel";
options.virtual_inputs = inputs;
return compile_hip_code_object(ck_gemm_kernel, options);
return compile_hip_code_object(ck_elementwise_kernel, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
\ No newline at end of file
......@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_CK_ELEMENTWISE_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_ELEMENTWISE_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
......@@ -30,9 +30,10 @@
namespace migraphx {
template <class T, class U, class V>
__device__ void ck_gemm(const T& /* data_t */, const U& /* indices_t */, const V& /* output_t */)
__device__ void ck_elementwise(const T& /* data_t */, const U& /* indices_t */, const V& /* output_t */)
{
}
} // namespace migraphx
#endif
......@@ -27,18 +27,18 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct ck_gemm : verify_program<ck_gemm>
struct ck_elementwise : verify_program<ck_elementwise>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {10, 20}};
migraphx::shape m2_shape{migraphx::shape::float_type, {20, 10}};
//migraphx::shape m2_shape{migraphx::shape::float_type, {20, 10}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
auto l2 = mm->add_parameter("2", m1_shape);
mm->add_instruction(migraphx::make_op("ck_gemm"), l1, l2);
mm->add_instruction(migraphx::make_op("ck_elementwise"), l1, l2);
return p;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment