"tools/vscode:/vscode.git/clone" did not exist on "2d252c9e52859576a0821cb9e2f13aec1d1c1458"
Commit d7b1895a authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into nhwc_workaround

parents 1add453a a83371ca
...@@ -34,7 +34,7 @@ void promote_literals::apply(module_pass_manager& mpm) const ...@@ -34,7 +34,7 @@ void promote_literals::apply(module_pass_manager& mpm) const
{ {
module& m = mpm.get_module(); module& m = mpm.get_module();
module_ref root_module = mpm.get_root_module(); module_ref root_module = mpm.get_root_module();
if(m.name() == "main") if(m == *root_module)
return; return;
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
......
...@@ -52,14 +52,6 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names ...@@ -52,14 +52,6 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names
auto mod_inputs = ins->module_inputs(); auto mod_inputs = ins->module_inputs();
auto s = ins->get_shape(); auto s = ins->get_shape();
// Convert back to original type before quantizing the inputs
if(mod_inputs.empty())
{
auto r = m.insert_instruction(
std::next(ins), make_op("convert", {{"target_type", s.type()}}), ins);
m.replace_instruction(ins, r);
}
// Convert each of the inputs that are floating point to fp16 // Convert each of the inputs that are floating point to fp16
auto inputs = ins->inputs(); auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
...@@ -70,8 +62,17 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names ...@@ -70,8 +62,17 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names
ins, make_op("convert", {{"target_type", shape::half_type}}), input); ins, make_op("convert", {{"target_type", shape::half_type}}), input);
}); });
// Replace inputs // Insert quantized ins
m.replace_instruction(ins, ins->get_operator(), inputs, mod_inputs); auto converted_ins = m.insert_instruction(ins, ins->get_operator(), inputs, mod_inputs);
// Convert back to original type after quantizing
if(mod_inputs.empty())
{
converted_ins = m.insert_instruction(
ins, make_op("convert", {{"target_type", s.type()}}), converted_ins);
}
// Replace original instruction
m.replace_instruction(ins, converted_ins);
} }
} }
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/pass_manager.hpp>
#include <migraphx/replace_allocate.hpp> #include <migraphx/replace_allocate.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
...@@ -84,10 +85,11 @@ void insert_submod_allocations(instruction_ref ins, module& mod, const allocatio ...@@ -84,10 +85,11 @@ void insert_submod_allocations(instruction_ref ins, module& mod, const allocatio
mod.replace_instruction(ins, ins->get_operator(), inputs, mod_args); mod.replace_instruction(ins, ins->get_operator(), inputs, mod_args);
} }
void replace_allocate::apply(module& m) const void replace_allocate::apply(module_pass_manager& mpm) const
{ {
module& m = mpm.get_module();
auto mod_output_names = create_output_names(m); auto mod_output_names = create_output_names(m);
bool main_offload_copy = m.name() == "main" ? this->offload_copy : false; bool root_offload_copy = (*mpm.get_root_module() == m) ? this->offload_copy : false;
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
auto op = ins->get_operator(); auto op = ins->get_operator();
...@@ -104,7 +106,7 @@ void replace_allocate::apply(module& m) const ...@@ -104,7 +106,7 @@ void replace_allocate::apply(module& m) const
continue; continue;
auto s = ins->get_shape(); auto s = ins->get_shape();
if(not main_offload_copy and model.needs_out_params() and contains(mod_output_names, ins)) if(not root_offload_copy and model.needs_out_params() and contains(mod_output_names, ins))
{ {
auto out_param = m.add_parameter(mod_output_names[ins], s); auto out_param = m.add_parameter(mod_output_names[ins], s);
m.replace_instruction(ins, out_param); m.replace_instruction(ins, out_param);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/target.hpp>
#include <migraphx/register_target.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void migraphx_to_value(value& v, const target& t) { v["name"] = t.name(); }
void migraphx_from_value(const value& v, target& t)
{
t = make_target(v.at("name").to<std::string>());
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -111,9 +111,27 @@ struct compile_plan ...@@ -111,9 +111,27 @@ struct compile_plan
context* ctx; context* ctx;
operation preop; operation preop;
instruction_ref ins; instruction_ref ins;
optional<tuning_config> config = nullopt; optional<tuning_config> config = nullopt;
std::vector<compiled_result> results = {}; std::vector<optional<compiled_result>> results = {};
void update_config() { config = get_tuning_config(*ctx, ins, preop); } void update_config(bool exhaustive)
{
config = get_tuning_config(*ctx, ins, preop, exhaustive);
}
template <class Vector>
void insert_compiles(Vector& compiles, const value& solution, std::size_t i)
{
compiles.emplace_back([=] {
try
{
results[i] = compiled_result{compile(*ctx, ins, preop, solution), ins};
}
catch(...)
{
results[i] = nullopt;
}
});
}
template <class Vector> template <class Vector>
void add_compiles(Vector& compiles, problem_cache& pc) void add_compiles(Vector& compiles, problem_cache& pc)
{ {
...@@ -127,9 +145,7 @@ struct compile_plan ...@@ -127,9 +145,7 @@ struct compile_plan
if(solution.is_null()) if(solution.is_null())
return; return;
results.resize(1); results.resize(1);
compiles.emplace_back([=] { insert_compiles(compiles, solution, 0);
results[0] = compiled_result{compile(*ctx, ins, preop, solution), ins};
});
} }
else else
{ {
...@@ -139,18 +155,14 @@ struct compile_plan ...@@ -139,18 +155,14 @@ struct compile_plan
for(auto i : range(solutions.size())) for(auto i : range(solutions.size()))
{ {
auto solution = solutions[i]; auto solution = solutions[i];
compiles.emplace_back([=] { insert_compiles(compiles, solution, i);
results[i] = compiled_result{compile(*ctx, ins, preop, solution), ins};
});
} }
} }
} }
else else
{ {
results.resize(1); results.resize(1);
compiles.emplace_back([=] { insert_compiles(compiles, value{}, 0);
results[0] = compiled_result{compile(*ctx, ins, preop, value{}), ins};
});
} }
} }
const compiled_result& benchmark(problem_cache& pc) const const compiled_result& benchmark(problem_cache& pc) const
...@@ -158,7 +170,11 @@ struct compile_plan ...@@ -158,7 +170,11 @@ struct compile_plan
if(results.empty()) if(results.empty())
MIGRAPHX_THROW("No configs to tune"); MIGRAPHX_THROW("No configs to tune");
if(results.size() == 1) if(results.size() == 1)
return results.front(); {
if(not results.front().has_value())
MIGRAPHX_THROW("No configs to tune");
return *results.front();
}
if(not config) if(not config)
MIGRAPHX_THROW("Multiple kernels without config"); MIGRAPHX_THROW("Multiple kernels without config");
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs" std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
...@@ -167,11 +183,17 @@ struct compile_plan ...@@ -167,11 +183,17 @@ struct compile_plan
times.reserve(results.size()); times.reserve(results.size());
std::transform( std::transform(
results.begin(), results.end(), std::back_inserter(times), [&](const auto& cr) { results.begin(), results.end(), std::back_inserter(times), [&](const auto& cr) {
return time_op(*ctx, cr.replace.code_object, to_shapes(cr.ins->inputs()), 20).first; if(not cr.has_value())
return std::numeric_limits<double>::max();
return time_op(*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20)
.first;
}); });
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end())); auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl;
pc.insert(preop.name(), config->problem, config->solutions.at(i)); pc.insert(preop.name(), config->problem, config->solutions.at(i));
return results[i]; if(not results[i].has_value())
MIGRAPHX_THROW("No valid tuned compilation.");
return *results[i];
} }
void replace(module& m, problem_cache& pc) const void replace(module& m, problem_cache& pc) const
{ {
...@@ -185,7 +207,10 @@ void par_compile(std::size_t n, F f) ...@@ -185,7 +207,10 @@ void par_compile(std::size_t n, F f)
{ {
if(n == 0) if(n == 0)
return; return;
par_for(n, n / value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}, n), f); auto d = value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{});
if(d == 0)
d = n;
par_for(n, n / d, f);
} }
struct compile_manager struct compile_manager
...@@ -202,9 +227,7 @@ struct compile_manager ...@@ -202,9 +227,7 @@ struct compile_manager
void update_configs() void update_configs()
{ {
if(not exhaustive) par_compile(cps.size(), [&](auto i) { cps[i].update_config(exhaustive); });
return;
par_compile(cps.size(), [&](auto i) { cps[i].update_config(); });
} }
void compile(module& m) void compile(module& m)
......
...@@ -63,9 +63,10 @@ compile_op(const std::string& name, context& ctx, const std::vector<shape>& inpu ...@@ -63,9 +63,10 @@ compile_op(const std::string& name, context& ctx, const std::vector<shape>& inpu
return compiler_map().at(name).compile_op(ctx, inputs, v); return compiler_map().at(name).compile_op(ctx, inputs, v);
} }
optional<tuning_config> get_tuning_config(context& ctx, instruction_ref ins, const operation& op) optional<tuning_config>
get_tuning_config(context& ctx, instruction_ref ins, const operation& op, bool exhaustive)
{ {
return compiler_map().at(op.name()).get_tuning_config(ctx, ins, op); return compiler_map().at(op.name()).get_tuning_config(ctx, ins, op, exhaustive);
} }
} // namespace gpu } // namespace gpu
......
...@@ -83,10 +83,23 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -83,10 +83,23 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return false; return false;
auto a = ins->inputs().front()->get_shape(); auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape(); auto b = ins->inputs().back()->get_shape();
auto m = a.lens()[a.lens().size() - 2];
auto n = b.lens().back();
auto k = a.lens().back();
// Integer gemms must be divisible by 4 in ck
if(contains({shape::int8_type, shape::int32_type}, ins->get_shape().type()))
{
if(m % 4 != 0)
return false;
if(n % 4 != 0)
return false;
if(k % 4 != 0)
return false;
}
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy // Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK // to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy // To-do: Investigate a more precise strategy
return a.lens().back() <= 2048; return k <= 2048;
} }
struct find_ck_gemm_pointwise struct find_ck_gemm_pointwise
......
...@@ -79,7 +79,7 @@ using compiler_compile = ...@@ -79,7 +79,7 @@ using compiler_compile =
using compiler_compile_op = using compiler_compile_op =
std::function<operation(context&, const std::vector<shape>& inputs, const value&)>; std::function<operation(context&, const std::vector<shape>& inputs, const value&)>;
using compiler_tuning_config = using compiler_tuning_config =
std::function<optional<tuning_config>(context&, instruction_ref, const operation&)>; std::function<optional<tuning_config>(context&, instruction_ref, const operation&, bool)>;
void register_compiler(const std::string& name, void register_compiler(const std::string& name,
compiler_compile c, compiler_compile c,
...@@ -91,7 +91,8 @@ compiler_replace ...@@ -91,7 +91,8 @@ compiler_replace
compile(context& ctx, instruction_ref ins, const operation& op, const value& solution); compile(context& ctx, instruction_ref ins, const operation& op, const value& solution);
operation operation
compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v); compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v);
optional<tuning_config> get_tuning_config(context& ctx, instruction_ref ins, const operation& op); optional<tuning_config>
get_tuning_config(context& ctx, instruction_ref ins, const operation& op, bool exhaustive);
template <class T> template <class T>
void register_compiler() void register_compiler()
...@@ -125,7 +126,8 @@ template <class Derived> ...@@ -125,7 +126,8 @@ template <class Derived>
struct compiler : auto_register_compiler<Derived> struct compiler : auto_register_compiler<Derived>
{ {
const Derived& derived() const { return static_cast<const Derived&>(*this); } const Derived& derived() const { return static_cast<const Derived&>(*this); }
optional<tuning_config> get_tuning_config(context&, instruction_ref, const operation&) const optional<tuning_config>
get_tuning_config(context&, instruction_ref, const operation&, bool) const
{ {
return nullopt; return nullopt;
} }
......
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module_pass_manager;
namespace gpu { namespace gpu {
...@@ -45,7 +45,7 @@ struct lowering ...@@ -45,7 +45,7 @@ struct lowering
context* ctx; context* ctx;
bool offload_copy; bool offload_copy;
std::string name() const { return "gpu::lowering"; } std::string name() const { return "gpu::lowering"; }
void apply(module& m) const; void apply(module_pass_manager& mpm) const;
}; };
} // namespace gpu } // namespace gpu
......
...@@ -50,6 +50,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM); ...@@ -50,6 +50,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING_VALUE); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING_VALUE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TUNE_CK);
// NOLINTNEXTLINE // NOLINTNEXTLINE
static const char* const ck_gemm_kernel = R"__migraphx__( static const char* const ck_gemm_kernel = R"__migraphx__(
...@@ -65,7 +66,7 @@ ${preamble} ...@@ -65,7 +66,7 @@ ${preamble}
extern "C" { extern "C" {
__global__ void ${kernel}(${params}) MIGRAPHX_GLOBAL void ${kernel}(${params})
{ {
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) { transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
ck_gemm<${solution}, ${blocks_per_batch}>(xs...); ck_gemm<${solution}, ${blocks_per_batch}>(xs...);
...@@ -265,7 +266,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -265,7 +266,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
s = shape{s.type(), {m1, m2}}; s = shape{s.type(), {m1, m2}};
} }
std::vector<std::string> names() const { return {"gpu::ck_gemm"}; } std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; }
static bool standard_batch(const shape& s) static bool standard_batch(const shape& s)
{ {
...@@ -418,9 +419,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -418,9 +419,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{ {
auto shapes = to_shapes(ins->inputs()); auto shapes = to_shapes(ins->inputs());
auto v = create_settings(ins, op); auto v = create_settings(ins, op);
if(solution.is_null()) if(not solution.is_null())
v["tuning_value"] = 4;
else
v["tuning_value"] = solution; v["tuning_value"] = solution;
return {compile_op(ctx, shapes, v), return {compile_op(ctx, shapes, v),
[=](module& m, instruction_ref ins2, const operation& code_object) { [=](module& m, instruction_ref ins2, const operation& code_object) {
...@@ -436,8 +435,10 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -436,8 +435,10 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
} }
optional<tuning_config> optional<tuning_config>
get_tuning_config(context& ctx, instruction_ref ins, const operation& op) const get_tuning_config(context& ctx, instruction_ref ins, const operation& op, bool exhaustive) const
{ {
if(not exhaustive and not enabled(MIGRAPHX_TUNE_CK{}))
return nullopt;
tuning_config tc; tuning_config tc;
auto shapes = to_shapes(ins->inputs()); auto shapes = to_shapes(ins->inputs());
auto problem = create_problem(shapes, create_settings(ins, op)); auto problem = create_problem(shapes, create_settings(ins, op));
......
...@@ -47,7 +47,7 @@ ${preamble} ...@@ -47,7 +47,7 @@ ${preamble}
extern "C" { extern "C" {
__global__ void ${kernel}(${params}) MIGRAPHX_GLOBAL void ${kernel}(${params})
{ {
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, ${concat_params}, auto... xs) { transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, ${concat_params}, auto... xs) {
concat<${axis}>(${concat_args})(${post}, y, xs...); concat<${axis}>(${concat_args})(${post}, y, xs...);
......
...@@ -44,7 +44,7 @@ namespace migraphx { ...@@ -44,7 +44,7 @@ namespace migraphx {
extern "C" { extern "C" {
__global__ void gather_kernel(void* in_data, void* in_indices, void* output) MIGRAPHX_GLOBAL void gather_kernel(void* in_data, void* in_indices, void* output)
{ {
make_tensors()(in_data, in_indices, output)([](auto&&... xs) { make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
gather<${axis}>(xs...); gather<${axis}>(xs...);
......
...@@ -44,7 +44,7 @@ namespace migraphx { ...@@ -44,7 +44,7 @@ namespace migraphx {
extern "C" { extern "C" {
__global__ void gathernd_kernel(void* in_data, void* in_indices, void* output) MIGRAPHX_GLOBAL void gathernd_kernel(void* in_data, void* in_indices, void* output)
{ {
make_tensors()(in_data, in_indices, output)([](auto&&... xs) { make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
auto settings = make_gathernd_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{BATCH_DIMS})); auto settings = make_gathernd_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{BATCH_DIMS}));
......
...@@ -48,7 +48,7 @@ namespace migraphx { ...@@ -48,7 +48,7 @@ namespace migraphx {
${preamble} ${preamble}
extern "C" { extern "C" {
__global__ void ${kernel}(${params}) MIGRAPHX_GLOBAL void ${kernel}(${params})
{ {
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) { transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
${layernorm}<${axis}>(${post}, ${eps}, xs...); ${layernorm}<${axis}>(${post}, ${eps}, xs...);
......
...@@ -44,7 +44,7 @@ static const char* const pointwise_kernel = R"__migraphx__( ...@@ -44,7 +44,7 @@ static const char* const pointwise_kernel = R"__migraphx__(
namespace migraphx { namespace migraphx {
extern "C" { extern "C" {
__global__ void pad_kernel(void* input_p, void* output_p) MIGRAPHX_GLOBAL void pad_kernel(void* input_p, void* output_p)
{ {
auto offsets = index_ints<${offsets}>{}; auto offsets = index_ints<${offsets}>{};
auto idx = make_index(); auto idx = make_index();
......
...@@ -44,7 +44,7 @@ namespace migraphx { ...@@ -44,7 +44,7 @@ namespace migraphx {
${preamble} ${preamble}
extern "C" { extern "C" {
__global__ void ${kernel}(${params}) MIGRAPHX_GLOBAL void ${kernel}(${params})
{ {
auto idx = make_index(); auto idx = make_index();
pointwise(idx, ${transformers})(${lambda}, ${args}); pointwise(idx, ${transformers})(${lambda}, ${args});
......
...@@ -45,7 +45,7 @@ namespace migraphx { ...@@ -45,7 +45,7 @@ namespace migraphx {
${preamble} ${preamble}
extern "C" { extern "C" {
__global__ void reduce_kernel(void* input_p, void* output_p) MIGRAPHX_GLOBAL void reduce_kernel(void* input_p, void* output_p)
{ {
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) { transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) {
......
...@@ -41,7 +41,7 @@ namespace migraphx { ...@@ -41,7 +41,7 @@ namespace migraphx {
extern "C" { extern "C" {
__global__ void roialign_kernel(void* in_x, void* in_rois, void* in_ind, void* y) MIGRAPHX_GLOBAL void roialign_kernel(void* in_x, void* in_rois, void* in_ind, void* y)
{ {
make_tensors()(in_x, in_rois, in_ind, y)([](auto&&... xs) { make_tensors()(in_x, in_rois, in_ind, y)([](auto&&... xs) {
auto settings = make_roalign_settings(MIGRAPHX_MAKE_CONSTANT(float{ROIS_OFFSET}), auto settings = make_roalign_settings(MIGRAPHX_MAKE_CONSTANT(float{ROIS_OFFSET}),
......
...@@ -42,7 +42,7 @@ namespace migraphx { ...@@ -42,7 +42,7 @@ namespace migraphx {
extern "C" { extern "C" {
__global__ void scatternd_kernel(void* in_indices, void* in_updates, void* output) MIGRAPHX_GLOBAL void scatternd_kernel(void* in_indices, void* in_updates, void* output)
{ {
make_tensors()(in_indices, in_updates, output)([](auto&&... xs) { make_tensors()(in_indices, in_updates, output)([](auto&&... xs) {
scatternd(xs..., ${reduction}{}); scatternd(xs..., ${reduction}{});
......
...@@ -45,7 +45,7 @@ static const char* const softmax_kernel = R"__migraphx__( ...@@ -45,7 +45,7 @@ static const char* const softmax_kernel = R"__migraphx__(
namespace migraphx { namespace migraphx {
extern "C" { extern "C" {
__global__ void softmax_kernel(void* input_p, void* output_p) MIGRAPHX_GLOBAL void softmax_kernel(void* input_p, void* output_p)
{ {
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) { transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) {
softmax<${axis}>(input, output); softmax<${axis}>(input, output);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment