Commit b8090620 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into rnn_optimization

parents c2db3b96 3540f1b9
...@@ -29,6 +29,7 @@ struct unsqueeze ...@@ -29,6 +29,7 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; } std::string name() const { return "unsqueeze"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard_or_scalar();
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
......
...@@ -69,7 +69,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) ...@@ -69,7 +69,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
{ {
os << x.name(); os << x.name();
char delim = '['; char delim = '[';
reflect_each(x, [&](auto& y, auto name) { reflect_each(x, [&](auto&& y, auto name) {
os << delim; os << delim;
os << name << "="; os << name << "=";
stream_write_value(os, y); stream_write_value(os, y);
...@@ -87,6 +87,8 @@ namespace operation_equal { ...@@ -87,6 +87,8 @@ namespace operation_equal {
template <class T, class U> template <class T, class U>
auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name()) auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
{ {
static_assert(is_reflectable<T>{} or sizeof(T) <= 1,
"Missing equality operator or reflect method.");
if(x.name() != y.name()) if(x.name() != y.name())
return false; return false;
const auto& yy = any_cast<T>(y); const auto& yy = any_cast<T>(y);
......
...@@ -11,9 +11,11 @@ ...@@ -11,9 +11,11 @@
#include <migraphx/op/batch_norm.hpp> #include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/binary.hpp> #include <migraphx/op/binary.hpp>
#include <migraphx/op/broadcast.hpp> #include <migraphx/op/broadcast.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/op/concat.hpp> #include <migraphx/op/concat.hpp>
#include <migraphx/op/contiguous.hpp> #include <migraphx/op/contiguous.hpp>
#include <migraphx/op/convert.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/cosh.hpp> #include <migraphx/op/cosh.hpp>
#include <migraphx/op/cos.hpp> #include <migraphx/op/cos.hpp>
......
#ifndef MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
inline std::size_t calculate_padding(std::size_t weight_dim, std::size_t dilation)
{
return (dilation * (weight_dim - 1)) / 2;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZATION_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZATION_HPP
#include <string>
#include <vector>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
void quantize(program& prog, const std::vector<std::string>& ins_names);
void quantize(program& prog);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -27,7 +27,8 @@ struct raw_data : raw_data_base ...@@ -27,7 +27,8 @@ struct raw_data : raw_data_base
template <class Stream> template <class Stream>
friend Stream& operator<<(Stream& os, const Derived& d) friend Stream& operator<<(Stream& os, const Derived& d)
{ {
d.visit([&](auto x) { os << x; }); if(not d.empty())
d.visit([&](auto x) { os << x; });
return os; return os;
} }
...@@ -40,8 +41,11 @@ struct raw_data : raw_data_base ...@@ -40,8 +41,11 @@ struct raw_data : raw_data_base
template <class Visitor> template <class Visitor>
void visit_at(Visitor v, std::size_t n = 0) const void visit_at(Visitor v, std::size_t n = 0) const
{ {
auto&& s = static_cast<const Derived&>(*this).get_shape(); auto&& derived = static_cast<const Derived&>(*this);
auto&& buffer = static_cast<const Derived&>(*this).data(); if(derived.empty())
MIGRAPHX_THROW("Visiting empty data!");
auto&& s = derived.get_shape();
auto&& buffer = derived.data();
s.visit_type([&](auto as) { v(*(as.from(buffer) + s.index(n))); }); s.visit_type([&](auto as) { v(*(as.from(buffer) + s.index(n))); });
} }
...@@ -55,8 +59,11 @@ struct raw_data : raw_data_base ...@@ -55,8 +59,11 @@ struct raw_data : raw_data_base
template <class Visitor> template <class Visitor>
void visit(Visitor v) const void visit(Visitor v) const
{ {
auto&& s = static_cast<const Derived&>(*this).get_shape(); auto&& derived = static_cast<const Derived&>(*this);
auto&& buffer = static_cast<const Derived&>(*this).data(); if(derived.empty())
MIGRAPHX_THROW("Visiting empty data!");
auto&& s = derived.get_shape();
auto&& buffer = derived.data();
s.visit_type([&](auto as) { v(make_view(s, as.from(buffer))); }); s.visit_type([&](auto as) { v(make_view(s, as.from(buffer))); });
} }
......
...@@ -11,6 +11,15 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,6 +11,15 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace detail { namespace detail {
struct reflect_placeholder
{
template <class... Ts>
int operator()(Ts&&...) const
{
return 0;
}
};
template <class T, class Selector> template <class T, class Selector>
auto reflect_impl(rank<1>, T& x, Selector f) -> decltype(T::reflect(x, f)) auto reflect_impl(rank<1>, T& x, Selector f) -> decltype(T::reflect(x, f))
{ {
...@@ -23,8 +32,53 @@ auto reflect_impl(rank<0>, T&, Selector) ...@@ -23,8 +32,53 @@ auto reflect_impl(rank<0>, T&, Selector)
return pack(); return pack();
} }
template <class T>
auto reflectable_impl(rank<1>, T&& x)
-> decltype(T::reflect(x, reflect_placeholder{}), std::true_type{});
template <class T>
auto reflectable_impl(rank<0>, T &&) -> decltype(std::false_type{});
template <class T>
struct remove_rvalue_reference
{
using type = T;
};
template <class T>
struct remove_rvalue_reference<T&&>
{
using type = T;
};
template <class T>
struct wrapper
{
using type = typename remove_rvalue_reference<T>::type;
type data;
type get() const { return data; }
};
template <class T>
wrapper<T> wrap(std::remove_reference_t<T>& x)
{
return wrapper<T>{std::forward<T>(x)};
}
template <class... Ts>
using auto_tuple_t = std::tuple<typename remove_rvalue_reference<Ts>::type...>;
template <class... Ts>
auto_tuple_t<Ts...> auto_tuple(Ts&&... xs)
{
return auto_tuple_t<Ts...>{std::forward<Ts>(xs)...};
}
} // namespace detail } // namespace detail
template <class T>
using is_reflectable = decltype(detail::reflectable_impl(rank<1>{}, std::declval<T>()));
template <class T, class Selector> template <class T, class Selector>
auto reflect(T& x, Selector f) auto reflect(T& x, Selector f)
{ {
...@@ -34,17 +88,18 @@ auto reflect(T& x, Selector f) ...@@ -34,17 +88,18 @@ auto reflect(T& x, Selector f)
template <class T> template <class T>
auto reflect_tie(T& x) auto reflect_tie(T& x)
{ {
return reflect(x, [](auto&& y, auto&&...) { return std::ref(y); })( return reflect(x, [](auto&& y, auto&&...) { return detail::wrap<decltype(y)>(y); })(
[](auto&&... xs) { return std::tie(xs.get()...); }); [](auto&&... xs) { return detail::auto_tuple(xs.get()...); });
} }
template <class T, class F> template <class T, class F>
void reflect_each(T& x, F f) void reflect_each(T& x, F f)
{ {
return reflect(x, [](auto&& y, auto... ys) { return pack(std::ref(y), ys...); })( return reflect(x, [](auto&& y, auto... ys) {
[&](auto&&... xs) { return pack(detail::wrap<decltype(y)>(y), ys...);
each_args([&](auto p) { p([&](auto&& y, auto... ys) { f(y.get(), ys...); }); }, xs...); })([&](auto&&... xs) {
}); each_args([&](auto p) { p([&](auto&& y, auto... ys) { f(y.get(), ys...); }); }, xs...);
});
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -12,6 +12,14 @@ ...@@ -12,6 +12,14 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class T>
T as_number(T x)
{
return x;
}
inline int32_t as_number(int8_t x) { return static_cast<int32_t>(x); }
inline uint32_t as_number(uint8_t x) { return static_cast<uint32_t>(x); }
template <class T> template <class T>
struct tensor_view struct tensor_view
{ {
...@@ -130,10 +138,10 @@ struct tensor_view ...@@ -130,10 +138,10 @@ struct tensor_view
{ {
if(!x.empty()) if(!x.empty())
{ {
os << x.front(); os << as_number(x.front());
for(std::size_t i = 1; i < x.m_shape.elements(); i++) for(std::size_t i = 1; i < x.m_shape.elements(); i++)
{ {
os << ", " << x.m_data[x.m_shape.index(i)]; os << ", " << as_number(x.m_data[x.m_shape.index(i)]);
} }
} }
return os; return os;
......
...@@ -28,6 +28,12 @@ void instruction::replace(const shape& r) ...@@ -28,6 +28,12 @@ void instruction::replace(const shape& r)
} }
} }
void instruction::replace(operation o)
{
op = std::move(o);
recompute_shape();
}
void instruction::recompute_shape() { replace(compute_shape(op, arguments)); } void instruction::recompute_shape() { replace(compute_shape(op, arguments)); }
void instruction::clear_arguments() void instruction::clear_arguments()
......
...@@ -63,6 +63,7 @@ struct onnx_parser ...@@ -63,6 +63,7 @@ struct onnx_parser
add_variadic_op("Max", op::max{}); add_variadic_op("Max", op::max{});
add_variadic_op("Min", op::min{}); add_variadic_op("Min", op::min{});
add_mem_op("Clip", &onnx_parser::parse_clip);
add_mem_op("LRN", &onnx_parser::parse_lrn); add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler); add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu); add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
...@@ -225,6 +226,22 @@ struct onnx_parser ...@@ -225,6 +226,22 @@ struct onnx_parser
}); });
} }
instruction_ref parse_clip(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
op::clip op;
if(contains(attributes, "max"))
{
op.max_val = parse_value(attributes.at("max")).at<float>();
}
if(contains(attributes, "min"))
{
op.min_val = parse_value(attributes.at("min")).at<float>();
}
return prog.add_instruction(op, std::move(args));
}
instruction_ref instruction_ref
parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{ {
......
...@@ -177,7 +177,7 @@ void memory_coloring_impl::build() ...@@ -177,7 +177,7 @@ void memory_coloring_impl::build()
void memory_coloring_impl::rewrite() void memory_coloring_impl::rewrite()
{ {
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
dims.push_back(required_bytes / sizeof(float)); dims.push_back((required_bytes + sizeof(float) - 1) / sizeof(float));
shape s = {shape::float_type, dims}; shape s = {shape::float_type, dims};
instruction_ref scratch_param = p_program->add_parameter("scratch", s); instruction_ref scratch_param = p_program->add_parameter("scratch", s);
for(auto ins : iterator_for(*p_program)) for(auto ins : iterator_for(*p_program))
......
...@@ -63,11 +63,16 @@ static void print_program(const program& p, F print_func) ...@@ -63,11 +63,16 @@ static void print_program(const program& p, F print_func)
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
std::string var_name = "@" + std::to_string(count); std::string var_name;
if(ins->name() == "@param") if(ins->name() == "@param")
{ {
var_name = any_cast<builtin::param>(ins->get_operator()).parameter; var_name = any_cast<builtin::param>(ins->get_operator()).parameter;
} }
else
{
var_name = "@" + std::to_string(count);
count++;
}
names.emplace(ins, var_name); names.emplace(ins, var_name);
// TODO: Use all_of // TODO: Use all_of
...@@ -78,8 +83,6 @@ static void print_program(const program& p, F print_func) ...@@ -78,8 +83,6 @@ static void print_program(const program& p, F print_func)
} }
print_func(ins, names); print_func(ins, names);
count++;
} }
} }
...@@ -434,13 +437,20 @@ argument program::eval(std::unordered_map<std::string, argument> params) const ...@@ -434,13 +437,20 @@ argument program::eval(std::unordered_map<std::string, argument> params) const
#else #else
auto check_context = [](auto f) { return f(); }; auto check_context = [](auto f) { return f(); };
#endif #endif
if(enabled(MIGRAPHX_TRACE_EVAL{}))
auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{});
if(trace_level > 0)
{ {
return generic_eval(*this, ctx, std::move(params), [&](auto& ins, auto f) { return generic_eval(*this, ctx, std::move(params), [&](auto& ins, auto f) {
ctx.finish(); ctx.finish();
std::cout << "Run instruction: "; std::cout << "Run instruction: ";
this->debug_print(ins); this->debug_print(ins);
return check_context(f); auto result = check_context(f);
ctx.finish();
if(trace_level > 1 and ins->name().front() != '@' and ins->name() != "load")
std::cout << "Ouput: " << result << std::endl;
return result;
}); });
} }
else else
......
...@@ -12,12 +12,7 @@ if(MIGRAPHX_ENABLE_PYTHON) ...@@ -12,12 +12,7 @@ if(MIGRAPHX_ENABLE_PYTHON)
C_VISIBILITY_PRESET hidden C_VISIBILITY_PRESET hidden
CXX_VISIBILITY_PRESET hidden CXX_VISIBILITY_PRESET hidden
) )
if(MIGRAPHX_ENABLE_TF) target_link_libraries(migraphx_py PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_cpu)
target_link_libraries(migraphx_py PRIVATE migraphx migraphx_tf migraphx_cpu)
target_compile_definitions(migraphx_py PRIVATE -DENABLE_TF)
else()
target_link_libraries(migraphx_py PRIVATE migraphx migraphx_onnx migraphx_cpu)
endif()
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
target_link_libraries(migraphx_py PRIVATE migraphx_gpu) target_link_libraries(migraphx_py PRIVATE migraphx_gpu)
target_compile_definitions(migraphx_py PRIVATE -DHAVE_GPU) target_compile_definitions(migraphx_py PRIVATE -DHAVE_GPU)
......
...@@ -2,14 +2,12 @@ ...@@ -2,14 +2,12 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#ifdef ENABLE_TF
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#else
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#endif
#ifdef HAVE_GPU #ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
...@@ -160,16 +158,13 @@ PYBIND11_MODULE(migraphx, m) ...@@ -160,16 +158,13 @@ PYBIND11_MODULE(migraphx, m)
.def("__ne__", std::not_equal_to<migraphx::program>{}) .def("__ne__", std::not_equal_to<migraphx::program>{})
.def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); }); .def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); });
#ifdef ENABLE_TF
m.def("parse_tf", m.def("parse_tf",
&migraphx::parse_tf, &migraphx::parse_tf,
"Parse tf protobuf (default format is nhwc)", "Parse tf protobuf (default format is nhwc)",
py::arg("filename"), py::arg("filename"),
py::arg("is_nhwc") = true); py::arg("is_nhwc") = true);
#else
m.def("parse_onnx", &migraphx::parse_onnx); m.def("parse_onnx", &migraphx::parse_onnx);
#endif
m.def("get_target", [](const std::string& name) -> migraphx::target { m.def("get_target", [](const std::string& name) -> migraphx::target {
if(name == "cpu") if(name == "cpu")
return migraphx::cpu::target{}; return migraphx::cpu::target{};
...@@ -181,6 +176,10 @@ PYBIND11_MODULE(migraphx, m) ...@@ -181,6 +176,10 @@ PYBIND11_MODULE(migraphx, m)
}); });
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0); m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
m.def("quantize", [](migraphx::program& p, std::vector<std::string>& ins_names) {
migraphx::quantize(p, ins_names);
});
m.def("quantize", [](migraphx::program& p) { migraphx::quantize(p, {"all"}); });
#ifdef HAVE_GPU #ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false); m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
......
#include <migraphx/quantization.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/convert.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
instruction_ref insert_fp16(program& prog,
instruction_ref& ins,
shape::type_t type,
std::unordered_map<instruction_ref, instruction_ref>& map_fp16)
{
if(map_fp16.count(ins) > 0)
{
return map_fp16[ins];
}
assert(ins->get_shape().type() == shape::float_type ||
ins->get_shape().type() == shape::double_type);
instruction_ref ins_fp16{};
ins_fp16 = prog.insert_instruction(std::next(ins), op::convert{type}, ins);
map_fp16[ins] = ins_fp16;
return ins_fp16;
}
void quantize(program& prog, const std::vector<std::string>& ins_names)
{
std::unordered_map<instruction_ref, instruction_ref> map_fp16;
for(auto ins : iterator_for(prog))
{
// all indicates every instruction is converted
if((not contains(ins_names, "all")) and (not contains(ins_names, ins->name())))
{
continue;
}
shape::type_t orig_type = ins->get_shape().type();
// process all inputs, if input is a fp32 or fp64, convert it
// to a fp16 by adding a convert operator.
auto inputs = ins->inputs();
std::vector<instruction_ref> converted_inputs;
for(auto input : inputs)
{
auto s = input->get_shape();
if(s.type() == shape::float_type || s.type() == shape::double_type)
{
// if the input is a convert operator, uses its input
// as its current input
instruction_ref input_fp16{};
if(input->name() == "convert")
{
input_fp16 = input->inputs().front();
}
else
{
input_fp16 = insert_fp16(prog, input, shape::half_type, map_fp16);
}
converted_inputs.push_back(input_fp16);
}
else
{
converted_inputs.push_back(input);
}
}
// no change for the input, go to the next instruction
if(inputs == converted_inputs)
{
continue;
}
auto op = ins->get_operator();
auto ins_shape = compute_shape(op, converted_inputs);
if(ins_shape.type() != orig_type)
{
// insert another convert instruction to convert it back
if(ins == std::prev(prog.end()))
{
prog.add_instruction(op::convert{orig_type}, ins);
}
else
{
// check the dead code case to avoid assert
bool output_empty = ins->outputs().empty();
auto ins_orig_type =
prog.insert_instruction(std::next(ins), op::convert{orig_type}, ins);
if(!output_empty)
{
prog.replace_instruction(ins, ins_orig_type);
}
}
}
prog.replace_instruction(ins, op, converted_inputs);
}
}
void quantize(program& prog) { quantize(prog, {"all"}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -14,7 +14,9 @@ bool is_reshaper(instruction_ref ins) ...@@ -14,7 +14,9 @@ bool is_reshaper(instruction_ref ins)
// clang-format off // clang-format off
static const std::unordered_set<std::string> names = { static const std::unordered_set<std::string> names = {
"reshape", "reshape",
"contiguous" "contiguous",
"squeeze",
"unsqueeze"
}; };
// clang-format on // clang-format on
return contains(names, ins->name()); return contains(names, ins->name());
...@@ -45,6 +47,9 @@ void simplify_reshapes::apply(program& p) const ...@@ -45,6 +47,9 @@ void simplify_reshapes::apply(program& p) const
auto end = std::prev(p.end()); auto end = std::prev(p.end());
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(ins == end and ins->name() == "contiguous")
continue;
// Skip possible dead instructions
if(ins->outputs().empty() and ins != end) if(ins->outputs().empty() and ins != end)
continue; continue;
if(is_reshaper(ins)) if(is_reshaper(ins))
...@@ -94,13 +99,6 @@ void simplify_reshapes::apply(program& p) const ...@@ -94,13 +99,6 @@ void simplify_reshapes::apply(program& p) const
p.replace_instruction(ins, t->inputs().front()); p.replace_instruction(ins, t->inputs().front());
} }
} }
// Replace all reshapes with as_shape
for(auto ins : iterator_for(p))
{
if(ins->name() != "reshape")
continue;
p.replace_instruction(ins, op::as_shape{ins->get_shape()}, ins->inputs());
}
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
This diff is collapsed.
...@@ -27,11 +27,14 @@ add_library(migraphx_device ...@@ -27,11 +27,14 @@ add_library(migraphx_device
device/add_relu.cpp device/add_relu.cpp
device/contiguous.cpp device/contiguous.cpp
device/logsoftmax.cpp device/logsoftmax.cpp
device/softmax.cpp
device/convert.cpp
device/mul.cpp device/mul.cpp
device/concat.cpp device/concat.cpp
device/pad.cpp device/pad.cpp
device/gather.cpp device/gather.cpp
device/sub.cpp device/sub.cpp
device/clip.cpp
) )
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device) set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_clang_tidy_check(migraphx_device) rocm_clang_tidy_check(migraphx_device)
...@@ -66,6 +69,7 @@ add_library(migraphx_gpu ...@@ -66,6 +69,7 @@ add_library(migraphx_gpu
lrn.cpp lrn.cpp
schedule_model.cpp schedule_model.cpp
adjust_allocation.cpp adjust_allocation.cpp
clip.cpp
) )
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu) set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
rocm_clang_tidy_check(migraphx_gpu) rocm_clang_tidy_check(migraphx_gpu)
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <algorithm>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/clip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_clip::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.compute_shape(inputs);
}
argument hip_clip::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::clip(ctx.get_stream().get(), args.back(), args.front(), op.max_val, op.min_val);
return args.back();
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment