Unverified Commit 9c91c08d authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into enable_navi_32_ci

parents a56bb11d c1b8c975
...@@ -52,14 +52,6 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names ...@@ -52,14 +52,6 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names
auto mod_inputs = ins->module_inputs(); auto mod_inputs = ins->module_inputs();
auto s = ins->get_shape(); auto s = ins->get_shape();
// Convert back to original type before quantizing the inputs
if(mod_inputs.empty())
{
auto r = m.insert_instruction(
std::next(ins), make_op("convert", {{"target_type", s.type()}}), ins);
m.replace_instruction(ins, r);
}
// Convert each of the inputs that are floating point to fp16 // Convert each of the inputs that are floating point to fp16
auto inputs = ins->inputs(); auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
...@@ -70,8 +62,17 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names ...@@ -70,8 +62,17 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names
ins, make_op("convert", {{"target_type", shape::half_type}}), input); ins, make_op("convert", {{"target_type", shape::half_type}}), input);
}); });
// Replace inputs // Insert quantized ins
m.replace_instruction(ins, ins->get_operator(), inputs, mod_inputs); auto converted_ins = m.insert_instruction(ins, ins->get_operator(), inputs, mod_inputs);
// Convert back to original type after quantizing
if(mod_inputs.empty())
{
converted_ins = m.insert_instruction(
ins, make_op("convert", {{"target_type", s.type()}}), converted_ins);
}
// Replace original instruction
m.replace_instruction(ins, converted_ins);
} }
} }
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/pass_manager.hpp>
#include <migraphx/replace_allocate.hpp> #include <migraphx/replace_allocate.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
...@@ -84,10 +85,11 @@ void insert_submod_allocations(instruction_ref ins, module& mod, const allocatio ...@@ -84,10 +85,11 @@ void insert_submod_allocations(instruction_ref ins, module& mod, const allocatio
mod.replace_instruction(ins, ins->get_operator(), inputs, mod_args); mod.replace_instruction(ins, ins->get_operator(), inputs, mod_args);
} }
void replace_allocate::apply(module& m) const void replace_allocate::apply(module_pass_manager& mpm) const
{ {
module& m = mpm.get_module();
auto mod_output_names = create_output_names(m); auto mod_output_names = create_output_names(m);
bool main_offload_copy = m.name() == "main" ? this->offload_copy : false; bool root_offload_copy = (*mpm.get_root_module() == m) ? this->offload_copy : false;
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
auto op = ins->get_operator(); auto op = ins->get_operator();
...@@ -104,7 +106,7 @@ void replace_allocate::apply(module& m) const ...@@ -104,7 +106,7 @@ void replace_allocate::apply(module& m) const
continue; continue;
auto s = ins->get_shape(); auto s = ins->get_shape();
if(not main_offload_copy and model.needs_out_params() and contains(mod_output_names, ins)) if(not root_offload_copy and model.needs_out_params() and contains(mod_output_names, ins))
{ {
auto out_param = m.add_parameter(mod_output_names[ins], s); auto out_param = m.add_parameter(mod_output_names[ins], s);
m.replace_instruction(ins, out_param); m.replace_instruction(ins, out_param);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/target.hpp>
#include <migraphx/register_target.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void migraphx_to_value(value& v, const target& t) { v["name"] = t.name(); }
void migraphx_from_value(const value& v, target& t)
{
t = make_target(v.at("name").to<std::string>());
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -88,8 +88,6 @@ foreach(LIBRARY ${OpenMP_CXX_LIBRARIES}) ...@@ -88,8 +88,6 @@ foreach(LIBRARY ${OpenMP_CXX_LIBRARIES})
endif() endif()
endforeach() endforeach()
target_link_libraries(migraphx_all_targets INTERFACE migraphx_cpu)
rocm_install_targets( rocm_install_targets(
TARGETS migraphx_cpu TARGETS migraphx_cpu
INCLUDE INCLUDE
......
...@@ -33,7 +33,13 @@ if(NOT TARGET MIOpen) ...@@ -33,7 +33,13 @@ if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen") message(SEND_ERROR "Cant find miopen")
endif() endif()
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipClang APIs") find_package(composable_kernel 1.0.0 COMPONENTS jit_library REQUIRED)
if(BUILD_DEV)
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs")
else()
set(MIGRAPHX_USE_HIPRTC ON CACHE BOOL "Use hipRTC APIs")
endif()
include(Embed) include(Embed)
file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS} file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS}
...@@ -91,6 +97,7 @@ add_library(migraphx_gpu ...@@ -91,6 +97,7 @@ add_library(migraphx_gpu
compile_miopen.cpp compile_miopen.cpp
compiler.cpp compiler.cpp
device_name.cpp device_name.cpp
fuse_ck.cpp
fuse_mlir.cpp fuse_mlir.cpp
fuse_ops.cpp fuse_ops.cpp
gather.cpp gather.cpp
...@@ -119,6 +126,7 @@ add_library(migraphx_gpu ...@@ -119,6 +126,7 @@ add_library(migraphx_gpu
schedule_model.cpp schedule_model.cpp
sync_device.cpp sync_device.cpp
target.cpp target.cpp
time_op.cpp
topk.cpp topk.cpp
write_literals.cpp write_literals.cpp
${JIT_GPU_SRCS} ${JIT_GPU_SRCS}
...@@ -237,7 +245,7 @@ else() ...@@ -237,7 +245,7 @@ else()
endif() endif()
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels) target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels composable_kernel::jit_library)
add_subdirectory(driver) add_subdirectory(driver)
add_subdirectory(hiprtc) add_subdirectory(hiprtc)
......
...@@ -161,7 +161,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option ...@@ -161,7 +161,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
assert(not options.inputs.empty()); assert(not options.inputs.empty());
assert(options.inputs.size() == options.virtual_inputs.size() or assert(options.inputs.size() == options.virtual_inputs.size() or
options.virtual_inputs.empty()); options.virtual_inputs.empty());
std::vector<src_file> srcs; std::vector<src_file> srcs = options.additional_src_files;
std::transform(migraphx_kernels().begin(), std::transform(migraphx_kernels().begin(),
migraphx_kernels().end(), migraphx_kernels().end(),
std::back_inserter(srcs), std::back_inserter(srcs),
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/gpu/compiler.hpp> #include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/time_op.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -76,33 +77,201 @@ struct compiled_result ...@@ -76,33 +77,201 @@ struct compiled_result
instruction_ref ins; instruction_ref ins;
}; };
struct problem_cache
{
bool has(const std::string& name, const value& problem) const
{
return contains(cache, create_key(name, problem));
}
void insert(const std::string& name, const value& problem, const value& solution)
{
assert(not solution.is_null());
cache[create_key(name, problem)] = solution;
}
void mark(const std::string& name, const value& problem)
{
cache.insert(std::make_pair(create_key(name, problem), value{}));
}
optional<value> get(const std::string& name, const value& problem) const
{
auto it = cache.find(create_key(name, problem));
if(it == cache.end())
return nullopt;
return it->second;
}
static value create_key(const std::string& name, const value& problem)
{
return {{"name", name}, {"problem", problem}};
}
std::unordered_map<value, value> cache;
};
struct compile_plan
{
context* ctx;
operation preop;
instruction_ref ins;
optional<tuning_config> config = nullopt;
std::vector<optional<compiled_result>> results = {};
void update_config(bool exhaustive)
{
config = get_tuning_config(*ctx, ins, preop, exhaustive);
}
template <class Vector>
void insert_compiles(Vector& compiles, const value& solution, std::size_t i)
{
compiles.emplace_back([=] {
try
{
results[i] = compiled_result{compile(*ctx, ins, preop, solution), ins};
}
catch(...)
{
results[i] = nullopt;
}
});
}
template <class Vector>
void add_compiles(Vector& compiles, problem_cache& pc)
{
if(config.has_value())
{
const auto& problem = config->problem;
if(auto sol = pc.get(preop.name(), problem))
{
auto solution = sol.value();
// No solution yet until benchmarked so skip for now
if(solution.is_null())
return;
results.resize(1);
insert_compiles(compiles, solution, 0);
}
else
{
pc.mark(preop.name(), problem);
const auto& solutions = config->solutions;
results.resize(solutions.size());
for(auto i : range(solutions.size()))
{
auto solution = solutions[i];
insert_compiles(compiles, solution, i);
}
}
}
else
{
results.resize(1);
insert_compiles(compiles, value{}, 0);
}
}
const compiled_result& benchmark(problem_cache& pc) const
{
if(results.empty())
MIGRAPHX_THROW("No configs to tune");
if(results.size() == 1)
{
if(not results.front().has_value())
MIGRAPHX_THROW("No configs to tune");
return *results.front();
}
if(not config)
MIGRAPHX_THROW("Multiple kernels without config");
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
<< std::endl;
std::vector<double> times;
times.reserve(results.size());
std::transform(
results.begin(), results.end(), std::back_inserter(times), [&](const auto& cr) {
if(not cr.has_value())
return std::numeric_limits<double>::max();
return time_op(*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20)
.first;
});
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl;
pc.insert(preop.name(), config->problem, config->solutions.at(i));
if(not results[i].has_value())
MIGRAPHX_THROW("No valid tuned compilation.");
return *results[i];
}
void replace(module& m, problem_cache& pc) const
{
const auto& cr = benchmark(pc);
cr.replace.replace(m, cr.ins);
}
};
template <class F> template <class F>
void par_compile(std::size_t n, F f) void par_compile(std::size_t n, F f)
{ {
if(n == 0) if(n == 0)
return; return;
par_for(n, n / value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}, n), f); auto d = value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{});
if(d == 0)
d = n;
par_for(n, n / d, f);
} }
void compile_ops::apply(module& m) const struct compile_manager
{ {
std::vector<std::function<compiled_result()>> compiles; problem_cache pc;
std::vector<compile_plan> cps;
bool exhaustive = false;
template <class... Ts>
void add_plan(Ts&&... xs)
{
cps.push_back({std::forward<Ts>(xs)...});
}
void update_configs()
{
par_compile(cps.size(), [&](auto i) { cps[i].update_config(exhaustive); });
}
void compile(module& m)
{
std::vector<std::function<void()>> compiles;
for(auto& cp : cps)
{
cp.add_compiles(compiles, pc);
}
par_compile(compiles.size(), [&](auto i) { compiles[i](); });
// Replace and/or benchmark
for(const auto& cp : cps)
{
if(cp.results.empty())
continue;
cp.replace(m, pc);
}
// Remove compile_plan already executed
cps.erase(std::remove_if(cps.begin(),
cps.end(),
[](const auto& cp) { return not cp.results.empty(); }),
cps.end());
}
};
void compile_ops::apply(module& m) const
{
compile_manager cm;
cm.exhaustive = exhaustive_tune;
// Find all precompile opes
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name() != "gpu::precompile_op") if(ins->name() != "gpu::precompile_op")
continue; continue;
operation preop = any_cast<precompile_op>(ins->get_operator()).op; operation preop = any_cast<precompile_op>(ins->get_operator()).op;
compiles.emplace_back([=]() -> compiled_result { cm.add_plan(ctx, preop, ins);
return {compile(*ctx, ins, preop), ins};
});
}
std::vector<compiled_result> results(compiles.size());
par_compile(compiles.size(), [&](auto i) { results[i] = compiles[i](); });
for(const auto& cr : results)
{
cr.replace.replace(m, cr.ins);
} }
cm.update_configs();
cm.compile(m);
// Compile already tuned configs
cm.compile(m);
assert(cm.cps.empty());
} }
} // namespace gpu } // namespace gpu
......
...@@ -28,33 +28,45 @@ namespace migraphx { ...@@ -28,33 +28,45 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
auto& compiler_map() namespace {
struct compiler_handle
{ {
static std::unordered_map<std::string, compiler_compile> m; // NOLINT compiler_compile compile;
return m; compiler_compile_op compile_op;
} compiler_tuning_config get_tuning_config;
};
} // namespace
auto& compiler_op_map() auto& compiler_map()
{ {
static std::unordered_map<std::string, compiler_compile_op> m; // NOLINT static std::unordered_map<std::string, compiler_handle> m; // NOLINT
return m; return m;
} }
void register_compiler(const std::string& name, compiler_compile c, compiler_compile_op cop) void register_compiler(const std::string& name,
compiler_compile c,
compiler_compile_op cop,
compiler_tuning_config ctg)
{ {
compiler_map()[name] = std::move(c); compiler_map()[name] = {std::move(c), std::move(cop), std::move(ctg)};
compiler_op_map()[name] = std::move(cop);
} }
bool has_compiler_for(const std::string& name) { return compiler_map().count(name) > 0; } bool has_compiler_for(const std::string& name) { return compiler_map().count(name) > 0; }
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) compiler_replace
compile(context& ctx, instruction_ref ins, const operation& op, const value& solution)
{ {
return compiler_map().at(op.name())(ctx, ins, op); return compiler_map().at(op.name()).compile(ctx, ins, op, solution);
} }
operation operation
compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v) compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v)
{ {
return compiler_op_map().at(name)(ctx, inputs, v); return compiler_map().at(name).compile_op(ctx, inputs, v);
}
optional<tuning_config>
get_tuning_config(context& ctx, instruction_ref ins, const operation& op, bool exhaustive)
{
return compiler_map().at(op.name()).get_tuning_config(ctx, ins, op, exhaustive);
} }
} // namespace gpu } // namespace gpu
......
...@@ -67,18 +67,19 @@ void multinomial(hipStream_t stream, ...@@ -67,18 +67,19 @@ void multinomial(hipStream_t stream,
size_t class_size = arg0.get_shape().lens().back(); size_t class_size = arg0.get_shape().lens().back();
size_t sample_size = result.get_shape().lens().back(); size_t sample_size = result.get_shape().lens().back();
hip_visit_all(arg0, arg1)([&](auto cdf, auto dist) { visit_all(arg0, arg1)([&](auto cdf_host, auto dist_host) {
result.visit([&](auto out) { result.visit([&](auto output_host) {
hip_visit_views(out)([&](auto output) { hip_visit_views(cdf_host, dist_host, output_host)(
gs_launch(stream, batch_size * sample_size)([=](auto i) __device__ { [&](auto cdf, auto dist, auto output) {
auto idx = output.get_shape().multi(i); gs_launch(stream, batch_size * sample_size)([=](auto i) __device__ {
auto cdf_begin = cdf.begin() + (idx.front() * class_size); auto idx = output.get_shape().multi(i);
auto cdf_end = cdf_begin + class_size; auto cdf_begin = cdf.begin() + (idx.front() * class_size);
auto sample_iter = auto cdf_end = cdf_begin + class_size;
upper_bound(cdf_begin, cdf_end, dist[i] * *(std::prev(cdf_end))); auto* sample_iter =
output[i] = std::distance(cdf_begin, sample_iter); upper_bound(cdf_begin, cdf_end, dist[i] * *(std::prev(cdf_end)));
output[i] = std::distance(cdf_begin, sample_iter);
});
}); });
});
}); });
}); });
} }
......
...@@ -43,6 +43,8 @@ auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string( ...@@ -43,6 +43,8 @@ auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string(
return std::string(props.gcnArchName); return std::string(props.gcnArchName);
} }
std::string get_arch_name(const hipDeviceProp_t& props) { return get_arch_name(rank<1>{}, props); }
int get_device_id() int get_device_id()
{ {
int device; int device;
...@@ -58,7 +60,7 @@ std::string get_device_name() ...@@ -58,7 +60,7 @@ std::string get_device_name()
auto status = hipGetDeviceProperties(&props, get_device_id()); auto status = hipGetDeviceProperties(&props, get_device_id());
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPHX_THROW("Failed to get device properties"); MIGRAPHX_THROW("Failed to get device properties");
return get_arch_name(rank<1>{}, props); return get_arch_name(props);
} }
} // namespace gpu } // namespace gpu
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/driver/action.hpp> #include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/driver/perf.hpp> #include <migraphx/gpu/time_op.hpp>
#include <migraphx/gpu/compiler.hpp> #include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/driver/action.hpp> #include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/driver/perf.hpp> #include <migraphx/gpu/time_op.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu {
struct ck_gemm
{
operation op = make_op("dot");
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
}
std::string name() const { return "gpu::ck_gemm"; }
void check_gemm_shape(const shape& s) const
{
if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1))
MIGRAPHX_THROW("Invalid shape for ck_gemm");
}
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
auto a = inputs[0];
auto b = inputs[1];
for(const auto& input : inputs)
check_gemm_shape(input);
auto r = op.compute_shape({a, b});
if(mods.empty())
return r;
return r.with_type(mods.front()->get_output_shapes().front().type());
}
};
MIGRAPHX_REGISTER_OP(ck_gemm);
namespace {
bool is_ck_supported_type(shape::type_t t)
{
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t);
}
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{
if(ins->name() != "dot" and ins->name() != "quant_dot")
return false;
if(not is_ck_supported_type(ins->get_shape().type()))
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
auto m = a.lens()[a.lens().size() - 2];
auto n = b.lens().back();
auto k = a.lens().back();
// Integer gemms must be divisible by 4 in ck
if(contains({shape::int8_type, shape::int32_type}, ins->get_shape().type()))
{
if(m % 4 != 0)
return false;
if(n % 4 != 0)
return false;
if(k % 4 != 0)
return false;
}
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy
return k <= 2048;
}
struct find_ck_gemm_pointwise
{
// Find a gemm followed by a pointwise operation.
auto matcher() const
{
auto gemm = match::skip(match::name("contiguous"))(
match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm")));
return match::name("pointwise")(match::any_of[match::inputs()](gemm.bind("x")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm_ins = r.instructions["gemm"];
auto x_ins = r.instructions["x"]; // input after contiguous
auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names();
std::sort(names.begin(), names.end());
auto inputs = ins->inputs();
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin();
if(gemm_ins->get_shape().type() != shape::int32_type and
ins->get_shape().type() != gemm_ins->get_shape().type())
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not is_ck_supported_type(input->get_shape().type());
}))
return;
assert(gemm_it != inputs.end());
if(gemm_idx != 0)
{
auto first_param = pm->get_parameter(names[0]);
auto gemm_param = pm->get_parameter(names[gemm_idx]);
auto new_gemm_param = pm->add_parameter(names[0] + "_0", gemm_param->get_shape());
auto new_first_param =
pm->add_parameter(names[gemm_idx] + "_0", first_param->get_shape());
pm->replace_instruction(gemm_param, new_gemm_param);
pm->replace_instruction(first_param, new_first_param);
pm->remove_instruction(first_param);
pm->remove_instruction(gemm_param);
}
inputs.erase(gemm_it);
inputs.insert(inputs.begin(), gemm_ins->inputs().begin(), gemm_ins->inputs().end());
mpm.get_module().replace_instruction(ins, ck_gemm{gemm_ins->get_operator()}, inputs, {pm});
}
};
struct find_ck_gemm
{
auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
mpm.get_module().replace_instruction(ins, ck_gemm{ins->get_operator()}, ins->inputs());
}
};
} // namespace
void fuse_ck::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm{});
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -139,7 +139,8 @@ struct find_mlir_op ...@@ -139,7 +139,8 @@ struct find_mlir_op
auto matcher() const auto matcher() const
{ {
auto dot_or_conv = match::skip(match::name("contiguous"))( auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(match::name("dot"), is_mlir_conv()).bind("gemm_based_op")); match::any_of(match::name("dot"), match::name("quant_dot"), is_mlir_conv())
.bind("gemm_based_op"));
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x"))); return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
} }
...@@ -190,6 +191,68 @@ struct find_mlir_op ...@@ -190,6 +191,68 @@ struct find_mlir_op
return {new_gemm_based_op, top_inputs}; return {new_gemm_based_op, top_inputs};
} }
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
bool is_pointwise_op_supported_by_mlir(const instruction& i) const
{
using type_t = shape::type_t;
const auto& name = i.name();
const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type,
type_t::int8_type,
type_t::int32_type,
type_t::bool_type};
// Preliminary type check.
if(not contains(allowed_types, result_type))
{
return false;
}
const std::initializer_list<std::string> any_type_ops = {"@literal", "@param", "@return"};
const std::initializer_list<std::string> no_bool_ops = {"convolution",
"quant_convolution",
"dot",
"quant_dot",
"add",
"clip",
"sub",
"mul",
"div",
"pow",
"where",
"quantizelinear",
"dequantizelinear",
"abs",
"neg"};
const std::initializer_list<std::string> fp_only_ops = {"ceil",
"erf",
"exp",
"floor",
"log",
"recip",
"rsqrt",
"sigmoid"
"softmax",
"tanh"};
bool is_float = contains({type_t::float_type, type_t::half_type}, result_type);
if(contains(any_type_ops, name))
return true;
if(result_type != type_t::bool_type && contains(no_bool_ops, name))
return true;
if(is_float && contains(fp_only_ops, name))
return true;
// Only conversions between floating types are known to be unambigiously
// supported.
if(is_float && name == "convert")
{
return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
});
}
return false;
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
...@@ -197,31 +260,12 @@ struct find_mlir_op ...@@ -197,31 +260,12 @@ struct find_mlir_op
auto x_ins = r.instructions["x"]; // input after contiguous auto x_ins = r.instructions["x"]; // input after contiguous
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names(); auto names = pm->get_parameter_names();
// Whitelist pointwise operators // Whitelist pointwise operators.
if(std::any_of(pm->begin(), pm->end(), [](const auto& i) { if(std::any_of(pm->begin(), pm->end(), [&](const auto& i) {
return not contains({"@literal", return not is_pointwise_op_supported_by_mlir(i);
"@param",
"@return",
"convolution",
"quant_convolution",
"dot",
"add",
"relu",
"dequantizelinear",
"quantizelinear",
"mul"},
i.name());
}))
return;
// Only fuse with fp32/fp16/int8/int32
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return not contains({shape::type_t::float_type,
shape::type_t::half_type,
shape::type_t::int8_type,
shape::type_t::int32_type},
i->get_shape().type());
})) }))
return; return;
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
module_ref mm = mpm.create_module("mlir_" + pm->name()); module_ref mm = mpm.create_module("mlir_" + pm->name());
mm->set_bypass(); mm->set_bypass();
......
...@@ -165,7 +165,8 @@ struct fusion ...@@ -165,7 +165,8 @@ struct fusion
const std::unordered_set<std::string>& get_supported_archs() const std::unordered_set<std::string>& get_supported_archs()
{ {
static std::unordered_set<std::string> supported_archs{"gfx900", "gfx906", "gfx908", "gfx1030"}; static std::unordered_set<std::string> supported_archs{
"gfx900", "gfx906", "gfx908", "gfx1030", "gfx940"};
return supported_archs; return supported_archs;
} }
......
...@@ -146,7 +146,11 @@ std::vector<T> read_from_gpu(const void* x, std::size_t sz) ...@@ -146,7 +146,11 @@ std::vector<T> read_from_gpu(const void* x, std::size_t sz)
gpu_sync(); gpu_sync();
std::vector<T> result(sz); std::vector<T> result(sz);
assert(not is_device_ptr(result.data())); assert(not is_device_ptr(result.data()));
assert(is_device_ptr(x)); if(not is_device_ptr(x))
{
MIGRAPHX_THROW(
"read_from_gpu() requires Src buffer to be on the GPU, Copy from gpu failed\n");
}
auto status = hipMemcpy(result.data(), x, sz * sizeof(T), hipMemcpyDeviceToHost); auto status = hipMemcpy(result.data(), x, sz * sizeof(T), hipMemcpyDeviceToHost);
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPHX_THROW("Copy from gpu failed: " + hip_error(status)); // NOLINT MIGRAPHX_THROW("Copy from gpu failed: " + hip_error(status)); // NOLINT
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/compile_src.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -39,9 +40,10 @@ struct hip_compile_options ...@@ -39,9 +40,10 @@ struct hip_compile_options
std::size_t local; std::size_t local;
std::vector<shape> inputs; std::vector<shape> inputs;
shape output; shape output;
std::string kernel_name = "kernel"; std::string kernel_name = "kernel";
std::string params = ""; std::string params = "";
std::vector<shape> virtual_inputs = {}; std::vector<shape> virtual_inputs = {};
std::vector<src_file> additional_src_files = {};
/** /**
* @brief Set the launch parameters but allow v to override the values * @brief Set the launch parameters but allow v to override the values
......
...@@ -38,7 +38,8 @@ struct context; ...@@ -38,7 +38,8 @@ struct context;
struct compile_ops struct compile_ops
{ {
context* ctx = nullptr; context* ctx = nullptr;
bool exhaustive_tune = false;
std::string name() const { return "gpu::compile_ops"; } std::string name() const { return "gpu::compile_ops"; }
void apply(module& m) const; void apply(module& m) const;
}; };
......
...@@ -30,6 +30,8 @@ ...@@ -30,6 +30,8 @@
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/rank.hpp>
#include <functional> #include <functional>
namespace migraphx { namespace migraphx {
...@@ -66,16 +68,31 @@ struct compiler_replace ...@@ -66,16 +68,31 @@ struct compiler_replace
} }
}; };
using compiler_compile = std::function<compiler_replace(context&, instruction_ref, operation)>; struct tuning_config
{
value problem;
std::vector<value> solutions;
};
using compiler_compile =
std::function<compiler_replace(context&, instruction_ref, operation, const value&)>;
using compiler_compile_op = using compiler_compile_op =
std::function<operation(context&, const std::vector<shape>& inputs, const value&)>; std::function<operation(context&, const std::vector<shape>& inputs, const value&)>;
using compiler_tuning_config =
std::function<optional<tuning_config>(context&, instruction_ref, const operation&, bool)>;
void register_compiler(const std::string& name, compiler_compile c, compiler_compile_op cop); void register_compiler(const std::string& name,
compiler_compile c,
compiler_compile_op cop,
compiler_tuning_config ctg);
bool has_compiler_for(const std::string& name); bool has_compiler_for(const std::string& name);
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op); compiler_replace
compile(context& ctx, instruction_ref ins, const operation& op, const value& solution);
operation operation
compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v); compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v);
optional<tuning_config>
get_tuning_config(context& ctx, instruction_ref ins, const operation& op, bool exhaustive);
template <class T> template <class T>
void register_compiler() void register_compiler()
...@@ -85,8 +102,11 @@ void register_compiler() ...@@ -85,8 +102,11 @@ void register_compiler()
{ {
register_compiler( register_compiler(
name, name,
[=](auto&&... xs) { return c.compile(std::forward<decltype(xs)>(xs)...); }, [=](auto&&... xs) {
[=](auto&&... xs) { return c.compile_op(std::forward<decltype(xs)>(xs)...); }); return c.invoke_compile(rank<1>{}, std::forward<decltype(xs)>(xs)...);
},
[=](auto&&... xs) { return c.compile_op(std::forward<decltype(xs)>(xs)...); },
[=](auto&&... xs) { return c.get_tuning_config(std::forward<decltype(xs)>(xs)...); });
} }
} }
...@@ -105,7 +125,31 @@ using auto_register_compiler = auto_register<register_compiler_action, T>; ...@@ -105,7 +125,31 @@ using auto_register_compiler = auto_register<register_compiler_action, T>;
template <class Derived> template <class Derived>
struct compiler : auto_register_compiler<Derived> struct compiler : auto_register_compiler<Derived>
{ {
const Derived& derived() const { return static_cast<const Derived&>(*this); }
optional<tuning_config>
get_tuning_config(context&, instruction_ref, const operation&, bool) const
{
return nullopt;
}
operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; } operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; }
template <class D = Derived>
auto invoke_compile(
rank<1>, context& ctx, instruction_ref ins, operation op, const value& solution) const
-> decltype(std::declval<D>().compile(ctx, ins, std::move(op), solution))
{
return derived().compile(ctx, ins, std::move(op), solution);
}
template <class D = Derived>
auto invoke_compile(
rank<0>, context& ctx, instruction_ref ins, operation op, const value& solution) const
-> decltype(std::declval<D>().compile(ctx, ins, std::move(op)))
{
assert(solution.empty());
(void)solution;
return derived().compile(ctx, ins, std::move(op));
}
}; };
} // namespace gpu } // namespace gpu
......
...@@ -170,7 +170,9 @@ struct hip_device ...@@ -170,7 +170,9 @@ struct hip_device
std::size_t stream_id() const { return current_stream; } std::size_t stream_id() const { return current_stream; }
std::string get_device_name() const { return device_props.gcnArchName; } std::string get_device_name() const { return get_arch_name(device_props); }
std::string get_gfx_name() const { return trim(split_string(get_device_name(), ':').front()); }
std::size_t get_device_major() const { return device_props.major; } std::size_t get_device_major() const { return device_props.major; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment