"vscode:/vscode.git/clone" did not exist on "c93e86956ca421336d1dae830e6d3f45a7d1e4e3"
Commit d7ea085c authored by Alan Turner's avatar Alan Turner
Browse files

Add gemm-softmax-gemm fusion

parent 4c370d64
...@@ -49,6 +49,44 @@ struct ck_gemm ...@@ -49,6 +49,44 @@ struct ck_gemm
}; };
MIGRAPHX_REGISTER_OP(ck_gemm); MIGRAPHX_REGISTER_OP(ck_gemm);
struct ck_gemm_scale_bias_softmax_gemm
{
operation op = make_op("dot");
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
}
std::string name() const { return "gpu::ck_gemm_softmax_gemm"; }
void check_gemm_shape(const shape& s) const
{
if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1))
MIGRAPHX_THROW("Invalid shape for ck_gemm_scale_bias_softmax_gemm");
}
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
check_shapes{inputs, *this}.same_ndims();
// if(mods.size() != 1)
// MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
auto a = inputs[0];
auto b = inputs[1];
auto b1 = inputs[2];
for(const auto& input : inputs)
{
//std::cout << input << std::endl;
check_gemm_shape(input);
}
return op.compute_shape({op.compute_shape({a, b}), b1});
}
};
MIGRAPHX_REGISTER_OP(ck_gemm_scale_bias_softmax_gemm);
namespace { namespace {
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
...@@ -116,14 +154,73 @@ struct find_ck_gemm ...@@ -116,14 +154,73 @@ struct find_ck_gemm
} }
}; };
struct find_ck_gemm_scale_bias_softmax_gemm
{
auto matcher() const
{
auto gemm1 = match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto pw = match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale_bias");
auto softmax = match::name("softmax")(match::any_of[match::inputs()](pw)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(match::any_of[match::inputs()](softmax));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
std::cout << "Matched" << std::endl;
auto ins = r.result;
auto gemm2_ins = r.instructions["gemm2"];
auto sm_ins = r.instructions["softmax"];
auto pw_ins = r.instructions["scale_bias"];
auto gemm1_ins = r.instructions["gemm1"];
gemm2_ins->debug_print();
sm_ins->debug_print();
pw_ins->debug_print();
gemm1_ins->debug_print();
auto inputs = gemm1_ins->inputs(); // A, B
inputs.push_back(gemm2_ins->inputs().back()); // B1
//inputs.push_back(pw_ins->inputs().back()); // C
mpm.get_module().replace_instruction(ins, ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs);
}
// auto matcher() const
// {
// auto gemm1 = match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
// auto softmax = match::name("softmax")(match::any_of[match::inputs()](gemm1)).bind("softmax");
// return match::name("dot")(is_ck_gemm().bind("gemm2"))(match::any_of[match::inputs()](softmax));
// }
// void apply(module_pass_manager& mpm, const match::matcher_result& r) const
// {
// std::cout << "Matched" << std::endl;
// auto ins = r.result;
// auto gemm2_ins = r.instructions["gemm2"];
// auto sm_ins = r.instructions["softmax"];
// auto gemm1_ins = r.instructions["gemm1"];
// gemm2_ins->debug_print();
// sm_ins->debug_print();
// gemm1_ins->debug_print();
// auto inputs = gemm1_ins->inputs(); // A, B
// inputs.push_back(gemm2_ins->inputs().back()); // B1
// mpm.get_module().replace_instruction(ins, ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs);
// }
};
} // namespace } // namespace
void fuse_ck::apply(module_pass_manager& mpm) const void fuse_ck::apply(module_pass_manager& mpm) const
{ {
if(not enabled(MIGRAPHX_DISABLE_CK_GEMM_FUSION{})) // mpm.get_module().debug_print();
match::find_matches(mpm, find_ck_gemm_pointwise{}); match::find_matches(mpm, find_ck_gemm_scale_bias_softmax_gemm{});
if(not enabled(MIGRAPHX_DISABLE_CK_GEMM{})) // if(not enabled(MIGRAPHX_DISABLE_CK_GEMM_FUSION{}))
match::find_matches(mpm, find_ck_gemm{}); // match::find_matches(mpm, find_ck_gemm_pointwise{});
// if(not enabled(MIGRAPHX_DISABLE_CK_GEMM{}))
// match::find_matches(mpm, find_ck_gemm{});
} }
} // namespace gpu } // namespace gpu
......
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <cstdio>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
...@@ -116,6 +118,21 @@ void gemm_impl(context& ctx, ...@@ -116,6 +118,21 @@ void gemm_impl(context& ctx,
{ {
beta = 0; beta = 0;
} }
// else
// {
// if (args[2].get_shape().lens()[1] == 12 and args[2].get_shape().lens()[2] == 2)
// {
// args[2].visit([&](auto output){
// std::cout << args[2].get_shape() << std::endl;
// for (auto i = 0; i < args[2].get_shape().elements(); ++i)
// {
// //if (output[i] == 0 )
// std::cout << output[i] << ", ";
// }
// std::cout << std::endl;
// });
// }
// }
bool transa = is_transposed(args[0].get_shape()); bool transa = is_transposed(args[0].get_shape());
bool transb = is_transposed(args[1].get_shape()); bool transb = is_transposed(args[1].get_shape());
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <fstream>
#include <filesystem>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/env.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/module.hpp>
#include <migraphx/env.hpp>
#include <migraphx/file_buffer.hpp>
const std::vector<std::string>&
get_instance(std::size_t i, const std::function<bool(const std::vector<std::string>&)>& pred);
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
using namespace migraphx::gpu::gen; // NOLINT
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG);
// NOLINTNEXTLINE
static const char* const ck_gemm_softmax_gemm_kernel = R"__migraphx__(
#include <args.hpp>
#include <migraphx/kernels/ck_gemm_softmax_gemm.hpp>
#include <migraphx/kernels/pointwise.hpp>
namespace migraphx {
${preamble}
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck_passthrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck_scale;//ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;//ck_add;//ck::tensor_operation::element_wise::Add;
using gemm = CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false, std::ratio<1, 8>>;
extern "C" {
__global__ void ${kernel}(${params})
{
// transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
// ck_gemm_softmax_gemm<CK_DeviceGemmMultipleD<${instance}>, ${blocks_per_batch}>(xs...);
// });
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
ck_gemm_softmax_gemm<gemm, ${blocks_per_batch}>(xs...);
});
}
}
} // namespace migraphx
)__migraphx__";
static std::size_t int_div_ceil(std::size_t x, std::size_t y) { return (x + y - 1) / y; }
struct instance
{
std::vector<std::string> params;
static const std::size_t block_size_index = 17;
std::size_t int_at(std::size_t i) const { return std::stoull(params[i]); }
std::size_t get_block_size() const { return int_at(block_size_index); }
std::size_t get_pb(std::size_t i) const
{
assert(i < 4);
return int_at(block_size_index + 1 + i);
}
std::array<std::size_t, 3> get_pad(const std::array<std::size_t, 3>& config) const
{
std::array<std::size_t, 3> result{};
for(auto i : range(config.size()))
{
result[i] = int_div_ceil(config[i], get_pb(i)) * get_pb(i) - config[i];
}
return result;
}
std::size_t get_grid_size(const std::array<std::size_t, 3>& config) const
{
return int_div_ceil(config[0], get_pb(0)) * int_div_ceil(config[1], get_pb(1));
}
void set_ds_layout(const std::string& s)
{
assert(params[2] == "ck::Tuple<>");
params[2] = s;
}
void set_ds_type(const std::string& s)
{
assert(params[8] == "ck::Tuple<>");
params[8] = s;
}
void set_ds_op(const std::string& s)
{
assert(params[12] == "ck_passthrough");
params[12] = s;
}
void set_gemm(const std::string& s)
{
assert(params[13] == "ck::tensor_operation::device::GemmSpecialization::Default");
params[13] = s;
}
std::string str() const { return join_strings(params, ","); }
};
template <class F, class Action>
auto action_decorate(F f, Action action)
{
return [=](auto&&... xs) {
action();
f(std::forward<decltype(xs)>(xs)...);
};
}
using tuning_entry = std::pair<std::vector<shape>, size_t>;
static std::vector<tuning_entry> read_tuning(const std::string& s)
{
if(not fs::exists(s))
return {};
return from_value<std::vector<tuning_entry>>(from_json_string(read_string(s)));
}
static std::size_t get_tuning_for(const std::vector<shape>& inputs)
{
static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, ""));
if(tuning.empty())
std::cout << "*********** Warning: No CK tuning!" << std::endl;
auto it = std::find_if(
tuning.begin(), tuning.end(), [&](const auto& p) { return p.first == inputs; });
if(it == tuning.end())
{
std::cout << "*********** Warning: CK tuning missing for config!" << std::endl;
return 4;
}
return it->second;
}
struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
static std::string get_layout(const shape& s)
{
return s.transposed() ? "ck::tensor_layout::gemm::ColumnMajor"
: "ck::tensor_layout::gemm::RowMajor";
}
static std::string get_type(const shape& s)
{
if(s.type() == shape::half_type)
return "ck::half_t";
return shape::cpp_type(s.type());
}
template <class Iterator, class F>
static std::string ck_tuple(Iterator start, Iterator last, F f)
{
std::vector<std::string> s;
std::transform(start, last, std::back_inserter(s), f);
return "ck::Tuple<" + join_strings(s, ",") + ">";
}
std::vector<std::string> names() const { return {"ck_gemm_softmax_gemm", "gpu::ck_gemm_softmax_gemm"}; }
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
{
auto a_shape = inputs[0];
auto b_shape = inputs[1];
auto c_shape = inputs.back();
auto m = a_shape.lens()[0];
auto k = a_shape.lens()[1];
auto n = c_shape.lens()[1];
auto rank = a_shape.lens().size();
// std::array<char, 3> keys{'M', 'N', 'K'};
// std::array<std::size_t, 3> config{
// c_shape.lens()[rank - 2], c_shape.lens().back(), a_shape.lens().back()};
// auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape}));
// auto ip = instance{get_instance(tuning_val, [&](const auto& x) -> bool {
// return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
// get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
// get_type(b_shape) == x[5] and get_type(c_shape) == x[9];
// })};
// assert(inputs.size() < 4 or v.contains("post"));
// if(v.contains("post"))
// {
// ip.set_ds_layout(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_layout));
// ip.set_ds_type(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_type));
// ip.set_ds_op(v.at("post").to<std::string>());
// }
// auto padding = ip.get_pad(config);
// std::string gemm_type;
// for(auto i : range(padding.size()))
// {
// if(padding[i] != 0)
// gemm_type += keys[i];
// }
// if(gemm_type.empty())
// gemm_type = "Default";
// else
// gemm_type += "Padding";
// ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
auto gemm1_nperblock = 64;
auto gemm01_mperblock = 256;
auto blocks_per_batch = int_div_ceil(m, gemm01_mperblock) * int_div_ceil(n, gemm1_nperblock);//ip.get_grid_size(config);
auto batch_count = std::accumulate(c_shape.lens().rbegin() + 2,
c_shape.lens().rend(),
std::size_t{1},
std::multiplies<std::size_t>());
hip_compile_options options;
auto block_size = 256; //ip.get_block_size();
auto grid_size = batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs;
options.output = c_shape;
options.kernel_name = v.get("kernel", "ck_gemm_softmax_gemm_kernel");
options.virtual_inputs = inputs;
if(v.get("check", false) or enabled(MIGRAPHX_CK_DEBUG{}))
options.params += " -DMIGRAPHX_CK_CHECK=1";
auto src = interpolate_string(ck_gemm_softmax_gemm_kernel,
{{"instance", ""/* ip.str() */},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"blocks_per_batch", to_string(blocks_per_batch)},
{"preamble", v.get("preamble", std::string{})},
{"kernel", options.kernel_name}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
auto v = op.to_value();
v["kernel"] = "ck_gemm_softmax_gemm_kernel";
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_softmax_gemm_function") +
"\nMIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, post_ck_gemm_softmax_gemm_function);";
v["post"] = "ck_function_adaptor<post_ck_gemm_softmax_gemm>";
v["kernel"] = "ck_gemm_softmax_gemm_" + generate_name_from_ops(*pm) + "_kernel";
}
auto shapes = to_shapes(ins->inputs());
return action_decorate(replace(compile_op(ctx, shapes, v)), [=] {
if(enabled(MIGRAPHX_LOG_CK_GEMM{}))
{
std::vector<shape> gemm_shapes{shapes[0], shapes[1], shapes.back()};
std::cout << "ck_gemm_softmax_gemm: " << to_json_string(to_value(gemm_shapes)) << std::endl;
}
});
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
This diff is collapsed.
...@@ -109,6 +109,28 @@ struct ck_passthrough ...@@ -109,6 +109,28 @@ struct ck_passthrough
} }
}; };
struct ck_scale
{
constexpr ck_scale(float s) : scale(s) {}
template <class T, class U>
constexpr void operator()(T& y, U x) const
{
y = x * static_cast<U>(scale);
}
float scale;
};
struct ck_add
{
template <class T, class U>
constexpr void operator()(T& y, U x) const
{
y += x;
}
};
#ifdef MIGRAPHX_CK_CHECK #ifdef MIGRAPHX_CK_CHECK
#define MIGRAPHX_CK_STATIC_ASSERT static_assert #define MIGRAPHX_CK_STATIC_ASSERT static_assert
#else #else
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CK_GEMM_SOFTMAX_GEMM_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_GEMM_SOFTMAX_GEMM_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/ck.hpp>
#include <migraphx/kernels/ck_gemm_softmax_gemm_includes.hpp>
#include <migraphx/kernels/gemm_batcher.hpp>
namespace migraphx {
// In CK, the B matrix is ordered as N,K instead of K,N
template <class Dims>
constexpr auto ck_transposeb_dims(Dims dims)
{
return unpack(dims, [](auto k, auto n) { return make_const_array(n, k); });
}
template <class Tensor>
using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor>{}.lens),
ck_transposeb_dims(get_shape_c<Tensor>{}.strides)));
template <class G, class C, class A, class B, class B1>
__device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
{
constexpr const G gemm{};
constexpr const auto a_shape = get_shape_c<A>{};
constexpr const auto m = a_shape.lens[0];
constexpr const auto k = a_shape.lens[1];
constexpr const auto sa = a_shape.strides[0];
constexpr const auto a_tensor = ck::make_naive_tensor_descriptor(ck::make_tuple(m, k),
ck::make_tuple(sa, 1));
constexpr const auto a_grid_desc_mraw_kraw = gemm.matrix_padder.PadADescriptor_M_K(a_tensor);
constexpr const auto AK1 = gemm.get_AK1();
constexpr const auto AK0 = k / AK1;
constexpr const auto a_grid_desc_ak0_m_ak1 = ck::transform_tensor_descriptor(a_grid_desc_mraw_kraw,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_pass_through_transform(m)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
constexpr const auto b_shape = get_shape_c<B>{};
constexpr const auto n = b_shape.lens[0]; // col-major
constexpr const auto sb = b_shape.strides[1]; // col-major
constexpr const auto BK1 = gemm.get_BK1();
constexpr const auto BK0 = k / BK1;
constexpr const auto b_tensor = ck::make_naive_tensor_descriptor(ck::make_tuple(n, k),
ck::make_tuple(sb, 1));
constexpr const auto b_grid_desc_nraw_kraw = gemm.matrix_padder.PadBDescriptor_N_K(b_tensor);
constexpr const auto b_grid_desc_bk0_n_bk1 = ck::transform_tensor_descriptor(b_grid_desc_nraw_kraw,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_pass_through_transform(n)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
constexpr const auto b1_shape = get_shape_c<B1>{};
constexpr const auto k1 = b1_shape.lens[0]; // row-major
constexpr const auto n1 = b1_shape.lens[1]; // row-major
constexpr const auto sb1 = b1_shape.strides[0]; // rowl-major
constexpr const auto B1K1 = gemm.get_B1K1();
constexpr const auto B1K0 = k1 / B1K1;
constexpr const auto b1_tensor = ck::make_naive_tensor_descriptor(ck::make_tuple(n1, k1),
ck::make_tuple(1, sb1));
constexpr const auto b1_grid_desc_nraw_kraw = gemm.matrix_padder.PadB1Descriptor_N_K(b1_tensor);
constexpr const auto b1_grid_desc_bk0_n_bk1 = ck::transform_tensor_descriptor(b1_grid_desc_nraw_kraw,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(B1K0, B1K1)),
ck::make_pass_through_transform(n1)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
constexpr const auto c_shape = get_shape_c<C>{};
constexpr const auto sc = c_shape.strides[0];
constexpr const auto c_tensor = ck::make_naive_tensor_descriptor(ck::make_tuple(m, n1),
ck::make_tuple(sc, 1));
constexpr const auto c_grid_desc_m_n = gemm.matrix_padder.PadCDescriptor_M_N(c_tensor);
constexpr const auto MPerBlock = gemm.get_mperblock();
constexpr const auto Gemm1NPerBlock = gemm.get_gemm1nperblock();
constexpr const auto MBlock = m / MPerBlock;
constexpr const auto NBlock = n1 / Gemm1NPerBlock;
constexpr const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
ck::transform_tensor_descriptor(
c_grid_desc_m_n,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(MBlock, ck::Number<MPerBlock>{})),
ck::make_unmerge_transform(ck::make_tuple(NBlock, ck::Number<Gemm1NPerBlock>{}))),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0, 1>{}, ck::Sequence<2, 3>{}));
constexpr const auto block_2_ctile_map = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, Gemm1NPerBlock, decltype(c_grid_desc_m_n)>(
c_grid_desc_m_n);
const C0MatrixMask c0_matrix_mask(n);
const auto K =
a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) * a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{});
using gridwise = typename G::template rt_gridwisegemm<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
decltype(b1_grid_desc_bk0_n_bk1),
decltype(c_grid_desc_m_n)>;
using GridwiseGemm = typename gridwise::GridwiseGemm;
constexpr const bool HasMainKBlockLoop = GridwiseGemm::CalculateHasMainKBlockLoop(K);
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_m_n,
block_2_ctile_map));
GridwiseGemm::template Run<HasMainKBlockLoop>(to_ck_const_pointer(a.data()),
to_ck_const_pointer(b.data()),
to_ck_const_pointer(b1.data()),
to_ck_pointer(c.data()),
p_shared,
gemm.a_element_op,
gemm.b_element_op,
gemm.acc_element_op,
gemm.b1_element_op,
gemm.c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map,
c0_matrix_mask);
}
template <class G, index_int BlocksPerBatch, class... Ts>
__device__ void ck_gemm_softmax_gemm(Ts... xs)
{
gemm_batch_args(make_index(), _c<BlocksPerBatch>, xs...)(
[](auto... ys) { ck_gemm_softmax_gemm_matrix<G>(ys...); });
}
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CK_GSG_INCLUDES_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_GSG_INCLUDES_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <ratio>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
namespace migraphx {
template <ck::index_t MPerBlock, ck::index_t NPerBlock, typename CGridDesc_M_N>
struct BlockToCTileMap_M00_N0_M01Adapt
{
static constexpr auto I0 = ck::Number<0>{};
static constexpr auto I1 = ck::Number<1>{};
static constexpr auto I2 = ck::Number<2>{};
static constexpr auto I3 = ck::Number<3>{};
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default;
__host__
__device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
ck::index_t M01 = 8)
: M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n)
{
}
__host__ __device__ constexpr ck::index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
const auto M0 = ck::math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = ck::math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const ck::index_t grid_size = M0 * N0;
return grid_size;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto block_1d_id = idx_top[I0];
const auto M0 = ck::math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock);
const auto N0 = ck::math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock);
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
ck::index_t idx_N0 = block_1d_id % N0;
ck::index_t idx_M0 = block_1d_id / N0;
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
ck::index_t idx_M00 = idx_M0 / M01_;
ck::index_t idx_M01 = idx_M0 % M01_;
ck::index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
return ck::make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool constexpr ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
__host__ __device__ constexpr bool
CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
private:
ck::index_t M01_;
CGridDesc_M_N c_grid_desc_m_n_;
};
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
struct C0MatrixMask
{
__device__ C0MatrixMask(ck::index_t NRaw) : NRaw_(NRaw) {}
__device__ bool IsUpperTriangle(ck::index_t m, ck::index_t n) const { return n > m; }
__device__ bool IsNOutOfBound(/*ck::index_t m, */ ck::index_t n) const
{
return n >= NRaw_;
}
__device__ bool IsMaskedElement(ck::index_t m, ck::index_t n) const
{
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
}
private:
// ck::index_t MRaw_;
ck::index_t NRaw_;
};
template <typename ALayout,
typename BLayout, // B0Layout
typename B1Layout,
typename CLayout,
typename ADataType,
typename BDataType,
typename B1DataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t NumGemmKPrefetchStage,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock, // Gemm0NPerBlock
ck::index_t KPerBlock, // Gemm0KPerBlock
ck::index_t Gemm1NPerBlock,
ck::index_t Gemm1KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t B1K1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
ck::index_t Gemm1NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
ck::index_t B1BlockTransferSrcVectorDim,
ck::index_t B1BlockTransferSrcScalarPerVector,
ck::index_t B1BlockTransferDstScalarPerVector_BK1,
bool B1BlockLdsExtraN,
ck::index_t CShuffleMXdlPerWavePerShuffle,
ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool MaskOutUpperTriangle,
typename Alpha,
ck::LoopScheduler LoopSched = ck::LoopScheduler::Default>
struct CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
static constexpr auto matrix_padder =
ck::tensor_operation::device::GemmGemmPadder<GemmSpec, ck::index_t, ck::index_t, ck::index_t, ck::index_t>{
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
static constexpr auto get_AK1() { return AK1; };
static constexpr auto get_BK1() { return BK1; };
static constexpr auto get_B1K1() { return B1K1; };
static constexpr auto get_mperblock() { return MPerBlock; };
static constexpr auto get_gemm1nperblock() { return Gemm1NPerBlock; };
static constexpr float alpha = float(Alpha::num) / Alpha::den;
static constexpr auto get_alpha() { return alpha; };
AElementwiseOperation a_element_op{};
BElementwiseOperation b_element_op{};
B1ElementwiseOperation b1_element_op{};
CElementwiseOperation c_element_op{};
AccElementwiseOperation acc_element_op{alpha};
template<typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1,
typename CGridDesc_M_N>
struct rt_gridwisegemm
{
// GridwiseGemm
using GridwiseGemm = ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
ck::InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
B1GridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
AK1,
BK1,
B1K1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
Gemm1NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
true,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
true,
BBlockLdsExtraN,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
false,
B1BlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
matrix_padder.PadN,
MaskOutUpperTriangle>;
};
};
} // namespace migraphx
#endif
gemm_softmax_gemm_test:

a
b gemm1_out"MatMul
!
gemm1_out
scalemul1_out"Mul

mul1_out
cadd1_out"Add
add1_out softmax_out"Softmax

softmax_out
b1out"MatMulgemm_softmax_gemm_test*
*`BscaleZ
a



Z
b



Z
c



Z
b1



Z
bias



b
out



B
\ No newline at end of file
...@@ -1986,6 +1986,42 @@ def gemm_half_test(): ...@@ -1986,6 +1986,42 @@ def gemm_half_test():
return ([node], [m1, m2, m3], [y]) return ([node], [m1, m2, m3], [y])
@onnx_test
def gemm_softmax_gemm_test():
a = helper.make_tensor_value_info('a', TensorProto.FLOAT16, [1, 1])
b = helper.make_tensor_value_info('b', TensorProto.FLOAT16, [1, 1])
c = helper.make_tensor_value_info('c', TensorProto.FLOAT16, [1, 1])
b1 = helper.make_tensor_value_info('b1', TensorProto.FLOAT16, [1, 1])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT16, [1, 1])
out = helper.make_tensor_value_info('out', TensorProto.FLOAT16, [1, 1])
scale_array = np.array([(1/8)])
scale_tensor = helper.make_tensor(name='scale',
data_type=TensorProto.FLOAT16,
dims=scale_array.shape,
vals=scale_array.flatten().astype(np.float16))
gemm1 = onnx.helper.make_node('MatMul',
inputs=['a', 'b'],
outputs=['gemm1_out'])
mul1 = onnx.helper.make_node('Mul',
inputs=['gemm1_out', 'scale'],
outputs=['mul1_out'])
add1 = onnx.helper.make_node('Add',
inputs=['mul1_out', 'c'],
outputs=['add1_out'])
softmax = onnx.helper.make_node('Softmax',
inputs=['add1_out'],
outputs=['softmax_out'])
gemm2 = onnx.helper.make_node('MatMul',
inputs=['softmax_out', 'b1'],
outputs=['out'])
return ([gemm1, mul1, add1, softmax, gemm2], [a, b, c, b1, bias], [out], [scale_tensor])
@onnx_test @onnx_test
def globalavgpool_test(): def globalavgpool_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 16, 16]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 16, 16])
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
auto m2_elements = 1 * 12 * 256 * 256;
auto a = mm->add_parameter("1", m1_shape);
auto b = mm->add_parameter("2", m1_shape);
auto b1 = mm->add_parameter("3", m1_shape);
auto c = mm->add_parameter("4", m1_shape);
std::vector<float> eights(m2_elements, 0.125);
auto eight = mm->add_literal(migraphx::literal{m2_shape, eights});
std::vector<float> zeros(m2_elements, 0);
auto zero = mm->add_literal(migraphx::literal{m2_shape, zeros});
std::vector<float> ones(m2_elements, 1);
auto one = mm->add_literal(migraphx::literal{m2_shape, ones});
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero);
auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), bias);
mm->add_instruction(migraphx::make_op("dot"), softmax, b1);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment