"docs/git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "34c1b7b142723f166467083ad41eb25ab5783e21"
Commit 359bb1cd authored by Paul's avatar Paul
Browse files

Merge branch 'ck-gemm-fused-transpose' into sd-opt

parents 1ac14290 55b363c9
...@@ -28,3 +28,4 @@ ROCmSoftwarePlatform/half@rocm-5.4.2 ...@@ -28,3 +28,4 @@ ROCmSoftwarePlatform/half@rocm-5.4.2
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@1b62bfaa2a42ed83da2692f6797a5f929c39946f -X header
...@@ -213,22 +213,28 @@ void from_value_impl(rank<6>, const value& v, optional<T>& x) ...@@ -213,22 +213,28 @@ void from_value_impl(rank<6>, const value& v, optional<T>& x)
x = from_value<T>(v); x = from_value<T>(v);
} }
template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<T>{} or std::is_enum<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<T>{})>
void from_value_impl(rank<7>, const value& v, T& x) void from_value_impl(rank<7>, const value& v, T& x)
{ {
x = v.to<T>(); x = v.to<T>();
} }
inline void from_value_impl(rank<8>, const value& v, std::string& x) { x = v.to<std::string>(); } template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})>
void from_value_impl(rank<8>, const value& v, T& x)
{
x = v.to<T>();
}
inline void from_value_impl(rank<9>, const value& v, std::string& x) { x = v.to<std::string>(); }
template <class T> template <class T>
auto from_value_impl(rank<9>, const value& v, T& x) -> decltype(x.from_value(v), void()) auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(x.from_value(v), void())
{ {
x.from_value(v); x.from_value(v);
} }
template <class T> template <class T>
auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void()) auto from_value_impl(rank<11>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void())
{ {
migraphx_from_value(v, x); migraphx_from_value(v, x);
} }
...@@ -244,7 +250,7 @@ value to_value(const T& x) ...@@ -244,7 +250,7 @@ value to_value(const T& x)
template <class T> template <class T>
void from_value(const value& v, T& x) void from_value(const value& v, T& x)
{ {
detail::from_value_impl(rank<10>{}, v, x); detail::from_value_impl(rank<11>{}, v, x);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/requires.hpp> #include <migraphx/requires.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <vector> #include <vector>
...@@ -99,12 +100,21 @@ void stream_write_value_impl(rank<0>, std::ostream& os, const T& x) ...@@ -99,12 +100,21 @@ void stream_write_value_impl(rank<0>, std::ostream& os, const T& x)
os << "}"; os << "}";
} }
template <class T>
void stream_write_value_impl(rank<0>, std::ostream& os, const optional<T>& x)
{
if(x.has_value())
stream_write_value_impl(rank<2>{}, os, *x);
else
os << "none";
}
} // namespace detail } // namespace detail
template <class T> template <class T>
void stream_write_value(std::ostream& os, const T& x) void stream_write_value(std::ostream& os, const T& x)
{ {
detail::stream_write_value_impl(rank<1>{}, os, x); detail::stream_write_value_impl(rank<2>{}, os, x);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -91,6 +91,7 @@ add_library(migraphx_gpu ...@@ -91,6 +91,7 @@ add_library(migraphx_gpu
compile_miopen.cpp compile_miopen.cpp
compiler.cpp compiler.cpp
device_name.cpp device_name.cpp
fuse_ck.cpp
fuse_mlir.cpp fuse_mlir.cpp
fuse_ops.cpp fuse_ops.cpp
gather.cpp gather.cpp
...@@ -184,19 +185,18 @@ if(MIGRAPHX_USE_HIPRTC) ...@@ -184,19 +185,18 @@ if(MIGRAPHX_USE_HIPRTC)
message(STATUS "MIGraphX is using hipRTC") message(STATUS "MIGraphX is using hipRTC")
target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_USE_HIPRTC=1) target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_USE_HIPRTC=1)
else() else()
message(STATUS "MIGraphX is using HIP Clang")
# Get flags needed to compile hip # Get flags needed to compile hip
include(TargetFlags) include(TargetFlags)
target_flags(HIP_COMPILER_FLAGS hip::device) target_flags(HIP_COMPILER_FLAGS hip::device)
# Remove cuda arch flags # Remove cuda arch flags
string(REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") string(REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REGEX REPLACE --offload-arch=[a-z0-9:+-]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") string(REGEX REPLACE --offload-arch=[a-z0-9:+-]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
# Skip library paths since hip will incorrectly treat it as a source file # Skip library paths since hip will incorrectly treat it as a source file
string(APPEND HIP_COMPILER_FLAGS " ") string(APPEND HIP_COMPILER_FLAGS " ")
# Add ck includes
find_path(CK_INCLUDE_PATH ck/ck.hpp)
message(STATUS "CK path: ${CK_INCLUDE_PATH}")
string(APPEND HIP_COMPILER_FLAGS " -isystem ${CK_INCLUDE_PATH}")
foreach(_unused RANGE 2) foreach(_unused RANGE 2)
string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
endforeach() endforeach()
......
...@@ -39,16 +39,18 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL); ...@@ -39,16 +39,18 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL);
struct precompile_op struct precompile_op
{ {
operation op = op::identity{}; operation op = op::identity{};
std::size_t additional_args = 1; std::size_t additional_args = 1;
bool ignore_modules = false; bool ignore_modules = false;
optional<shape> output_shape = {};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.op, "op"), return pack(f(self.op, "op"),
f(self.additional_args, "additional_args"), f(self.additional_args, "additional_args"),
f(self.ignore_modules, "ignore_modules")); f(self.ignore_modules, "ignore_modules"),
f(self.output_shape, "output_shape"));
} }
std::string name() const { return "gpu::precompile_op"; } std::string name() const { return "gpu::precompile_op"; }
...@@ -57,9 +59,14 @@ struct precompile_op ...@@ -57,9 +59,14 @@ struct precompile_op
{ {
// Pop off additional args // Pop off additional args
inputs.resize(inputs.size() - additional_args); inputs.resize(inputs.size() - additional_args);
shape r{};
if(ignore_modules) if(ignore_modules)
return op.compute_shape(inputs); r = op.compute_shape(inputs);
return op.compute_shape(inputs, mods); else
r = op.compute_shape(inputs, mods);
if(output_shape.has_value())
r = *output_shape;
return r;
} }
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
......
...@@ -36,7 +36,7 @@ struct compile_op : action<compile_op> ...@@ -36,7 +36,7 @@ struct compile_op : action<compile_op>
static void apply(const parser& p, const value& v) static void apply(const parser& p, const value& v)
{ {
context ctx; context ctx;
auto inputs = p.parse_shapes(v.at("inputs")); auto inputs = p.get_inputs(v);
auto op = gpu::compile_op(v.at("name").to<std::string>(), ctx, inputs, v); auto op = gpu::compile_op(v.at("name").to<std::string>(), ctx, inputs, v);
auto [host_time, device_time] = time_op(ctx, op, inputs, p.get(v, "iterations", 100)); auto [host_time, device_time] = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << host_time << "ms"; std::cout << op << ": " << host_time << "ms";
......
...@@ -53,6 +53,8 @@ struct parser ...@@ -53,6 +53,8 @@ struct parser
std::vector<shape> parse_shapes(const value& v) const; std::vector<shape> parse_shapes(const value& v) const;
std::vector<shape> get_inputs(const value& v) const;
void load_settings(const value& v); void load_settings(const value& v);
static void process(const value& v); static void process(const value& v);
......
...@@ -55,6 +55,11 @@ std::vector<shape> parser::parse_shapes(const value& v) const ...@@ -55,6 +55,11 @@ std::vector<shape> parser::parse_shapes(const value& v) const
return result; return result;
} }
std::vector<shape> parser::get_inputs(const value& v) const
{
return parse_shapes(get(v, "inputs", value{}));
}
void parser::load_settings(const value& v) void parser::load_settings(const value& v)
{ {
if(v.contains("settings")) if(v.contains("settings"))
......
...@@ -56,15 +56,26 @@ time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n) ...@@ -56,15 +56,26 @@ time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n)
op.compute(ctx, output, args); op.compute(ctx, output, args);
ctx.finish(); ctx.finish();
}; };
gctx.enable_perf_measurement();
run(); run();
double host_time = 0.0; // Measure host time
double device_time = 0.0; double host_time = 0.0;
for(auto i : range(n)) for(auto i : range(n))
{ {
(void)i; (void)i;
host_time += time<milliseconds>(run); host_time += time<milliseconds>(run);
device_time += gctx.get_elapsed_ms(); }
// Measure device time only for code_object ops which support it
double device_time = 0.0;
if(op.name() == "gpu::code_object")
{
gctx.enable_perf_measurement();
run();
for(auto i : range(n))
{
(void)i;
run();
device_time += gctx.get_elapsed_ms();
}
} }
return std::make_pair(host_time / n, device_time / n); return std::make_pair(host_time / n, device_time / n);
} }
......
...@@ -36,7 +36,7 @@ struct run_op : action<run_op> ...@@ -36,7 +36,7 @@ struct run_op : action<run_op>
static void apply(const parser& p, const value& v) static void apply(const parser& p, const value& v)
{ {
context ctx; context ctx;
auto inputs = p.parse_shapes(v.at("inputs")); auto inputs = p.get_inputs(v);
auto name = v.at("name").to<std::string>(); auto name = v.at("name").to<std::string>();
if(not contains(name, "::")) if(not contains(name, "::"))
name = "gpu::" + name; name = "gpu::" + name;
......
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/env.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_CK_GEMM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_CK_GEMM_FUSION);
struct module;
namespace gpu {
struct ck_gemm
{
operation op = make_op("dot");
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
}
std::string name() const { return "gpu::ck_gemm"; }
void check_gemm_shape(const shape& s) const
{
if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1))
MIGRAPHX_THROW("Invalid shape for ck_gemm");
}
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
check_shapes{inputs, *this}.same_ndims();
// if(mods.size() != 1)
// MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
auto a = inputs[0];
auto b = inputs[1];
for(const auto& input : inputs)
check_gemm_shape(input);
auto r = op.compute_shape({a, b});
if(mods.empty())
return r;
return r.with_type(mods.front()->get_output_shapes().front().type());
}
};
MIGRAPHX_REGISTER_OP(ck_gemm);
namespace {
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{
if(ins->name() != "dot")
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
if(a.lens().back() > 2048)
return false;
return true;
}
struct find_ck_gemm_pointwise
{
// Find a gemm followed by a pointwise operation.
auto matcher() const
{
auto gemm =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm")));
return match::name("pointwise")(match::any_of[match::inputs()](gemm.bind("x")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm_ins = r.instructions["gemm"];
auto x_ins = r.instructions["x"]; // input after contiguous
auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names();
std::sort(names.begin(), names.end());
auto inputs = ins->inputs();
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin();
assert(gemm_it != inputs.end());
if(ins->get_shape().type() != shape::half_type)
return;
if(gemm_idx != 0)
{
auto first_param = pm->get_parameter(names[0]);
auto gemm_param = pm->get_parameter(names[gemm_idx]);
auto new_gemm_param = pm->add_parameter(names[0] + ".0", gemm_param->get_shape());
auto new_first_param =
pm->add_parameter(names[gemm_idx] + ".0", first_param->get_shape());
pm->replace_instruction(gemm_param, new_gemm_param);
pm->replace_instruction(first_param, new_first_param);
pm->remove_instruction(first_param);
pm->remove_instruction(gemm_param);
}
inputs.erase(gemm_it);
inputs.insert(inputs.begin(), gemm_ins->inputs().begin(), gemm_ins->inputs().end());
mpm.get_module().replace_instruction(ins, ck_gemm{}, inputs, {pm});
}
};
struct find_ck_gemm
{
auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
mpm.get_module().replace_instruction(ins, ck_gemm{ins->get_operator()}, ins->inputs());
}
};
} // namespace
void fuse_ck::apply(module_pass_manager& mpm) const
{
if(not enabled(MIGRAPHX_DISABLE_CK_GEMM_FUSION{}))
match::find_matches(mpm, find_ck_gemm_pointwise{});
if(not enabled(MIGRAPHX_DISABLE_CK_GEMM{}))
match::find_matches(mpm, find_ck_gemm{});
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -675,6 +675,36 @@ struct find_gemm_pointwise ...@@ -675,6 +675,36 @@ struct find_gemm_pointwise
} }
}; };
struct find_contiguous_tranpose_precompile
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(
match::name("transpose")(
match::used_once(),
match::arg(0)(match::name("gpu::precompile_op")(match::used_once()).bind("op")))
.bind("transpose")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto op_ins = r.instructions["op"];
auto alloc = op_ins->inputs().back();
auto transpose = r.instructions["transpose"];
auto perm = transpose->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto iperm = invert_permutation(perm);
auto s = shape::from_permutation(
op_ins->get_shape().type(), op_ins->get_shape().lens(), perm); // perm or iperm?
auto v = op_ins->get_operator().to_value();
v["output_shape"] = to_value(s);
auto new_op = make_op("gpu::precompile_op", v);
m.replace_instruction(op_ins, new_op, op_ins->inputs(), op_ins->module_inputs());
assert(ins->get_shape() == transpose->get_shape());
m.replace_instruction(ins, transpose);
}
};
struct find_contiguous_tranpose_gemm struct find_contiguous_tranpose_gemm
{ {
auto matcher() const auto matcher() const
...@@ -852,6 +882,7 @@ void fuse_ops::apply(module& m) const ...@@ -852,6 +882,7 @@ void fuse_ops::apply(module& m) const
find_concat_pointwise{}, find_concat_pointwise{},
find_gemm_pointwise{}, find_gemm_pointwise{},
find_contiguous_tranpose_gemm{}, find_contiguous_tranpose_gemm{},
find_contiguous_tranpose_precompile{},
find_commutative_broadcast{}); find_commutative_broadcast{});
match::find_matches(m, find_contiguous{}); match::find_matches(m, find_contiguous{});
} }
......
#ifndef MIGRAPHX_GUARD_GPU_FUSE_CK_HPP
#define MIGRAPHX_GUARD_GPU_FUSE_CK_HPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_pass_manager;
namespace gpu {
struct fuse_ck
{
context* ctx = nullptr;
std::string name() const { return "gpu::fuse_ck"; }
void apply(module_pass_manager& mpm) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_FUSE_CK_HPP
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <fstream>
#include <filesystem>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/env.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/module.hpp>
#include <migraphx/env.hpp>
#include <migraphx/file_buffer.hpp>
const std::vector<std::string>&
get_instance(std::size_t i, const std::function<bool(const std::vector<std::string>&)>& pred);
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
using namespace migraphx::gpu::gen; // NOLINT
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING_VALUE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG);
// NOLINTNEXTLINE
static const char* const ck_gemm_kernel = R"__migraphx__(
#include <args.hpp>
#include <migraphx/kernels/ck_gemm.hpp>
#include <migraphx/kernels/pointwise.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
ck_gemm<CK_DeviceGemmMultipleD<${instance}>, ${blocks_per_batch}>(xs...);
});
}
}
} // namespace migraphx
)__migraphx__";
static std::size_t int_div_ceil(std::size_t x, std::size_t y) { return (x + y - 1) / y; }
struct instance
{
std::vector<std::string> params;
static const std::size_t block_size_index = 15;
std::size_t int_at(std::size_t i) const { return std::stoull(params[i]); }
std::size_t get_block_size() const { return int_at(block_size_index); }
std::size_t get_pb(std::size_t i) const
{
assert(i < 4);
return int_at(block_size_index + 1 + i);
}
std::array<std::size_t, 3> get_pad(const std::array<std::size_t, 3>& config) const
{
std::array<std::size_t, 3> result{};
for(auto i : range(config.size()))
{
result[i] = int_div_ceil(config[i], get_pb(i)) * get_pb(i) - config[i];
}
return result;
}
std::size_t get_grid_size(const std::array<std::size_t, 3>& config) const
{
return int_div_ceil(config[0], get_pb(0)) * int_div_ceil(config[1], get_pb(1));
}
void set_ds_layout(const std::string& s)
{
assert(params[2] == "ck::Tuple<>");
params[2] = s;
}
void set_ds_type(const std::string& s)
{
assert(params[8] == "ck::Tuple<>");
params[8] = s;
}
void set_ds_op(const std::string& s)
{
assert(params[12] == "ck_passthrough");
params[12] = s;
}
void set_gemm(const std::string& s)
{
assert(params[13] == "ck::tensor_operation::device::GemmSpecialization::Default");
params[13] = s;
}
std::string str() const { return join_strings(params, ","); }
};
static bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }
template <class F, class Action>
auto action_decorate(F f, Action action)
{
return [=](auto&&... xs) {
action();
f(std::forward<decltype(xs)>(xs)...);
};
}
using tuning_entry = std::pair<std::vector<shape>, size_t>;
static std::vector<tuning_entry> read_tuning(const std::string& s)
{
if(not fs::exists(s))
return {};
return from_value<std::vector<tuning_entry>>(from_json_string(read_string(s)));
}
static float matrix_distance(const shape& x, const shape& y)
{
if(x.type() != y.type())
return std::numeric_limits<float>::max();
if(transposed_matrix(x) != transposed_matrix(y))
return std::numeric_limits<float>::max();
auto sum_squared = std::inner_product(x.lens().rbegin(),
x.lens().rbegin() + 2,
y.lens().rbegin(),
0,
std::plus<>{},
[](auto a, auto b) { return (a - b) * (a - b); });
return std::sqrt(sum_squared);
}
static std::size_t get_tuning_for(const std::vector<shape>& inputs)
{
static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, ""));
if(tuning.empty())
std::cout << "*********** Warning: No CK tuning!" << std::endl;
auto it = std::find_if(
tuning.begin(), tuning.end(), [&](const auto& p) { return p.first == inputs; });
if(it == tuning.end())
{
std::cout << "*********** Warning: CK tuning missing for config!" << std::endl;
std::vector<std::pair<float, std::size_t>> w;
std::transform(tuning.begin(), tuning.end(), std::back_inserter(w), [&](const auto& p) {
if(inputs.size() < 3 or p.first.size() < 3)
MIGRAPHX_THROW("Invalid CK config");
auto avg_distance = std::inner_product(
p.first.begin(),
p.first.begin() + 3,
inputs.begin(),
0.0f,
std::plus<>{},
[](const auto& x, const auto& y) { return matrix_distance(x, y) / 3.0f; });
return std::make_pair(avg_distance, p.second);
});
std::sort(w.begin(), w.end());
std::size_t default_value = 4;
if(not w.empty())
default_value = w.front().second;
auto tuning_val = value_of(MIGRAPHX_CK_TUNING_VALUE{}, default_value);
std::cout << "*********** Warning: CK try tuning: " << tuning_val << std::endl;
return tuning_val;
}
return it->second;
}
struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{
static std::string get_layout(const shape& s)
{
return transposed_matrix(s) ? "ck::tensor_layout::gemm::ColumnMajor"
: "ck::tensor_layout::gemm::RowMajor";
}
static std::string get_type(const shape& s)
{
if(s.type() == shape::half_type)
return "ck::half_t";
return shape::cpp_type(s.type());
}
template <class Iterator, class F>
static std::string ck_tuple(Iterator start, Iterator last, F f)
{
std::vector<std::string> s;
std::transform(start, last, std::back_inserter(s), f);
return "ck::Tuple<" + join_strings(s, ",") + ">";
}
static std::vector<shape> adjust_inputs(std::vector<shape> inputs, bool& swap_inputs)
{
swap_inputs = false;
auto c_shape = inputs.back();
if(not transposed_matrix(c_shape))
return inputs;
std::vector<int64_t> perm(c_shape.lens().size());
std::iota(perm.begin(), perm.end(), 0);
std::swap(perm[perm.size() - 1], perm[perm.size() - 2]);
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](shape s) {
return reorder_shape(s, perm);
});
swap_inputs = true;
return inputs;
}
std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; }
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
{
auto a_shape = inputs[0];
auto b_shape = inputs[1];
auto c_shape = inputs.back();
auto rank = a_shape.lens().size();
std::array<char, 3> keys{'M', 'N', 'K'};
std::array<std::size_t, 3> config{
c_shape.lens()[rank - 2], c_shape.lens().back(), a_shape.lens().back()};
auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape}));
auto ip = instance{get_instance(tuning_val, [&](const auto& x) -> bool {
return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_type(b_shape) == x[5] and get_type(c_shape) == x[9];
})};
assert(inputs.size() < 4 or v.contains("post"));
if(v.contains("post"))
{
ip.set_ds_layout(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_layout));
ip.set_ds_type(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_type));
ip.set_ds_op(v.at("post").to<std::string>());
}
auto padding = ip.get_pad(config);
std::string gemm_type;
for(auto i : range(padding.size()))
{
if(padding[i] != 0)
gemm_type += keys[i];
}
if(gemm_type.empty())
gemm_type = "Default";
else
gemm_type += "Padding";
ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
auto blocks_per_batch = ip.get_grid_size(config);
auto batch_count = std::accumulate(c_shape.lens().rbegin() + 2,
c_shape.lens().rend(),
std::size_t{1},
std::multiplies<std::size_t>());
hip_compile_options options;
auto block_size = ip.get_block_size();
auto grid_size = batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs;
options.output = c_shape;
options.kernel_name = v.get("kernel", "ck_gemm_kernel");
options.virtual_inputs = inputs;
if(v.get("check", false) or enabled(MIGRAPHX_CK_DEBUG{}))
options.params += " -DMIGRAPHX_CK_CHECK=1";
auto src = interpolate_string(ck_gemm_kernel,
{{"instance", ip.str()},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"blocks_per_batch", to_string(blocks_per_batch)},
{"preamble", v.get("preamble", std::string{})},
{"kernel", options.kernel_name}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
auto v = op.to_value();
v["kernel"] = "ck_gemm_kernel";
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_function") +
"\nMIGRAPHX_LIFT_CLASS(post_ck_gemm, post_ck_gemm_function);";
v["post"] = "ck_function_adaptor<post_ck_gemm>";
v["kernel"] = "ck_gemm_" + generate_name_from_ops(*pm) + "_kernel";
}
auto shapes = to_shapes(ins->inputs());
return action_decorate(replace(compile_op(ctx, shapes, v)), [=] {
if(enabled(MIGRAPHX_LOG_CK_GEMM{}))
{
std::vector<shape> gemm_shapes{shapes[0], shapes[1], shapes.back()};
std::cout << "ck_gemm: " << to_json_string(to_value(gemm_shapes)) << std::endl;
}
});
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
This diff is collapsed.
...@@ -272,6 +272,18 @@ struct integral_const_array : array<T, sizeof...(Xs)> ...@@ -272,6 +272,18 @@ struct integral_const_array : array<T, sizeof...(Xs)>
MIGRAPHX_DEVICE_CONSTEXPR integral_const_array() : base_array({Xs...}) {} MIGRAPHX_DEVICE_CONSTEXPR integral_const_array() : base_array({Xs...}) {}
}; };
template <class T, class... Ts>
constexpr auto make_const_array(T x, Ts... xs)
{
return integral_const_array<typename T::value_type, x, xs...>{};
}
template <class T, T... Xs, class F>
constexpr auto unpack(integral_const_array<T, Xs...>, F f)
{
return f(_c<Xs>...);
}
template <class T, T... Xs, class F> template <class T, T... Xs, class F>
constexpr auto transform(integral_const_array<T, Xs...>, F f) constexpr auto transform(integral_const_array<T, Xs...>, F f)
{ {
......
#ifndef MIGRAPHX_GUARD_KERNELS_CK_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_HPP
#include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <ck/utility/common_header.hpp>
#include <ck/tensor_description/tensor_descriptor.hpp>
#include <ck/tensor_description/tensor_descriptor_helper.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
namespace migraphx {
namespace detail {
template <class T>
struct to_ck_type_impl
{
using type = T;
};
template <>
struct to_ck_type_impl<migraphx::half>
{
using type = ck::half_t;
};
template <class T>
struct to_ck_type_impl<const T>
{
using type = const typename to_ck_type_impl<T>::type;
};
template <class Shape>
constexpr bool is_row_major()
{
constexpr auto strides = Shape{}.strides;
MIGRAPHX_ASSERT(strides.size() >= 2);
if(strides.back() == 1)
{
MIGRAPHX_ASSERT(not Shape{}.is_transposed());
return true;
}
MIGRAPHX_ASSERT(strides[strides.size() - 2] == 1);
return false;
}
} // namespace detail
template <class T>
using to_ck_type = typename detail::to_ck_type_impl<T>::type;
template <class T>
constexpr auto to_ck_pointer(T* x)
{
return static_cast<to_ck_type<T>*>(x);
}
template <class T>
constexpr auto to_ck_const_pointer(const T* x)
{
return static_cast<const to_ck_type<T>*>(x);
}
template <class Shape>
using to_ck_gemm_layout = conditional_t<detail::is_row_major<get_shape_c<Shape>>(),
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor>;
template <class Tensor>
constexpr auto to_ck_tensor()
{
constexpr auto s = get_shape_c<Tensor>{};
return sequence(s.lens.size(), [&](auto... is) {
return ck::make_naive_tensor_descriptor(ck::make_tuple(s.lens[is]...),
ck::make_tuple(s.strides[is]...));
});
}
template <class F>
struct ck_function_adaptor : F
{
template <class... Ts>
constexpr ck_function_adaptor(Ts&&... xs) : F(static_cast<Ts&&>(xs)...)
{
}
template <class T, class... Ts>
constexpr void operator()(T& out, Ts&&... xs) const
{
out = static_cast<const F&>(*this)(static_cast<Ts&&>(xs)...);
}
};
struct ck_nop
{
template <class T>
constexpr void operator()(T&) const
{
}
};
struct ck_passthrough
{
template <class T, class U>
constexpr void operator()(T& y, U x) const
{
y = x;
}
};
#ifdef MIGRAPHX_CK_CHECK
#define MIGRAPHX_CK_STATIC_ASSERT static_assert
#else
#define MIGRAPHX_CK_STATIC_ASSERT(...)
#endif
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_CK_HPP
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/ck.hpp>
#include <migraphx/kernels/ck_gemm_includes.hpp>
#include <migraphx/kernels/gemm_batcher.hpp>
namespace migraphx {
// In CK, the B matrix is ordered as N,K instead of K,N
template <class Dims>
constexpr auto ck_transposeb_dims(Dims dims)
{
return unpack(dims, [](auto k, auto n) { return make_const_array(n, k); });
}
template <class Tensor>
using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor>{}.lens),
ck_transposeb_dims(get_shape_c<Tensor>{}.strides)));
template <class G, class E, class A, class B, class... Ds>
__device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds)
{
constexpr const G gemm{};
constexpr const auto a_grid_desc_m_k = gemm.matrix_padder.PadADescriptor_M_K(to_ck_tensor<A>());
constexpr const auto b_grid_desc_n_k =
gemm.matrix_padder.PadBDescriptor_N_K(to_ck_tensor<ck_transposeb<B>>());
constexpr const auto e_grid_desc_m_n = gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<E>());
constexpr const auto ds_grid_desc_m_n =
ck::make_tuple(gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<Ds>())...);
constexpr const auto block_2_etile_map = gemm.MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
using GridwiseGemm = typename G::GridwiseGemm;
// tensor descriptors for block/thread-wise copy
constexpr auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
constexpr auto b_grid_desc_bk0_n_bk1 =
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
constexpr auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n);
constexpr auto e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
static_assert(GridwiseGemm::CheckValidity(
a_grid_desc_m_k, b_grid_desc_n_k, ds_grid_desc_m_n, e_grid_desc_m_n, block_2_etile_map));
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
constexpr const bool HasMainKBlockLoop =
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) *
a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{}));
GridwiseGemm::template Run<HasMainKBlockLoop>(to_ck_const_pointer(a.data()),
to_ck_const_pointer(b.data()),
ck::make_tuple(to_ck_const_pointer(ds.data())...),
to_ck_pointer(e.data()),
p_shared_block,
gemm.a_element_op,
gemm.b_element_op,
gemm.cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
template <class G, index_int BlocksPerBatch, class... Ts>
__device__ void ck_gemm(Ts... xs)
{
gemm_batch_args(make_index(), _c<BlocksPerBatch>, xs...)(
[](auto... ys) { ck_gemm_matrix<G>(ys...); });
}
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CK_INCLUDES_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_INCLUDES_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <ck/utility/common_header.hpp>
#include <ck/tensor_description/tensor_descriptor.hpp>
#include <ck/tensor_description/tensor_descriptor_helper.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/device/device_gemm.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp>
#include <ck/tensor_operation/gpu/device/matrix_padder.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp>
namespace migraphx {
template <ck::index_t MPerBlock, ck::index_t NPerBlock, typename CGridDesc_M_N>
struct BlockToCTileMap_M00_N0_M01Adapt
{
static constexpr auto I0 = ck::Number<0>{};
static constexpr auto I1 = ck::Number<1>{};
static constexpr auto I2 = ck::Number<2>{};
static constexpr auto I3 = ck::Number<3>{};
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default;
__host__
__device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
ck::index_t M01 = 8)
: M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n)
{
}
__host__ __device__ constexpr ck::index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
const auto M0 = ck::math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = ck::math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const ck::index_t grid_size = M0 * N0;
return grid_size;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto block_1d_id = idx_top[I0];
const auto M0 = ck::math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock);
const auto N0 = ck::math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock);
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
ck::index_t idx_N0 = block_1d_id % N0;
ck::index_t idx_M0 = block_1d_id / N0;
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
ck::index_t idx_M00 = idx_M0 / M01_;
ck::index_t idx_M01 = idx_M0 % M01_;
ck::index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
return ck::make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool constexpr ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
__host__ __device__ constexpr bool
CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
private:
ck::index_t M01_;
CGridDesc_M_N c_grid_desc_m_n_;
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t NumGemmKPrefetchStage,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_AK1,
ck::index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_BK1,
ck::index_t BBlockLdsExtraN,
ck::index_t CShuffleMXdlPerWavePerShuffle,
ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()>
struct CK_DeviceGemmMultipleD
{
static constexpr auto I0 = ck::Number<0>{};
static constexpr auto I1 = ck::Number<1>{};
// static constexpr auto I2 = ck::Number<2>{};
// static constexpr auto I3 = ck::Number<3>{};
// static constexpr auto I4 = ck::Number<4>{};
// static constexpr auto I5 = ck::Number<5>{};
// static constexpr auto I6 = ck::Number<6>{};
// static constexpr auto I7 = ck::Number<7>{};
ck::tensor_operation::device::MatrixPadder<GemmSpec, ck::index_t, ck::index_t, ck::index_t>
matrix_padder{MPerBlock, NPerBlock, KPerBlock};
// GridwiseGemm
using GridwiseGemm = ck::GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
ck::InMemoryDataOperationEnum::Set,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
// return block_id to E matrix tile idx (m0, n0) mapping
template <class EGridDesc_M_N>
__device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n_)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n_);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename AGridDesc_M_K,
typename BGridDesc_N_K,
typename DsGridDesc_M_N,
typename EGridDesc_M_N,
typename Block2ETileMap>
static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
// check consistency of desc
MIGRAPHX_CHECK(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1));
// check tile size
MIGRAPHX_CHECK(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0);
// check block-to-E-tile
MIGRAPHX_CHECK(block_2_etile_map.CheckValidity(e_grid_desc_m_n));
return GridwiseGemm::CheckValidity(
a_grid_desc_m_k, b_grid_desc_n_k, ds_grid_desc_m_n, e_grid_desc_m_n, block_2_etile_map);
}
AElementwiseOperation a_element_op{};
BElementwiseOperation b_element_op{};
CDEElementwiseOperation cde_element_op{};
};
} // namespace migraphx
#endif
...@@ -35,6 +35,15 @@ ...@@ -35,6 +35,15 @@
[](auto&&... private_lisft_xs) MIGRAPHX_RETURNS( \ [](auto&&... private_lisft_xs) MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...)) (__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...))
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT_CLASS(name, ...) \
struct name \
{ \
template <class... PrivateLiftTs> \
constexpr auto operator()(PrivateLiftTs&&... private_lisft_xs) const MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...)) \
}
namespace migraphx { namespace migraphx {
struct swallow struct swallow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment