Unverified Commit 25e8cf0b authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'develop' into test_onnx_zoo

parents a313a68e 635502be
...@@ -854,6 +854,25 @@ void program::print_graph(std::ostream& os, bool brief) const ...@@ -854,6 +854,25 @@ void program::print_graph(std::ostream& os, bool brief) const
mm->print_graph(os, brief); mm->print_graph(os, brief);
} }
void program::print_py(std::ostream& os) const
{
auto vec_modules = this->get_modules();
std::unordered_map<instruction_ref, std::string> names;
os << "p = migraphx.program()\n";
for(auto& mod : vec_modules)
{
std::string var_name = "m" + mod->name();
os << var_name << " = ";
if(mod->name() == "main")
os << "p.get_main_module()";
else
os << "p.create_module(\"" << mod->name() << "\");";
os << std::endl;
names = mod->print_py(os, var_name, names);
os << std::endl;
}
}
void program::print_cpp(std::ostream& os) const void program::print_cpp(std::ostream& os) const
{ {
auto vec_modules = this->get_modules(); auto vec_modules = this->get_modules();
......
...@@ -92,7 +92,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const ...@@ -92,7 +92,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// process sequence length // process sequence length
instruction_ref seq_lens = m.end(); instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined") if((args.size() >= 5) and not args[4]->is_undefined())
{ {
seq_lens = args[4]; seq_lens = args[4];
} }
...@@ -117,7 +117,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const ...@@ -117,7 +117,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// process bias // process bias
instruction_ref bias_forward = m.end(); instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = m.end(); instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias_forward = m.insert_instruction( bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
...@@ -129,7 +129,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const ...@@ -129,7 +129,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// or the 5th one (if the sequence len argument is ignored) // or the 5th one (if the sequence len argument is ignored)
instruction_ref ih_forward{}; instruction_ref ih_forward{};
instruction_ref ih_reverse{}; instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined") if(args.size() == 6 and not args[5]->is_undefined())
{ {
ih_forward = m.insert_instruction( ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
...@@ -195,14 +195,14 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const ...@@ -195,14 +195,14 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// process bias and initial hidden state // process bias and initial hidden state
instruction_ref bias = m.end(); instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias = args[3]; bias = args[3];
} }
// process intial hidden state // process intial hidden state
instruction_ref ih; instruction_ref ih;
if(args.size() == 6 && args[5]->name() != "undefined") if(args.size() == 6 and not args[5]->is_undefined())
{ {
ih = args[5]; ih = args[5];
} }
...@@ -398,7 +398,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const ...@@ -398,7 +398,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// process sequence length // process sequence length
instruction_ref seq_lens = m.end(); instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined") if((args.size() >= 5) and not args[4]->is_undefined())
{ {
seq_lens = args[4]; seq_lens = args[4];
} }
...@@ -423,7 +423,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const ...@@ -423,7 +423,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// bias // bias
instruction_ref bias_forward = m.end(); instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = m.end(); instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias_forward = m.insert_instruction( bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
...@@ -434,7 +434,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const ...@@ -434,7 +434,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// intial hidden state // intial hidden state
instruction_ref ih_forward{}; instruction_ref ih_forward{};
instruction_ref ih_reverse{}; instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined") if(args.size() == 6 and not args[5]->is_undefined())
{ {
ih_forward = m.insert_instruction( ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
...@@ -501,14 +501,14 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const ...@@ -501,14 +501,14 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// bias // bias
instruction_ref bias = m.end(); instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias = args[3]; bias = args[3];
} }
// intial hidden state // intial hidden state
instruction_ref ih{}; instruction_ref ih{};
if(args.size() == 6 && args[5]->name() != "undefined") if(args.size() == 6 and not args[5]->is_undefined())
{ {
ih = args[5]; ih = args[5];
} }
...@@ -784,7 +784,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -784,7 +784,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process sequence length // process sequence length
instruction_ref seq_lens = m.end(); instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined") if((args.size() >= 5) and not args[4]->is_undefined())
{ {
seq_lens = args[4]; seq_lens = args[4];
} }
...@@ -813,7 +813,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -813,7 +813,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process bias // process bias
instruction_ref bias_forward = m.end(); instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = m.end(); instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias_forward = m.insert_instruction( bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
...@@ -824,7 +824,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -824,7 +824,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process intial hidden state, it is the 6th argument // process intial hidden state, it is the 6th argument
instruction_ref ih_forward{}; instruction_ref ih_forward{};
instruction_ref ih_reverse{}; instruction_ref ih_reverse{};
if(args.size() >= 6 && args[5]->name() != "undefined") if(args.size() >= 6 and not args[5]->is_undefined())
{ {
ih_forward = m.insert_instruction( ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
...@@ -840,7 +840,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -840,7 +840,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process initial cell value // process initial cell value
instruction_ref ic_forward{}; instruction_ref ic_forward{};
instruction_ref ic_reverse{}; instruction_ref ic_reverse{};
if(args.size() >= 7 && args[6]->name() != "undefined") if(args.size() >= 7 and not args[6]->is_undefined())
{ {
ic_forward = m.insert_instruction( ic_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[6]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[6]);
...@@ -856,7 +856,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -856,7 +856,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process weight of the peephole // process weight of the peephole
instruction_ref pph_forward = m.end(); instruction_ref pph_forward = m.end();
instruction_ref pph_reverse = m.end(); instruction_ref pph_reverse = m.end();
if(args.size() == 8 && args[7]->name() != "undefined") if(args.size() == 8 and not args[7]->is_undefined())
{ {
pph_forward = m.insert_instruction( pph_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[7]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[7]);
...@@ -940,14 +940,14 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -940,14 +940,14 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// bias // bias
instruction_ref bias = m.end(); instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias = args[3]; bias = args[3];
} }
// initial hidden state // initial hidden state
instruction_ref ih{}; instruction_ref ih{};
if(args.size() >= 6 && args[5]->name() != "undefined") if(args.size() >= 6 and not args[5]->is_undefined())
{ {
ih = args[5]; ih = args[5];
} }
...@@ -958,7 +958,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -958,7 +958,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// initial cell value // initial cell value
instruction_ref ic{}; instruction_ref ic{};
if(args.size() >= 7 && args[6]->name() != "undefined") if(args.size() >= 7 and not args[6]->is_undefined())
{ {
ic = args[6]; ic = args[6];
} }
...@@ -969,7 +969,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -969,7 +969,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process weight of the peephole // process weight of the peephole
instruction_ref pph = m.end(); instruction_ref pph = m.end();
if(args.size() == 8 && args[7]->name() != "undefined") if(args.size() == 8 and not args[7]->is_undefined())
{ {
pph = args[7]; pph = args[7];
} }
......
...@@ -504,6 +504,31 @@ bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max; ...@@ -504,6 +504,31 @@ bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max;
bool shape::dynamic_dimension::has_optimal() const { return opt != 0; } bool shape::dynamic_dimension::has_optimal() const { return opt != 0; }
shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const std::size_t& x)
{
this->min += x;
this->max += x;
if(this->opt != 0)
{
this->opt += x;
};
return *this;
}
shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t& x)
{
assert(this->min >= x);
assert(this->max >= x);
this->min -= x;
this->max -= x;
if(this->opt != 0)
{
assert(this->opt >= x);
this->opt -= x;
}
return *this;
}
bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{ {
// don't check opt if both are fixed // don't check opt if both are fixed
...@@ -521,6 +546,31 @@ std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x) ...@@ -521,6 +546,31 @@ std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
return os; return os;
} }
bool operator==(const shape::dynamic_dimension& x, const std::size_t& y)
{
return x.min == y and x.max == y;
}
bool operator==(const std::size_t& x, const shape::dynamic_dimension& y) { return y == x; }
bool operator!=(const shape::dynamic_dimension& x, const std::size_t& y) { return not(x == y); }
bool operator!=(const std::size_t& x, const shape::dynamic_dimension& y) { return not(x == y); }
shape::dynamic_dimension operator+(const shape::dynamic_dimension& x, const std::size_t& y)
{
auto dd = x;
return dd += y;
}
shape::dynamic_dimension operator+(const std::size_t& x, const shape::dynamic_dimension& y)
{
return y + x;
}
shape::dynamic_dimension operator-(const shape::dynamic_dimension& x, const std::size_t& y)
{
auto dd = x;
return dd -= y;
}
bool operator==(const shape& x, const shape& y) bool operator==(const shape& x, const shape& y)
{ {
if(x.dynamic() and y.dynamic()) if(x.dynamic() and y.dynamic())
......
...@@ -1065,11 +1065,23 @@ struct find_split_reshape ...@@ -1065,11 +1065,23 @@ struct find_split_reshape
return; return;
} }
// Only want to apply this optimization if each split output is followed by
// a contiguous op and a reshape
if(std::any_of(split_outputs.begin(), split_outputs.end(), [](auto i) {
if(i->outputs().size() == 1)
{
auto cont = i->outputs().front();
return cont->outputs().size() != 1;
}
return false;
}))
{
return;
}
std::vector<instruction_ref> vec_rsp(split_outputs.size()); std::vector<instruction_ref> vec_rsp(split_outputs.size());
std::transform(split_outputs.begin(), split_outputs.end(), vec_rsp.begin(), [](auto i) { std::transform(split_outputs.begin(), split_outputs.end(), vec_rsp.begin(), [](auto i) {
assert(i->outputs().size() == 1);
auto cont = i->outputs().front(); auto cont = i->outputs().front();
assert(cont->outputs().size() == 1);
return cont->outputs().front(); return cont->outputs().front();
}); });
......
...@@ -763,16 +763,23 @@ struct find_transpose_slice ...@@ -763,16 +763,23 @@ struct find_transpose_slice
// Compute axis before transpose to use for unsqueeze // Compute axis before transpose to use for unsqueeze
auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>(); auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto preaxis = std::find(perm.begin(), perm.end(), axis) - perm.begin(); auto preaxis = std::find(perm.begin(), perm.end(), axis) - perm.begin();
// Make unsqeeze // Make unsqueeze
std::vector<int64_t> steps(sdistance.size());
std::transform(
slice.axes.begin(),
slice.axes.end(),
sdistance.begin(),
steps.begin(),
[&](const auto ax, const auto sdis) { return ins->get_shape().lens().at(ax) / sdis; });
auto unsqueeze = m.insert_instruction( auto unsqueeze = m.insert_instruction(
ins, make_op("unsqueeze", {{"axes", {preaxis}}, {"steps", sdistance}}), ins->inputs()); ins, make_op("unsqueeze", {{"axes", {preaxis}}, {"steps", steps}}), ins->inputs());
// Make transpose // Make transpose
std::transform(perm.begin(), perm.end(), perm.begin(), [&](auto i) { std::transform(perm.begin(), perm.end(), perm.begin(), [&](auto i) {
if(i > preaxis) if(i >= preaxis)
return i + 1; return i + 1;
return i; return i;
}); });
perm.insert(perm.begin(), preaxis + 1); perm.insert(perm.begin(), preaxis);
auto transpose = auto transpose =
m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), unsqueeze); m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), unsqueeze);
// Slice and squeeze // Slice and squeeze
......
...@@ -185,7 +185,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -185,7 +185,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
options.push_back("-fno-gpu-rdc"); options.push_back("-fno-gpu-rdc");
options.push_back(" -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3")); options.push_back(" -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3"));
options.push_back("-Wno-cuda-compat"); options.push_back("-Wno-cuda-compat");
options.push_back("--cuda-gpu-arch=" + arch); options.push_back("--offload-arch=" + arch);
prog.compile(options); prog.compile(options);
return {prog.get_code_obj()}; return {prog.get_code_obj()};
} }
...@@ -237,7 +237,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -237,7 +237,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
} }
else if(is_hip_clang_compiler()) else if(is_hip_clang_compiler())
{ {
params += " --cuda-gpu-arch=" + arch; params += " --offload-arch=" + arch;
params += " --cuda-device-only"; params += " --cuda-device-only";
params += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " "; params += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " ";
} }
......
...@@ -196,12 +196,21 @@ argument to_gpu(const argument& arg, bool host) ...@@ -196,12 +196,21 @@ argument to_gpu(const argument& arg, bool host)
argument from_gpu(const argument& arg) argument from_gpu(const argument& arg)
{ {
argument result; argument result;
arg.visit([&](auto x) { arg.visit(
using type = typename decltype(x)::value_type; [&](auto x) {
auto v = read_from_gpu<type>(arg.data(), x.get_shape().bytes() / sizeof(type)); using type = typename decltype(x)::value_type;
// cppcheck-suppress returnDanglingLifetime auto v = read_from_gpu<type>(arg.data(), x.get_shape().bytes() / sizeof(type));
result = {x.get_shape(), [v]() mutable { return v.data(); }}; // cppcheck-suppress returnDanglingLifetime
}); result = {x.get_shape(), [v]() mutable { return v.data(); }};
},
[&](const auto& xs) {
std::vector<argument> args;
std::transform(xs.begin(), xs.end(), std::back_inserter(args), [&](auto x) {
return from_gpu(x);
});
result = argument{args};
});
return result; return result;
} }
......
...@@ -105,7 +105,7 @@ struct hip_copy_to_gpu ...@@ -105,7 +105,7 @@ struct hip_copy_to_gpu
std::string name() const { return "hip::copy_to_gpu"; } std::string name() const { return "hip::copy_to_gpu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1, 2); check_shapes{inputs, *this}.has(1, 2).same_type();
return inputs.at(0); return inputs.at(0);
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
...@@ -131,7 +131,7 @@ struct hip_copy_from_gpu ...@@ -131,7 +131,7 @@ struct hip_copy_from_gpu
std::string name() const { return "hip::copy_from_gpu"; } std::string name() const { return "hip::copy_from_gpu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1, 2); check_shapes{inputs, *this}.has(1, 2).same_type();
return inputs.at(0); return inputs.at(0);
} }
argument argument
...@@ -159,7 +159,7 @@ struct hip_copy ...@@ -159,7 +159,7 @@ struct hip_copy
std::string name() const { return "hip::copy"; } std::string name() const { return "hip::copy"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2).same_type();
return inputs.at(1); return inputs.at(1);
} }
argument compute(context& ctx, const shape&, std::vector<argument> args) const argument compute(context& ctx, const shape&, std::vector<argument> args) const
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// NOLINTNEXTLINE
static const char* const gather_kernel = R"__migraphx__(
#include <migraphx/kernels/gather.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void gather_kernel(void* in_data, void* in_indices, void* output)
{
make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
gather<${axis}>(xs...);
});
}
}
} // namespace migraphx
)__migraphx__";
struct gather_compiler : compiler<gather_compiler>
{
std::vector<std::string> names() const { return {"gather"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
const auto& out_s = inputs.back();
options.set_launch_params(v, compute_global_for(ctx, out_s.elements()));
options.inputs = inputs;
options.output = out_s;
options.kernel_name = "gather_kernel";
options.virtual_inputs = inputs;
auto axis = v.at("axis").to<std::string>();
auto src = interpolate_string(gather_kernel, {{"axis", axis}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include <migraphx/gpu/compiler.hpp> #include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/mlir.hpp> #include <migraphx/gpu/mlir.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -156,16 +156,25 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -156,16 +156,25 @@ struct reduce_compiler : compiler<reduce_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
value v = value::object{}; value v = value::object{};
auto reduce_elements = get_reduce_elements(ins->inputs());
if(op.name() == "reduce_sum") if(op.name() == "reduce_sum")
{ {
v["reduction"] = "op::sum{}"; v["reduction"] = "op::sum{}";
} }
else if(op.name() == "reduce_mean") else if(op.name() == "reduce_mean")
{ {
v["reduction"] = "op::sum{}"; auto reduce_elements = get_reduce_elements(ins->inputs());
v["write"] = "op::mean{" + std::to_string(reduce_elements) + "}"; auto reduce_type = ins->inputs().front()->get_shape().type();
v["reduction"] = "op::sum{}";
std::string mean = "op::mean{" + std::to_string(reduce_elements) + "}";
// Use float accumulator when reduction size is too large for half
if(reduce_type == shape::half_type and reduce_elements > 16384)
v["read"] = "compose(" + mean + ", op::convert_to<float>{})";
else if(contains({shape::float_type, shape::half_type, shape::double_type},
reduce_type))
v["read"] = mean;
else
v["write"] = mean;
} }
else if(op.name() == "reduce_max") else if(op.name() == "reduce_max")
{ {
......
...@@ -187,6 +187,14 @@ constexpr auto fold(F f) ...@@ -187,6 +187,14 @@ constexpr auto fold(F f)
return [=](auto&&... xs) { return fold_impl(f, static_cast<decltype(xs)&&>(xs)...); }; return [=](auto&&... xs) { return fold_impl(f, static_cast<decltype(xs)&&>(xs)...); };
} }
template <class... Fs>
constexpr auto compose(Fs... fs)
{
return fold([](auto f, auto g) {
return [=](auto&&... xs) { return f(g(static_cast<decltype(xs)>(xs)...)); };
})(fs...);
}
template <class... Ts> template <class... Ts>
constexpr auto pack(Ts... xs) constexpr auto pack(Ts... xs)
{ {
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_GATHER_HPP
#define MIGRAPHX_GUARD_KERNELS_GATHER_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/tensor_view.hpp>
namespace migraphx {
template <int Axis, class Input, class Indices>
constexpr auto gather_shape(Input input, Indices indices)
{
auto lengths = input.lens;
lengths[Axis] = indices.elements();
return make_shape(lengths, input.strides);
}
template <int Axis, class Input, class Indices, class Output>
__device__ void gather(Input input, Indices indices, Output output)
{
auto ind = make_index();
auto axis_dim_size = input.get_shape().lens[Axis];
constexpr auto out_comp = gather_shape<Axis>(get_shape_c<Input>{}, get_shape_c<Indices>{});
ind.global_stride(output.get_shape().elements(), [&](auto i) {
auto idx = out_comp.multi(i);
auto in_index = indices[idx[Axis]];
auto new_in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
idx[Axis] = new_in_index;
output[i] = input[idx];
});
}
} // namespace migraphx
#endif
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP #define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#include <migraphx/kernels/reduce.hpp> #include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp> #include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/print.hpp> #include <migraphx/kernels/print.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -132,9 +132,14 @@ MIGRAPHX_DEVICE_MATH_FOR(float, fmod, ::fmodf) ...@@ -132,9 +132,14 @@ MIGRAPHX_DEVICE_MATH_FOR(float, fmod, ::fmodf)
// Builtin half functions // Builtin half functions
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, abs, ::__habs) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, abs, ::__habs)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, ceil, ::hceil)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, cos, ::hcos)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, exp, ::hexp) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, exp, ::hexp)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, floor, ::hfloor)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, isnan, ::__hisnan)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, log, ::hlog) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, log, ::hlog)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, rsqrt, ::hrsqrt) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, rsqrt, ::hrsqrt)
// MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sin, ::hsin)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sqrt, ::hsqrt) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sqrt, ::hsqrt)
// Use float to compute half overload // Use float to compute half overload
...@@ -144,16 +149,11 @@ MIGRAPHX_DEVICE_MATH_HALF(asin, ::asin) ...@@ -144,16 +149,11 @@ MIGRAPHX_DEVICE_MATH_HALF(asin, ::asin)
MIGRAPHX_DEVICE_MATH_HALF(asinh, ::asinh) MIGRAPHX_DEVICE_MATH_HALF(asinh, ::asinh)
MIGRAPHX_DEVICE_MATH_HALF(atan, ::atan) MIGRAPHX_DEVICE_MATH_HALF(atan, ::atan)
MIGRAPHX_DEVICE_MATH_HALF(atanh, ::atanh) MIGRAPHX_DEVICE_MATH_HALF(atanh, ::atanh)
MIGRAPHX_DEVICE_MATH_HALF(ceil, ::ceil)
MIGRAPHX_DEVICE_MATH_HALF(cos, ::cos)
MIGRAPHX_DEVICE_MATH_HALF(cosh, ::cosh) MIGRAPHX_DEVICE_MATH_HALF(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf) MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf)
MIGRAPHX_DEVICE_MATH_HALF(floor, ::floor)
MIGRAPHX_DEVICE_MATH_HALF(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow) MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow)
MIGRAPHX_DEVICE_MATH_HALF(remainder, ::remainder) MIGRAPHX_DEVICE_MATH_HALF(remainder, ::remainder)
MIGRAPHX_DEVICE_MATH_HALF(round, ::round) MIGRAPHX_DEVICE_MATH_HALF(round, ::round)
MIGRAPHX_DEVICE_MATH_HALF(sin, ::sin)
MIGRAPHX_DEVICE_MATH_HALF(sinh, ::sinh) MIGRAPHX_DEVICE_MATH_HALF(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan) MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh) MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh)
...@@ -166,19 +166,19 @@ MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod) ...@@ -166,19 +166,19 @@ MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
// at this time are: exp2, exp10, log2, log10, isinf // at this time are: exp2, exp10, log2, log10, isinf
MIGRAPHX_DEVICE_MATH_HALF2(abs, ::__habs2) MIGRAPHX_DEVICE_MATH_HALF2(abs, ::__habs2)
MIGRAPHX_DEVICE_MATH_HALF2(ceil, ::h2ceil) MIGRAPHX_DEVICE_MATH_HALF2(ceil, ::h2ceil)
MIGRAPHX_DEVICE_MATH_HALF2(floor, ::h2floor)
MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin)
MIGRAPHX_DEVICE_MATH_HALF2(cos, ::h2cos) MIGRAPHX_DEVICE_MATH_HALF2(cos, ::h2cos)
MIGRAPHX_DEVICE_MATH_HALF2(exp, ::h2exp) MIGRAPHX_DEVICE_MATH_HALF2(exp, ::h2exp)
MIGRAPHX_DEVICE_MATH_HALF2(exp2, ::h2exp2)
MIGRAPHX_DEVICE_MATH_HALF2(exp10, ::h2exp10) MIGRAPHX_DEVICE_MATH_HALF2(exp10, ::h2exp10)
MIGRAPHX_DEVICE_MATH_HALF2(log2, ::h2log2) MIGRAPHX_DEVICE_MATH_HALF2(exp2, ::h2exp2)
MIGRAPHX_DEVICE_MATH_HALF2(floor, ::h2floor)
MIGRAPHX_DEVICE_MATH_HALF2(isinf, ::__hisinf2)
MIGRAPHX_DEVICE_MATH_HALF2(isnan, ::__hisnan2)
MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log) MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log)
MIGRAPHX_DEVICE_MATH_HALF2(log10, ::h2log10) MIGRAPHX_DEVICE_MATH_HALF2(log10, ::h2log10)
MIGRAPHX_DEVICE_MATH_HALF2(log2, ::h2log2)
MIGRAPHX_DEVICE_MATH_HALF2(rsqrt, ::h2rsqrt) MIGRAPHX_DEVICE_MATH_HALF2(rsqrt, ::h2rsqrt)
// MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin)
MIGRAPHX_DEVICE_MATH_HALF2(sqrt, ::h2sqrt) MIGRAPHX_DEVICE_MATH_HALF2(sqrt, ::h2sqrt)
MIGRAPHX_DEVICE_MATH_HALF2(isinf, ::__hisinf2)
MIGRAPHX_DEVICE_MATH_HALF2(isnan, ::__hisnan2)
template <class T, class U> template <class T, class U>
constexpr auto where(bool cond, const T& a, const U& b) constexpr auto where(bool cond, const T& a, const U& b)
...@@ -218,6 +218,14 @@ constexpr auto min(const T& a, const U& b) ...@@ -218,6 +218,14 @@ constexpr auto min(const T& a, const U& b)
return min<common_type_t<T, U>>(a, b); return min<common_type_t<T, U>>(a, b);
} }
// Sin for half is broken on hip, so use cos instead
template <class T, MIGRAPHX_REQUIRES(is_same<vec_type<T>, half>{})>
constexpr T sin(T x)
{
constexpr const T shift = M_PI_2;
return migraphx::cos(shift - x);
}
MIGRAPHX_DEVICE_MATH_VEC(abs) MIGRAPHX_DEVICE_MATH_VEC(abs)
MIGRAPHX_DEVICE_MATH_VEC(acos) MIGRAPHX_DEVICE_MATH_VEC(acos)
MIGRAPHX_DEVICE_MATH_VEC(acosh) MIGRAPHX_DEVICE_MATH_VEC(acosh)
......
...@@ -56,6 +56,16 @@ struct id ...@@ -56,6 +56,16 @@ struct id
} }
}; };
template <class T>
struct convert_to
{
template <class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(U x) const
{
return convert<T>(x);
}
};
struct mean struct mean
{ {
index_int item_num = 1; index_int item_num = 1;
......
...@@ -33,38 +33,6 @@ ...@@ -33,38 +33,6 @@
namespace migraphx { namespace migraphx {
template <class T>
struct implicit_conversion_op
{
T x;
template <index_int N, class U>
constexpr operator vec<U, N>() const
{
if constexpr(vec_size<T>() == 0)
{
return x;
}
else
{
static_assert(vec_size<T>() == N, "Vector mismatch size");
return __builtin_convertvector(x, vec<U, N>);
}
}
template <class U>
constexpr operator U() const
{
return x;
}
};
template <class T>
constexpr implicit_conversion_op<T> implicit_conversion(T x)
{
return {x};
}
template <class F, class T, class... Ts> template <class F, class T, class... Ts>
__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs) __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
{ {
......
...@@ -128,6 +128,7 @@ struct shape ...@@ -128,6 +128,7 @@ struct shape
result[0] = tidx; result[0] = tidx;
return result; return result;
} }
/// Convert multi-index into a single index /// Convert multi-index into a single index
constexpr index_int single(index_array idx) const constexpr index_int single(index_array idx) const
{ {
......
...@@ -185,5 +185,37 @@ constexpr auto vec_reduce(T x, Op op) ...@@ -185,5 +185,37 @@ constexpr auto vec_reduce(T x, Op op)
} }
} }
template <class T>
struct implicit_conversion_op
{
T x;
template <index_int N, class U>
constexpr operator vec<U, N>() const
{
if constexpr(vec_size<T>() == 0)
{
return x;
}
else
{
static_assert(vec_size<T>() == N, "Vector mismatch size");
return __builtin_convertvector(x, vec<U, N>);
}
}
template <class U>
constexpr operator U() const
{
return x;
}
};
template <class T>
constexpr implicit_conversion_op<T> implicit_conversion(T x)
{
return {x};
}
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP #endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
...@@ -90,7 +90,6 @@ struct miopen_apply ...@@ -90,7 +90,6 @@ struct miopen_apply
add_extend_op("argmax"); add_extend_op("argmax");
add_extend_op("argmin"); add_extend_op("argmin");
add_extend_op("gather");
add_extend_op("logsoftmax"); add_extend_op("logsoftmax");
add_extend_op("lrn"); add_extend_op("lrn");
add_extend_op("multinomial"); add_extend_op("multinomial");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment