Unverified Commit 25af8710 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add initial CK integration plus auto-tuning for kernels (#1791)

Enable with MIGRAPHX_ENABLE_CK=1 and --exhaustive-tune tune flag
parent e5a33aad
...@@ -28,3 +28,4 @@ ROCmSoftwarePlatform/half@rocm-5.4.2 ...@@ -28,3 +28,4 @@ ROCmSoftwarePlatform/half@rocm-5.4.2
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@84c5bec1d66a633802fd977bd61e0aada7a6f153 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
#include <migraphx/config.hpp>
#include <functional>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
std::size_t hash_value(const T& v)
{
return std::hash<T>{}(v);
}
template <class T>
void hash_combine(std::size_t& seed, const T& v)
{
seed ^= hash_value(v) + 0x9e3779b9 + (seed << 6u) + (seed >> 2u);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
...@@ -392,8 +392,8 @@ struct value ...@@ -392,8 +392,8 @@ struct value
return; \ return; \
} }
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE_VALUE) MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE_VALUE)
MIGRAPHX_VALUE_GENERATE_CASE(array, ) MIGRAPHX_VALUE_GENERATE_CASE_VALUE(array, )
MIGRAPHX_VALUE_GENERATE_CASE(object, ) MIGRAPHX_VALUE_GENERATE_CASE_VALUE(object, )
} }
MIGRAPHX_THROW("Unknown type"); MIGRAPHX_THROW("Unknown type");
} }
...@@ -461,6 +461,8 @@ struct value ...@@ -461,6 +461,8 @@ struct value
friend std::ostream& operator<<(std::ostream& os, const value& d); friend std::ostream& operator<<(std::ostream& os, const value& d);
std::size_t hash() const;
void debug_print(bool show_type = false) const; void debug_print(bool show_type = false) const;
type_t get_type() const; type_t get_type() const;
...@@ -481,4 +483,15 @@ struct value ...@@ -481,4 +483,15 @@ struct value
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
namespace std {
template <>
struct hash<migraphx::value>
{
using argument_type = migraphx::value;
using result_type = std::size_t;
result_type operator()(const migraphx::value& x) const { return x.hash(); }
};
} // namespace std
#endif #endif
...@@ -31,6 +31,8 @@ namespace migraphx { ...@@ -31,6 +31,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace onnx { namespace onnx {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK)
struct parse_where : op_parser<parse_where> struct parse_where : op_parser<parse_where>
{ {
std::vector<op_desc> operators() const { return {{"Where"}}; } std::vector<op_desc> operators() const { return {{"Where"}}; }
...@@ -56,6 +58,14 @@ struct parse_where : op_parser<parse_where> ...@@ -56,6 +58,14 @@ struct parse_where : op_parser<parse_where>
auto lens = auto lens =
compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens()); compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens()); lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(enabled(MIGRAPHX_ENABLE_CK{}))
{
// Convert condition tensor to int32 to work around CK not supporting bool type
args[0] = info.add_instruction(
make_op("convert", {{"target_type", shape::int32_type}}), args[0]);
}
if(args[0]->get_shape().lens() != lens) if(args[0]->get_shape().lens() != lens)
{ {
args[0] = args[0] =
......
...@@ -33,6 +33,8 @@ if(NOT TARGET MIOpen) ...@@ -33,6 +33,8 @@ if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen") message(SEND_ERROR "Cant find miopen")
endif() endif()
find_package(composable_kernel 1.0.0 COMPONENTS jit_library REQUIRED)
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs") set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs")
#if(BUILD_DEV) #if(BUILD_DEV)
#set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs") #set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs")
...@@ -96,6 +98,7 @@ add_library(migraphx_gpu ...@@ -96,6 +98,7 @@ add_library(migraphx_gpu
compile_miopen.cpp compile_miopen.cpp
compiler.cpp compiler.cpp
device_name.cpp device_name.cpp
fuse_ck.cpp
fuse_mlir.cpp fuse_mlir.cpp
fuse_ops.cpp fuse_ops.cpp
gather.cpp gather.cpp
...@@ -124,6 +127,7 @@ add_library(migraphx_gpu ...@@ -124,6 +127,7 @@ add_library(migraphx_gpu
schedule_model.cpp schedule_model.cpp
sync_device.cpp sync_device.cpp
target.cpp target.cpp
time_op.cpp
topk.cpp topk.cpp
write_literals.cpp write_literals.cpp
${JIT_GPU_SRCS} ${JIT_GPU_SRCS}
...@@ -242,7 +246,7 @@ else() ...@@ -242,7 +246,7 @@ else()
endif() endif()
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels) target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels composable_kernel::jit_library)
add_subdirectory(driver) add_subdirectory(driver)
add_subdirectory(hiprtc) add_subdirectory(hiprtc)
......
...@@ -161,7 +161,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option ...@@ -161,7 +161,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
assert(not options.inputs.empty()); assert(not options.inputs.empty());
assert(options.inputs.size() == options.virtual_inputs.size() or assert(options.inputs.size() == options.virtual_inputs.size() or
options.virtual_inputs.empty()); options.virtual_inputs.empty());
std::vector<src_file> srcs; std::vector<src_file> srcs = options.additional_src_files;
std::transform(migraphx_kernels().begin(), std::transform(migraphx_kernels().begin(),
migraphx_kernels().end(), migraphx_kernels().end(),
std::back_inserter(srcs), std::back_inserter(srcs),
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/gpu/compiler.hpp> #include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/time_op.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -76,6 +77,109 @@ struct compiled_result ...@@ -76,6 +77,109 @@ struct compiled_result
instruction_ref ins; instruction_ref ins;
}; };
struct problem_cache
{
bool has(const std::string& name, const value& problem) const
{
return contains(cache, create_key(name, problem));
}
void insert(const std::string& name, const value& problem, const value& solution)
{
assert(not solution.is_null());
cache[create_key(name, problem)] = solution;
}
void mark(const std::string& name, const value& problem)
{
cache.insert(std::make_pair(create_key(name, problem), value{}));
}
optional<value> get(const std::string& name, const value& problem) const
{
auto it = cache.find(create_key(name, problem));
if(it == cache.end())
return nullopt;
return it->second;
}
static value create_key(const std::string& name, const value& problem)
{
return {{"name", name}, {"problem", problem}};
}
std::unordered_map<value, value> cache;
};
struct compile_plan
{
context* ctx;
operation preop;
instruction_ref ins;
optional<tuning_config> config = nullopt;
std::vector<compiled_result> results = {};
void update_config() { config = get_tuning_config(*ctx, ins, preop); }
template <class Vector>
void add_compiles(Vector& compiles, problem_cache& pc)
{
if(config.has_value())
{
const auto& problem = config->problem;
if(auto sol = pc.get(preop.name(), problem))
{
auto solution = sol.value();
// No solution yet until benchmarked so skip for now
if(solution.is_null())
return;
results.resize(1);
compiles.emplace_back([=] {
results[0] = compiled_result{compile(*ctx, ins, preop, solution), ins};
});
}
else
{
pc.mark(preop.name(), problem);
const auto& solutions = config->solutions;
results.resize(solutions.size());
for(auto i : range(solutions.size()))
{
auto solution = solutions[i];
compiles.emplace_back([=] {
results[i] = compiled_result{compile(*ctx, ins, preop, solution), ins};
});
}
}
}
else
{
results.resize(1);
compiles.emplace_back([=] {
results[0] = compiled_result{compile(*ctx, ins, preop, value{}), ins};
});
}
}
const compiled_result& benchmark(problem_cache& pc) const
{
if(results.empty())
MIGRAPHX_THROW("No configs to tune");
if(results.size() == 1)
return results.front();
if(not config)
MIGRAPHX_THROW("Multiple kernels without config");
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
<< std::endl;
std::vector<double> times;
times.reserve(results.size());
std::transform(
results.begin(), results.end(), std::back_inserter(times), [&](const auto& cr) {
return time_op(*ctx, cr.replace.code_object, to_shapes(cr.ins->inputs()), 20).first;
});
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
pc.insert(preop.name(), config->problem, config->solutions.at(i));
return results[i];
}
void replace(module& m, problem_cache& pc) const
{
const auto& cr = benchmark(pc);
cr.replace.replace(m, cr.ins);
}
};
template <class F> template <class F>
void par_compile(std::size_t n, F f) void par_compile(std::size_t n, F f)
{ {
...@@ -84,25 +188,67 @@ void par_compile(std::size_t n, F f) ...@@ -84,25 +188,67 @@ void par_compile(std::size_t n, F f)
par_for(n, n / value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}, n), f); par_for(n, n / value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}, n), f);
} }
void compile_ops::apply(module& m) const struct compile_manager
{ {
std::vector<std::function<compiled_result()>> compiles; problem_cache pc;
std::vector<compile_plan> cps;
bool exhaustive = false;
template <class... Ts>
void add_plan(Ts&&... xs)
{
cps.push_back({std::forward<Ts>(xs)...});
}
void update_configs()
{
if(not exhaustive)
return;
par_compile(cps.size(), [&](auto i) { cps[i].update_config(); });
}
void compile(module& m)
{
std::vector<std::function<void()>> compiles;
for(auto& cp : cps)
{
cp.add_compiles(compiles, pc);
}
par_compile(compiles.size(), [&](auto i) { compiles[i](); });
// Replace and/or benchmark
for(const auto& cp : cps)
{
if(cp.results.empty())
continue;
cp.replace(m, pc);
}
// Remove compile_plan already executed
cps.erase(std::remove_if(cps.begin(),
cps.end(),
[](const auto& cp) { return not cp.results.empty(); }),
cps.end());
}
};
void compile_ops::apply(module& m) const
{
compile_manager cm;
cm.exhaustive = exhaustive_tune;
// Find all precompile opes
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name() != "gpu::precompile_op") if(ins->name() != "gpu::precompile_op")
continue; continue;
operation preop = any_cast<precompile_op>(ins->get_operator()).op; operation preop = any_cast<precompile_op>(ins->get_operator()).op;
compiles.emplace_back([=]() -> compiled_result { cm.add_plan(ctx, preop, ins);
return {compile(*ctx, ins, preop), ins};
});
}
std::vector<compiled_result> results(compiles.size());
par_compile(compiles.size(), [&](auto i) { results[i] = compiles[i](); });
for(const auto& cr : results)
{
cr.replace.replace(m, cr.ins);
} }
cm.update_configs();
cm.compile(m);
// Compile already tuned configs
cm.compile(m);
assert(cm.cps.empty());
} }
} // namespace gpu } // namespace gpu
......
...@@ -28,33 +28,44 @@ namespace migraphx { ...@@ -28,33 +28,44 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
auto& compiler_map() namespace {
struct compiler_handle
{ {
static std::unordered_map<std::string, compiler_compile> m; // NOLINT compiler_compile compile;
return m; compiler_compile_op compile_op;
} compiler_tuning_config get_tuning_config;
};
} // namespace
auto& compiler_op_map() auto& compiler_map()
{ {
static std::unordered_map<std::string, compiler_compile_op> m; // NOLINT static std::unordered_map<std::string, compiler_handle> m; // NOLINT
return m; return m;
} }
void register_compiler(const std::string& name, compiler_compile c, compiler_compile_op cop) void register_compiler(const std::string& name,
compiler_compile c,
compiler_compile_op cop,
compiler_tuning_config ctg)
{ {
compiler_map()[name] = std::move(c); compiler_map()[name] = {std::move(c), std::move(cop), std::move(ctg)};
compiler_op_map()[name] = std::move(cop);
} }
bool has_compiler_for(const std::string& name) { return compiler_map().count(name) > 0; } bool has_compiler_for(const std::string& name) { return compiler_map().count(name) > 0; }
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) compiler_replace
compile(context& ctx, instruction_ref ins, const operation& op, const value& solution)
{ {
return compiler_map().at(op.name())(ctx, ins, op); return compiler_map().at(op.name()).compile(ctx, ins, op, solution);
} }
operation operation
compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v) compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v)
{ {
return compiler_op_map().at(name)(ctx, inputs, v); return compiler_map().at(name).compile_op(ctx, inputs, v);
}
optional<tuning_config> get_tuning_config(context& ctx, instruction_ref ins, const operation& op)
{
return compiler_map().at(op.name()).get_tuning_config(ctx, ins, op);
} }
} // namespace gpu } // namespace gpu
......
...@@ -67,18 +67,19 @@ void multinomial(hipStream_t stream, ...@@ -67,18 +67,19 @@ void multinomial(hipStream_t stream,
size_t class_size = arg0.get_shape().lens().back(); size_t class_size = arg0.get_shape().lens().back();
size_t sample_size = result.get_shape().lens().back(); size_t sample_size = result.get_shape().lens().back();
hip_visit_all(arg0, arg1)([&](auto cdf, auto dist) { visit_all(arg0, arg1)([&](auto cdf_host, auto dist_host) {
result.visit([&](auto out) { result.visit([&](auto output_host) {
hip_visit_views(out)([&](auto output) { hip_visit_views(cdf_host, dist_host, output_host)(
gs_launch(stream, batch_size * sample_size)([=](auto i) __device__ { [&](auto cdf, auto dist, auto output) {
auto idx = output.get_shape().multi(i); gs_launch(stream, batch_size * sample_size)([=](auto i) __device__ {
auto cdf_begin = cdf.begin() + (idx.front() * class_size); auto idx = output.get_shape().multi(i);
auto cdf_end = cdf_begin + class_size; auto cdf_begin = cdf.begin() + (idx.front() * class_size);
auto sample_iter = auto cdf_end = cdf_begin + class_size;
upper_bound(cdf_begin, cdf_end, dist[i] * *(std::prev(cdf_end))); auto* sample_iter =
output[i] = std::distance(cdf_begin, sample_iter); upper_bound(cdf_begin, cdf_end, dist[i] * *(std::prev(cdf_end)));
output[i] = std::distance(cdf_begin, sample_iter);
});
}); });
});
}); });
}); });
} }
......
...@@ -43,6 +43,8 @@ auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string( ...@@ -43,6 +43,8 @@ auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string(
return std::string(props.gcnArchName); return std::string(props.gcnArchName);
} }
std::string get_arch_name(const hipDeviceProp_t& props) { return get_arch_name(rank<1>{}, props); }
int get_device_id() int get_device_id()
{ {
int device; int device;
...@@ -58,7 +60,7 @@ std::string get_device_name() ...@@ -58,7 +60,7 @@ std::string get_device_name()
auto status = hipGetDeviceProperties(&props, get_device_id()); auto status = hipGetDeviceProperties(&props, get_device_id());
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPHX_THROW("Failed to get device properties"); MIGRAPHX_THROW("Failed to get device properties");
return get_arch_name(rank<1>{}, props); return get_arch_name(props);
} }
} // namespace gpu } // namespace gpu
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/driver/action.hpp> #include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/driver/perf.hpp> #include <migraphx/gpu/time_op.hpp>
#include <migraphx/gpu/compiler.hpp> #include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/driver/action.hpp> #include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/driver/perf.hpp> #include <migraphx/gpu/time_op.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu {
struct ck_gemm
{
operation op = make_op("dot");
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
}
std::string name() const { return "gpu::ck_gemm"; }
void check_gemm_shape(const shape& s) const
{
if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1))
MIGRAPHX_THROW("Invalid shape for ck_gemm");
}
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
auto a = inputs[0];
auto b = inputs[1];
for(const auto& input : inputs)
check_gemm_shape(input);
auto r = op.compute_shape({a, b});
if(mods.empty())
return r;
return r.with_type(mods.front()->get_output_shapes().front().type());
}
};
MIGRAPHX_REGISTER_OP(ck_gemm);
namespace {
bool is_ck_supported_type(shape::type_t t)
{
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t);
}
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{
if(ins->name() != "dot" and ins->name() != "quant_dot")
return false;
if(not is_ck_supported_type(ins->get_shape().type()))
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy
return a.lens().back() <= 2048;
}
struct find_ck_gemm_pointwise
{
// Find a gemm followed by a pointwise operation.
auto matcher() const
{
auto gemm = match::skip(match::name("contiguous"))(
match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm")));
return match::name("pointwise")(match::any_of[match::inputs()](gemm.bind("x")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm_ins = r.instructions["gemm"];
auto x_ins = r.instructions["x"]; // input after contiguous
auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names();
std::sort(names.begin(), names.end());
auto inputs = ins->inputs();
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin();
if(gemm_ins->get_shape().type() != shape::int32_type and
ins->get_shape().type() != gemm_ins->get_shape().type())
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not is_ck_supported_type(input->get_shape().type());
}))
return;
assert(gemm_it != inputs.end());
if(gemm_idx != 0)
{
auto first_param = pm->get_parameter(names[0]);
auto gemm_param = pm->get_parameter(names[gemm_idx]);
auto new_gemm_param = pm->add_parameter(names[0] + "_0", gemm_param->get_shape());
auto new_first_param =
pm->add_parameter(names[gemm_idx] + "_0", first_param->get_shape());
pm->replace_instruction(gemm_param, new_gemm_param);
pm->replace_instruction(first_param, new_first_param);
pm->remove_instruction(first_param);
pm->remove_instruction(gemm_param);
}
inputs.erase(gemm_it);
inputs.insert(inputs.begin(), gemm_ins->inputs().begin(), gemm_ins->inputs().end());
mpm.get_module().replace_instruction(ins, ck_gemm{gemm_ins->get_operator()}, inputs, {pm});
}
};
struct find_ck_gemm
{
auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
mpm.get_module().replace_instruction(ins, ck_gemm{ins->get_operator()}, ins->inputs());
}
};
} // namespace
void fuse_ck::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm{});
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/compile_src.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -39,9 +40,10 @@ struct hip_compile_options ...@@ -39,9 +40,10 @@ struct hip_compile_options
std::size_t local; std::size_t local;
std::vector<shape> inputs; std::vector<shape> inputs;
shape output; shape output;
std::string kernel_name = "kernel"; std::string kernel_name = "kernel";
std::string params = ""; std::string params = "";
std::vector<shape> virtual_inputs = {}; std::vector<shape> virtual_inputs = {};
std::vector<src_file> additional_src_files = {};
/** /**
* @brief Set the launch parameters but allow v to override the values * @brief Set the launch parameters but allow v to override the values
......
...@@ -38,7 +38,8 @@ struct context; ...@@ -38,7 +38,8 @@ struct context;
struct compile_ops struct compile_ops
{ {
context* ctx = nullptr; context* ctx = nullptr;
bool exhaustive_tune = false;
std::string name() const { return "gpu::compile_ops"; } std::string name() const { return "gpu::compile_ops"; }
void apply(module& m) const; void apply(module& m) const;
}; };
......
...@@ -30,6 +30,8 @@ ...@@ -30,6 +30,8 @@
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/rank.hpp>
#include <functional> #include <functional>
namespace migraphx { namespace migraphx {
...@@ -66,16 +68,30 @@ struct compiler_replace ...@@ -66,16 +68,30 @@ struct compiler_replace
} }
}; };
using compiler_compile = std::function<compiler_replace(context&, instruction_ref, operation)>; struct tuning_config
{
value problem;
std::vector<value> solutions;
};
using compiler_compile =
std::function<compiler_replace(context&, instruction_ref, operation, const value&)>;
using compiler_compile_op = using compiler_compile_op =
std::function<operation(context&, const std::vector<shape>& inputs, const value&)>; std::function<operation(context&, const std::vector<shape>& inputs, const value&)>;
using compiler_tuning_config =
std::function<optional<tuning_config>(context&, instruction_ref, const operation&)>;
void register_compiler(const std::string& name, compiler_compile c, compiler_compile_op cop); void register_compiler(const std::string& name,
compiler_compile c,
compiler_compile_op cop,
compiler_tuning_config ctg);
bool has_compiler_for(const std::string& name); bool has_compiler_for(const std::string& name);
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op); compiler_replace
compile(context& ctx, instruction_ref ins, const operation& op, const value& solution);
operation operation
compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v); compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v);
optional<tuning_config> get_tuning_config(context& ctx, instruction_ref ins, const operation& op);
template <class T> template <class T>
void register_compiler() void register_compiler()
...@@ -85,8 +101,11 @@ void register_compiler() ...@@ -85,8 +101,11 @@ void register_compiler()
{ {
register_compiler( register_compiler(
name, name,
[=](auto&&... xs) { return c.compile(std::forward<decltype(xs)>(xs)...); }, [=](auto&&... xs) {
[=](auto&&... xs) { return c.compile_op(std::forward<decltype(xs)>(xs)...); }); return c.invoke_compile(rank<1>{}, std::forward<decltype(xs)>(xs)...);
},
[=](auto&&... xs) { return c.compile_op(std::forward<decltype(xs)>(xs)...); },
[=](auto&&... xs) { return c.get_tuning_config(std::forward<decltype(xs)>(xs)...); });
} }
} }
...@@ -105,7 +124,30 @@ using auto_register_compiler = auto_register<register_compiler_action, T>; ...@@ -105,7 +124,30 @@ using auto_register_compiler = auto_register<register_compiler_action, T>;
template <class Derived> template <class Derived>
struct compiler : auto_register_compiler<Derived> struct compiler : auto_register_compiler<Derived>
{ {
const Derived& derived() const { return static_cast<const Derived&>(*this); }
optional<tuning_config> get_tuning_config(context&, instruction_ref, const operation&) const
{
return nullopt;
}
operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; } operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; }
template <class D = Derived>
auto invoke_compile(
rank<1>, context& ctx, instruction_ref ins, operation op, const value& solution) const
-> decltype(std::declval<D>().compile(ctx, ins, std::move(op), solution))
{
return derived().compile(ctx, ins, std::move(op), solution);
}
template <class D = Derived>
auto invoke_compile(
rank<0>, context& ctx, instruction_ref ins, operation op, const value& solution) const
-> decltype(std::declval<D>().compile(ctx, ins, std::move(op)))
{
assert(solution.empty());
(void)solution;
return derived().compile(ctx, ins, std::move(op));
}
}; };
} // namespace gpu } // namespace gpu
......
...@@ -170,7 +170,9 @@ struct hip_device ...@@ -170,7 +170,9 @@ struct hip_device
std::size_t stream_id() const { return current_stream; } std::size_t stream_id() const { return current_stream; }
std::string get_device_name() const { return device_props.gcnArchName; } std::string get_device_name() const { return get_arch_name(device_props); }
std::string get_gfx_name() const { return trim(split_string(get_device_name(), ':').front()); }
std::size_t get_device_major() const { return device_props.major; } std::size_t get_device_major() const { return device_props.major; }
......
...@@ -27,10 +27,14 @@ ...@@ -27,10 +27,14 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <string> #include <string>
struct hipDeviceProp_t;
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
std::string get_arch_name(const hipDeviceProp_t& props);
std::string get_device_name(); std::string get_device_name();
int get_device_id(); int get_device_id();
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_FUSE_CK_HPP
#define MIGRAPHX_GUARD_GPU_FUSE_CK_HPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_pass_manager;
namespace gpu {
struct fuse_ck
{
context* ctx = nullptr;
std::string name() const { return "gpu::fuse_ck"; }
void apply(module_pass_manager& mpm) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_FUSE_CK_HPP
...@@ -31,12 +31,10 @@ ...@@ -31,12 +31,10 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace driver {
std::pair<double, double> std::pair<double, double>
time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n = 100); time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n = 100);
} // namespace driver
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment