Unverified Commit 32d69e8e authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'develop' into simplify_1_mul_div_ops

parents 8398fb19 bab9502a
...@@ -61,8 +61,6 @@ check_type_size("half_float::detail::expr" HALF_EXPR LANGUAGE CXX) ...@@ -61,8 +61,6 @@ check_type_size("half_float::detail::expr" HALF_EXPR LANGUAGE CXX)
set(CMAKE_REQUIRED_INCLUDES) set(CMAKE_REQUIRED_INCLUDES)
set(CMAKE_EXTRA_INCLUDE_FILES) set(CMAKE_EXTRA_INCLUDE_FILES)
find_package(nlohmann_json 3.8.0 REQUIRED)
include(ROCMSetupVersion) include(ROCMSetupVersion)
rocm_setup_version(VERSION 2.3) rocm_setup_version(VERSION 2.3)
......
...@@ -77,7 +77,7 @@ RUN cget -p $PREFIX install ccache@v4.1 ...@@ -77,7 +77,7 @@ RUN cget -p $PREFIX install ccache@v4.1
RUN cget -p /opt/cmake install kitware/cmake@v3.13.4 RUN cget -p /opt/cmake install kitware/cmake@v3.13.4
ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
ARG ONNXRUNTIME_BRANCH=master ARG ONNXRUNTIME_BRANCH=main
ARG ONNXRUNTIME_COMMIT=24f1bd6156cf5968bbc76dfb0e801a9b9c56b9fc ARG ONNXRUNTIME_COMMIT=24f1bd6156cf5968bbc76dfb0e801a9b9c56b9fc
RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime && \ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime && \
cd onnxruntime && \ cd onnxruntime && \
...@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR ...@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@26a4b3cfc0a1a15181490f24ae461608fef1b04e -DBUILD_MIXR_TARGET=On RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@d2cb9e580550e92ab75a0a417e7a4abd02a24edf -DBUILD_MIXR_TARGET=On
ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
......
...@@ -46,6 +46,7 @@ The following is a list of prerequisites required to build MIGraphX source. ...@@ -46,6 +46,7 @@ The following is a list of prerequisites required to build MIGraphX source.
* [pybind11](https://pybind11.readthedocs.io/en/stable/) - for python bindings * [pybind11](https://pybind11.readthedocs.io/en/stable/) - for python bindings
* [JSON](https://github.com/nlohmann/json) - for model serialization to json string format * [JSON](https://github.com/nlohmann/json) - for model serialization to json string format
* [MessagePack](https://msgpack.org/index.html) - for model serialization to binary format * [MessagePack](https://msgpack.org/index.html) - for model serialization to binary format
* [SQLite3](https://www.sqlite.org/index.html) - to create database of kernels' tuning information or execute queries on existing database
#### Use the ROCm build tool [rbuild](https://github.com/RadeonOpenCompute/rbuild). #### Use the ROCm build tool [rbuild](https://github.com/RadeonOpenCompute/rbuild).
......
File mode changed from 100755 to 100644
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
*/ */
#include <algorithm> #include <algorithm>
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <rocblas.h> #include <rocblas/rocblas.h>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp> // MIGraphX's C++ API #include <migraphx/migraphx.hpp> // MIGraphX's C++ API
#include <numeric> #include <numeric>
...@@ -56,11 +56,13 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base ...@@ -56,11 +56,13 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base
migraphx::arguments args) const override migraphx::arguments args) const override
{ {
// create rocblas stream handle // create rocblas stream handle
auto rocblas_handle = create_rocblas_handle_ptr(ctx); auto rb_handle = create_rocblas_handle_ptr(ctx);
rocblas_int n = args[1].get_shape().lengths()[0]; MIGRAPHX_ROCBLAS_ASSERT(rocblas_set_pointer_mode(rb_handle, rocblas_pointer_mode_device));
float* alpha = reinterpret_cast<float*>(args[0].data()); rocblas_int n = args[1].get_shape().lengths()[0];
float* vec_ptr = reinterpret_cast<float*>(args[1].data()); float* alpha = reinterpret_cast<float*>(args[0].data());
MIGRAPHX_ROCBLAS_ASSERT(rocblas_sscal(rocblas_handle, n, alpha, vec_ptr, 1)); float* vec_ptr = reinterpret_cast<float*>(args[1].data());
MIGRAPHX_ROCBLAS_ASSERT(rocblas_sscal(rb_handle, n, alpha, vec_ptr, 1));
MIGRAPHX_ROCBLAS_ASSERT(rocblas_destroy_handle(rb_handle));
return args[1]; return args[1];
} }
......
...@@ -27,3 +27,4 @@ live-clones/blaze@v3.8 -X header -DHEADER_DIR=blaze -H sha256:d0ff011f4753828517 ...@@ -27,3 +27,4 @@ live-clones/blaze@v3.8 -X header -DHEADER_DIR=blaze -H sha256:d0ff011f4753828517
half,https://github.com/pfultz2/half/archive/1.12.0.tar.gz -X header -H sha256:0a08660b68abb176ebc2a0cdf8de46e3182a7f46c66443bb80dbfaaec98cf969 half,https://github.com/pfultz2/half/archive/1.12.0.tar.gz -X header -H sha256:0a08660b68abb176ebc2a0cdf8de46e3182a7f46c66443bb80dbfaaec98cf969
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On
...@@ -65,6 +65,7 @@ add_library(migraphx ...@@ -65,6 +65,7 @@ add_library(migraphx
operation.cpp operation.cpp
opt/memory_coloring.cpp opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp opt/memory_coloring_impl.cpp
pad_calc.cpp
pass_manager.cpp pass_manager.cpp
permutation.cpp permutation.cpp
preallocate_param.cpp preallocate_param.cpp
...@@ -79,6 +80,7 @@ add_library(migraphx ...@@ -79,6 +80,7 @@ add_library(migraphx
register_target.cpp register_target.cpp
replace_allocate.cpp replace_allocate.cpp
simplify_qdq.cpp simplify_qdq.cpp
sqlite.cpp
rewrite_batchnorm.cpp rewrite_batchnorm.cpp
rewrite_pooling.cpp rewrite_pooling.cpp
rewrite_quantization.cpp rewrite_quantization.cpp
...@@ -134,6 +136,7 @@ register_migraphx_ops( ...@@ -134,6 +136,7 @@ register_migraphx_ops(
exp exp
flatten flatten
floor floor
fmod
gather gather
gathernd gathernd
get_tuple_elem get_tuple_elem
...@@ -156,6 +159,7 @@ register_migraphx_ops( ...@@ -156,6 +159,7 @@ register_migraphx_ops(
lstm lstm
max max
min min
mod
mul mul
multibroadcast multibroadcast
multinomial multinomial
...@@ -239,6 +243,13 @@ target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLU ...@@ -239,6 +243,13 @@ target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLU
find_package(Threads) find_package(Threads)
target_link_libraries(migraphx PUBLIC Threads::Threads) target_link_libraries(migraphx PUBLIC Threads::Threads)
find_package(nlohmann_json 3.8.0 REQUIRED)
target_link_libraries(migraphx PRIVATE nlohmann_json::nlohmann_json)
find_package(PkgConfig)
pkg_check_modules(SQLITE3 REQUIRED IMPORTED_TARGET sqlite3)
target_link_libraries(migraphx PRIVATE PkgConfig::SQLITE3)
find_package(msgpack REQUIRED) find_package(msgpack REQUIRED)
target_link_libraries(migraphx PRIVATE msgpackc-cxx) target_link_libraries(migraphx PRIVATE msgpackc-cxx)
# Make this available to the tests # Make this available to the tests
......
...@@ -63,7 +63,7 @@ void auto_contiguous::apply(module& m) const ...@@ -63,7 +63,7 @@ void auto_contiguous::apply(module& m) const
if(ins->outputs().empty() and ins != last) if(ins->outputs().empty() and ins != last)
continue; continue;
shape s = ins->get_shape(); shape s = ins->get_shape();
if(not s.standard() and s.elements() != 0) if(not s.dynamic() and not s.standard() and s.elements() != 0)
{ {
auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins); auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins);
m.replace_instruction(ins, c); m.replace_instruction(ins, c);
......
...@@ -48,9 +48,10 @@ void dead_code_elimination::apply(module& m) const ...@@ -48,9 +48,10 @@ void dead_code_elimination::apply(module& m) const
// Skip the last instruction // Skip the last instruction
if(i == last) if(i == last)
break; break;
// Skip instruction with empty shape as output unless its a builtin, undefined, identity, or // Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// allocate // identity, allocate]
if(i->get_shape().elements() == 0 and i->name().front() != '@' and if((not i->get_shape().dynamic() and i->get_shape().elements() == 0) and
i->name().front() != '@' and
not contains({"undefined", "identity", "allocate"}, i->name())) not contains({"undefined", "identity", "allocate"}, i->name()))
continue; continue;
assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last)); assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last));
......
...@@ -27,11 +27,13 @@ ...@@ -27,11 +27,13 @@
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <list>
#include <set> #include <set>
#include <string> #include <string>
#include <sstream> #include <sstream>
#include <type_traits> #include <type_traits>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -39,9 +41,16 @@ ...@@ -39,9 +41,16 @@
#include <migraphx/requires.hpp> #include <migraphx/requires.hpp>
#include <migraphx/type_name.hpp> #include <migraphx/type_name.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#ifndef _WIN32
#include <unistd.h>
#endif
namespace migraphx { namespace migraphx {
namespace driver { namespace driver {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -74,6 +83,65 @@ template <class T> ...@@ -74,6 +83,65 @@ template <class T>
using is_multi_value = using is_multi_value =
std::integral_constant<bool, (is_container<T>{} and not std::is_convertible<T, std::string>{})>; std::integral_constant<bool, (is_container<T>{} and not std::is_convertible<T, std::string>{})>;
enum class color
{
reset = 0,
bold = 1,
underlined = 4,
fg_red = 31,
fg_green = 32,
fg_yellow = 33,
fg_blue = 34,
fg_default = 39,
bg_red = 41,
bg_green = 42,
bg_yellow = 43,
bg_blue = 44,
bg_default = 49
};
inline std::ostream& operator<<(std::ostream& os, const color& c)
{
#ifndef _WIN32
static const bool use_color = isatty(STDOUT_FILENO) != 0;
if(use_color)
return os << "\033[" << static_cast<std::size_t>(c) << "m";
#endif
return os;
}
inline std::string colorize(color c, const std::string& s)
{
std::stringstream ss;
ss << c << s << color::reset;
return ss.str();
}
template <class T>
struct type_name
{
static const std::string& apply() { return migraphx::get_type_name<T>(); }
};
template <>
struct type_name<std::string>
{
static const std::string& apply()
{
static const std::string name = "std::string";
return name;
}
};
template <class T>
struct type_name<std::vector<T>>
{
static const std::string& apply()
{
static const std::string name = "std::vector<" + type_name<T>::apply() + ">";
return name;
}
};
template <class T> template <class T>
struct value_parser struct value_parser
{ {
...@@ -85,7 +153,7 @@ struct value_parser ...@@ -85,7 +153,7 @@ struct value_parser
ss.str(x); ss.str(x);
ss >> result; ss >> result;
if(ss.fail()) if(ss.fail())
throw std::runtime_error("Failed to parse: " + x); throw std::runtime_error("Failed to parse '" + x + "' as " + type_name<T>::apply());
return result; return result;
} }
...@@ -97,7 +165,7 @@ struct value_parser ...@@ -97,7 +165,7 @@ struct value_parser
ss.str(x); ss.str(x);
ss >> i; ss >> i;
if(ss.fail()) if(ss.fail())
throw std::runtime_error("Failed to parse: " + x); throw std::runtime_error("Failed to parse '" + x + "' as " + type_name<T>::apply());
return static_cast<T>(i); return static_cast<T>(i);
} }
...@@ -115,13 +183,42 @@ struct argument_parser ...@@ -115,13 +183,42 @@ struct argument_parser
{ {
struct argument struct argument
{ {
using action_function =
std::function<bool(argument_parser&, const std::vector<std::string>&)>;
using validate_function =
std::function<void(const argument_parser&, const std::vector<std::string>&)>;
std::vector<std::string> flags; std::vector<std::string> flags;
std::function<bool(argument_parser&, const std::vector<std::string>&)> action{}; action_function action{};
std::string type = ""; std::string type = "";
std::string help = ""; std::string help = "";
std::string metavar = ""; std::string metavar = "";
std::string default_value = ""; std::string default_value = "";
std::string group = "";
unsigned nargs = 1; unsigned nargs = 1;
bool required = false;
std::vector<validate_function> validations{};
std::string usage(const std::string& flag) const
{
std::stringstream ss;
if(flag.empty())
{
ss << metavar;
}
else
{
ss << flag;
if(not type.empty())
ss << " [" << type << "]";
}
return ss.str();
}
std::string usage() const
{
if(flags.empty())
return usage("");
return usage(flags.front());
}
}; };
template <class T, MIGRAPHX_REQUIRES(is_multi_value<T>{})> template <class T, MIGRAPHX_REQUIRES(is_multi_value<T>{})>
...@@ -154,12 +251,14 @@ struct argument_parser ...@@ -154,12 +251,14 @@ struct argument_parser
arguments.push_back({flags, [&](auto&&, const std::vector<std::string>& params) { arguments.push_back({flags, [&](auto&&, const std::vector<std::string>& params) {
if(params.empty()) if(params.empty())
throw std::runtime_error("Flag with no value."); throw std::runtime_error("Flag with no value.");
if(not is_multi_value<T>{} and params.size() > 1)
throw std::runtime_error("Too many arguments passed.");
x = value_parser<T>::apply(params.back()); x = value_parser<T>::apply(params.back());
return false; return false;
}}); }});
argument& arg = arguments.back(); argument& arg = arguments.back();
arg.type = migraphx::get_type_name<T>(); arg.type = type_name<T>::apply();
migraphx::each_args([&](auto f) { f(x, arg); }, fs...); migraphx::each_args([&](auto f) { f(x, arg); }, fs...);
if(not arg.default_value.empty() and arg.nargs > 0) if(not arg.default_value.empty() and arg.nargs > 0)
arg.default_value = as_string_value(x); arg.default_value = as_string_value(x);
...@@ -181,6 +280,11 @@ struct argument_parser ...@@ -181,6 +280,11 @@ struct argument_parser
return [=](auto&&, auto& arg) { arg.nargs = n; }; return [=](auto&&, auto& arg) { arg.nargs = n; };
} }
MIGRAPHX_DRIVER_STATIC auto required()
{
return [=](auto&&, auto& arg) { arg.required = true; };
}
template <class F> template <class F>
MIGRAPHX_DRIVER_STATIC auto write_action(F f) MIGRAPHX_DRIVER_STATIC auto write_action(F f)
{ {
...@@ -215,34 +319,164 @@ struct argument_parser ...@@ -215,34 +319,164 @@ struct argument_parser
}); });
} }
MIGRAPHX_DRIVER_STATIC auto show_help(const std::string& msg = "") template <class F>
MIGRAPHX_DRIVER_STATIC auto validate(F f)
{
return [=](const auto& x, auto& arg) {
arg.validations.push_back(
[&, f](auto& self, const std::vector<std::string>& params) { f(self, x, params); });
};
}
MIGRAPHX_DRIVER_STATIC auto file_exist()
{
return validate([](auto&, auto&, auto& params) {
if(params.empty())
throw std::runtime_error("No argument passed.");
if(not fs::exists(params.back()))
throw std::runtime_error("Path does not exists: " + params.back());
});
}
template <class F>
argument* find_argument(F f)
{
auto it = std::find_if(arguments.begin(), arguments.end(), f);
if(it == arguments.end())
return nullptr;
return std::addressof(*it);
}
template <class F>
bool has_argument(F f)
{
return find_argument(f) != nullptr;
}
template <class F>
std::vector<argument*> find_arguments(F f)
{
std::vector<argument*> result;
for(auto& arg : arguments)
{
if(not f(arg))
continue;
result.push_back(&arg);
}
return result;
}
std::vector<argument*> get_group_arguments(const std::string& group)
{
return find_arguments([&](const auto& arg) { return arg.group == group; });
}
std::vector<argument*> get_required_arguments()
{
return find_arguments([&](const auto& arg) { return arg.required; });
}
template <class SequenceContainer>
std::vector<std::string> get_argument_usages(SequenceContainer args)
{
std::vector<std::string> usage_flags;
std::unordered_set<std::string> found_groups;
// Remove arguments that belong to a group
auto it = std::remove_if(args.begin(), args.end(), [&](const argument* arg) {
if(arg->group.empty())
return false;
found_groups.insert(arg->group);
return true;
});
args.erase(it, args.end());
transform(found_groups, std::back_inserter(usage_flags), [&](auto&& group) {
std::vector<std::string> either_flags;
transform(get_group_arguments(group), std::back_inserter(either_flags), [](auto* arg) {
return arg->usage();
});
return "(" + join_strings(either_flags, "|") + ")";
});
transform(args, std::back_inserter(usage_flags), [&](auto* arg) { return arg->usage(); });
return usage_flags;
}
auto show_help(const std::string& msg = "")
{ {
return do_action([=](auto& self) { return do_action([=](auto& self) {
for(auto&& arg : self.arguments) argument* input_argument =
self.find_argument([](const auto& arg) { return arg.flags.empty(); });
auto required_usages = get_argument_usages(get_required_arguments());
if(required_usages.empty() && input_argument)
required_usages.push_back(input_argument->metavar);
required_usages.insert(required_usages.begin(), "<options>");
print_usage(required_usages);
std::cout << std::endl;
if(self.find_argument([](const auto& arg) { return arg.nargs == 0; }))
{ {
std::cout << color::fg_yellow << "FLAGS:" << color::reset << std::endl;
std::cout << std::endl; std::cout << std::endl;
std::string prefix = " "; for(auto&& arg : self.arguments)
if(arg.flags.empty())
{
std::cout << prefix;
std::cout << arg.metavar;
}
for(const std::string& a : arg.flags)
{ {
std::cout << prefix; if(arg.nargs != 0)
std::cout << a; continue;
prefix = ", "; const int col_align = 35;
std::string prefix = " ";
int len = 0;
std::cout << color::fg_green;
for(const std::string& a : arg.flags)
{
len += prefix.length() + a.length();
std::cout << prefix;
std::cout << a;
prefix = ", ";
}
std::cout << color::reset;
int spaces = col_align - len;
if(spaces < 0)
{
std::cout << std::endl;
}
else
{
for(int i = 0; i < spaces; i++)
std::cout << " ";
}
std::cout << arg.help << std::endl;
} }
if(not arg.type.empty()) std::cout << std::endl;
}
if(self.find_argument([](const auto& arg) { return arg.nargs != 0; }))
{
std::cout << color::fg_yellow << "OPTIONS:" << color::reset << std::endl;
for(auto&& arg : self.arguments)
{ {
std::cout << " [" << arg.type << "]"; if(arg.nargs == 0)
if(not arg.default_value.empty()) continue;
std::cout << " (Default: " << arg.default_value << ")"; std::cout << std::endl;
std::string prefix = " ";
std::cout << color::fg_green;
if(arg.flags.empty())
{
std::cout << prefix;
std::cout << arg.metavar;
}
for(const std::string& a : arg.flags)
{
std::cout << prefix;
std::cout << a;
prefix = ", ";
}
std::cout << color::reset;
if(not arg.type.empty())
{
std::cout << " [" << color::fg_blue << arg.type << color::reset << "]";
if(not arg.default_value.empty())
std::cout << " (Default: " << arg.default_value << ")";
}
std::cout << std::endl;
std::cout << " " << arg.help << std::endl;
} }
std::cout << std::endl; std::cout << std::endl;
std::cout << " " << arg.help << std::endl;
} }
std::cout << std::endl;
if(not msg.empty()) if(not msg.empty())
std::cout << msg << std::endl; std::cout << msg << std::endl;
}); });
...@@ -263,6 +497,11 @@ struct argument_parser ...@@ -263,6 +497,11 @@ struct argument_parser
return [=](auto&, auto& arg) { arg.type = type; }; return [=](auto&, auto& arg) { arg.type = type; };
} }
MIGRAPHX_DRIVER_STATIC auto group(const std::string& group)
{
return [=](auto&, auto& arg) { arg.group = group; };
}
template <class T> template <class T>
MIGRAPHX_DRIVER_STATIC auto set_value(T value) MIGRAPHX_DRIVER_STATIC auto set_value(T value)
{ {
...@@ -276,6 +515,109 @@ struct argument_parser ...@@ -276,6 +515,109 @@ struct argument_parser
}; };
} }
template <class T>
void set_exe_name_to(T& x)
{
actions.push_back([&](const auto& self) { x = self.exe_name; });
}
void print_try_help()
{
if(has_argument([](const auto& a) { return contains(a.flags, "--help"); }))
{
std::cout << std::endl;
std::cout << "For more information try '" << color::fg_green << "--help" << color::reset
<< "'" << std::endl;
}
}
void print_usage(const std::vector<std::string>& flags) const
{
std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl;
std::cout << " " << exe_name << " ";
std::cout << join_strings(flags, " ") << std::endl;
}
auto spellcheck(const std::vector<std::string>& inputs)
{
struct result_t
{
const argument* arg = nullptr;
std::string correct = "";
std::string incorrect = "";
std::ptrdiff_t distance = std::numeric_limits<std::ptrdiff_t>::max();
};
result_t result;
for(const auto& input : inputs)
{
if(input.empty())
continue;
if(input[0] != '-')
continue;
for(const auto& arg : arguments)
{
for(const auto& flag : arg.flags)
{
if(flag.empty())
continue;
if(flag[0] != '-')
continue;
auto d =
levenshtein_distance(flag.begin(), flag.end(), input.begin(), input.end());
if(d < result.distance)
result = result_t{&arg, flag, input, d};
}
}
}
return result;
}
bool
run_action(const argument& arg, const std::string& flag, const std::vector<std::string>& inputs)
{
std::string msg = "";
try
{
for(const auto& v : arg.validations)
v(*this, inputs);
return arg.action(*this, inputs);
}
catch(const std::exception& e)
{
msg = e.what();
}
catch(...)
{
msg = "unknown exception";
}
std::cout << color::fg_red << color::bold << "error: " << color::reset;
auto sc = spellcheck(inputs);
if(sc.distance < 5)
{
std::cout << "Found argument '" << color::fg_yellow << sc.incorrect << color::reset
<< "'"
<< " which wasn't expected, or isn't valid in this context" << std::endl;
std::cout << " "
<< "Did you mean " << color::fg_green << sc.correct << color::reset << "?"
<< std::endl;
std::cout << std::endl;
print_usage({sc.arg->usage(sc.correct)});
}
else
{
const auto& flag_name = flag.empty() ? arg.metavar : flag;
std::cout << "Invalid input to '" << color::fg_yellow;
std::cout << arg.usage(flag_name);
std::cout << color::reset << "'" << std::endl;
std::cout << " " << msg << std::endl;
std::cout << std::endl;
print_usage({arg.usage()});
}
std::cout << std::endl;
print_try_help();
return true;
}
bool parse(std::vector<std::string> args) bool parse(std::vector<std::string> args)
{ {
std::unordered_map<std::string, unsigned> keywords; std::unordered_map<std::string, unsigned> keywords;
...@@ -286,8 +628,11 @@ struct argument_parser ...@@ -286,8 +628,11 @@ struct argument_parser
} }
auto arg_map = auto arg_map =
generic_parse(std::move(args), [&](const std::string& x) { return keywords[x]; }); generic_parse(std::move(args), [&](const std::string& x) { return keywords[x]; });
std::list<const argument*> missing_arguments;
std::unordered_set<std::string> groups_used;
for(auto&& arg : arguments) for(auto&& arg : arguments)
{ {
bool used = false;
auto flags = arg.flags; auto flags = arg.flags;
if(flags.empty()) if(flags.empty())
flags = {""}; flags = {""};
...@@ -295,14 +640,41 @@ struct argument_parser ...@@ -295,14 +640,41 @@ struct argument_parser
{ {
if(arg_map.count(flag) > 0) if(arg_map.count(flag) > 0)
{ {
if(arg.action(*this, arg_map[flag])) if(run_action(arg, flag, arg_map[flag]))
return true; return true;
used = true;
} }
} }
if(used and not arg.group.empty())
groups_used.insert(arg.group);
if(arg.required and not used)
missing_arguments.push_back(&arg);
}
// Remove arguments from a group that is being used
missing_arguments.remove_if(
[&](const argument* arg) { return groups_used.count(arg->group); });
if(not missing_arguments.empty())
{
std::cout << color::fg_red << color::bold << "error: " << color::reset;
std::cout << "The following required arguments were not provided:" << std::endl;
std::cout << " " << color::fg_red
<< join_strings(get_argument_usages(std::move(missing_arguments)), " ")
<< color::reset << std::endl;
std::cout << std::endl;
auto required_usages = get_argument_usages(get_required_arguments());
print_usage(required_usages);
print_try_help();
return true;
} }
for(auto&& action : actions)
action(*this);
return false; return false;
} }
void set_exe_name(const std::string& s) { exe_name = s; }
const std::string& get_exe_name() const { return exe_name; }
using string_map = std::unordered_map<std::string, std::vector<std::string>>; using string_map = std::unordered_map<std::string, std::vector<std::string>>;
template <class IsKeyword> template <class IsKeyword>
static string_map generic_parse(std::vector<std::string> as, IsKeyword is_keyword) static string_map generic_parse(std::vector<std::string> as, IsKeyword is_keyword)
...@@ -337,7 +709,9 @@ struct argument_parser ...@@ -337,7 +709,9 @@ struct argument_parser
} }
private: private:
std::vector<argument> arguments; std::list<argument> arguments;
std::string exe_name = "";
std::vector<std::function<void(argument_parser&)>> actions;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -41,7 +41,10 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -41,7 +41,10 @@ inline namespace MIGRAPHX_INLINE_NS {
inline auto& get_commands() inline auto& get_commands()
{ {
// NOLINTNEXTLINE // NOLINTNEXTLINE
static std::unordered_map<std::string, std::function<void(std::vector<std::string> args)>> m; static std::unordered_map<
std::string,
std::function<void(const std::string& exe_name, std::vector<std::string> args)>>
m;
return m; return m;
} }
...@@ -65,10 +68,11 @@ const std::string& command_name() ...@@ -65,10 +68,11 @@ const std::string& command_name()
} }
template <class T> template <class T>
void run_command(std::vector<std::string> args, bool add_help = false) void run_command(const std::string& exe_name, std::vector<std::string> args, bool add_help = false)
{ {
T x; T x;
argument_parser ap; argument_parser ap;
ap.set_exe_name(exe_name + " " + command_name<T>());
if(add_help) if(add_help)
ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help()); ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help());
x.parse(ap); x.parse(ap);
...@@ -81,7 +85,9 @@ template <class T> ...@@ -81,7 +85,9 @@ template <class T>
int auto_register_command() int auto_register_command()
{ {
auto& m = get_commands(); auto& m = get_commands();
m[command_name<T>()] = [](std::vector<std::string> args) { run_command<T>(args, true); }; m[command_name<T>()] = [](const std::string& exe_name, std::vector<std::string> args) {
run_command<T>(exe_name, args, true);
};
return 0; return 0;
} }
......
...@@ -73,8 +73,12 @@ struct loader ...@@ -73,8 +73,12 @@ struct loader
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
ap(file, {}, ap.metavar("<input file>")); ap(file, {}, ap.metavar("<input file>"), ap.file_exist(), ap.required(), ap.group("input"));
ap(model, {"--model"}, ap.help("Load model"), ap.type("resnet50|inceptionv3|alexnet")); ap(model,
{"--model"},
ap.help("Load model"),
ap.type("resnet50|inceptionv3|alexnet"),
ap.group("input"));
ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx")); ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx"));
ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf")); ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf"));
ap(file_type, {"--migraphx"}, ap.help("Load as MIGraphX"), ap.set_value("migraphx")); ap(file_type, {"--migraphx"}, ap.help("Load as MIGraphX"), ap.set_value("migraphx"));
...@@ -578,26 +582,62 @@ struct onnx : command<onnx> ...@@ -578,26 +582,62 @@ struct onnx : command<onnx>
struct main_command struct main_command
{ {
static std::string get_command_help() static std::string get_command_help(const std::string& title = colorize(color::fg_yellow,
"COMMANDS:"))
{ {
std::string result = "Commands:\n"; std::string result = title + "\n";
return std::accumulate(get_commands().begin(), std::vector<std::string> commands(get_commands().size());
get_commands().end(), std::transform(get_commands().begin(),
result, get_commands().end(),
[](auto r, auto&& p) { return r + " " + p.first + "\n"; }); commands.begin(),
[](const auto& p) { return colorize(color::fg_green, p.first); });
std::sort(commands.begin(), commands.end());
return std::accumulate(commands.begin(), commands.end(), result, [](auto r, auto&& s) {
return r + " " + s + "\n";
});
} }
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
std::string version_str = "MIGraphX Version: " + std::to_string(MIGRAPHX_VERSION_MAJOR) + std::string version_str = "MIGraphX Version: " + std::to_string(MIGRAPHX_VERSION_MAJOR) +
"." + std::to_string(MIGRAPHX_VERSION_MINOR); "." + std::to_string(MIGRAPHX_VERSION_MINOR);
ap(wrong_commands, {}, ap.metavar("<command>"), ap.append());
ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help(get_command_help())); ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help(get_command_help()));
ap(nullptr, ap(nullptr,
{"-v", "--version"}, {"-v", "--version"},
ap.help("Show MIGraphX version"), ap.help("Show MIGraphX version"),
ap.show_help(version_str)); ap.show_help(version_str));
// Trim command off of exe name
ap.set_exe_name(ap.get_exe_name().substr(0, ap.get_exe_name().size() - 5));
ap.set_exe_name_to(exe_name);
} }
void run() {} std::vector<std::string> wrong_commands{};
std::string exe_name = "<exe>";
void run()
{
std::cout << color::fg_red << color::bold << "error: " << color::reset;
auto it = std::find_if(wrong_commands.begin(), wrong_commands.end(), [](const auto& c) {
return get_commands().count(c) > 0;
});
if(it == wrong_commands.end())
{
std::cout << "'" << color::fg_yellow << wrong_commands.front() << color::reset
<< "' is not a valid command." << std::endl;
std::cout << get_command_help("Available commands:") << std::endl;
}
else
{
std::cout << "command '" << color::fg_yellow << *it << color::reset
<< "' must be first argument" << std::endl;
std::cout << std::endl;
std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl;
std::cout << " " << exe_name << " " << *it << " <options>" << std::endl;
}
std::cout << std::endl;
}
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
...@@ -619,11 +659,11 @@ int main(int argc, const char* argv[]) ...@@ -619,11 +659,11 @@ int main(int argc, const char* argv[])
auto cmd = args.front(); auto cmd = args.front();
if(m.count(cmd) > 0) if(m.count(cmd) > 0)
{ {
m.at(cmd)({args.begin() + 1, args.end()}); m.at(cmd)(argv[0], {args.begin() + 1, args.end()});
} }
else else
{ {
run_command<main_command>(args); run_command<main_command>(argv[0], args);
} }
return 0; return 0;
......
...@@ -74,6 +74,22 @@ void group_unique(Iterator start, Iterator last, Output out, Predicate pred) ...@@ -74,6 +74,22 @@ void group_unique(Iterator start, Iterator last, Output out, Predicate pred)
} }
} }
template <class Iterator1, class Iterator2>
std::ptrdiff_t
levenshtein_distance(Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterator2 last2)
{
if(first1 == last1)
return std::distance(first2, last2);
if(first2 == last2)
return std::distance(first1, last1);
if(*first1 == *first2)
return levenshtein_distance(std::next(first1), last1, std::next(first2), last2);
auto x1 = levenshtein_distance(std::next(first1), last1, std::next(first2), last2);
auto x2 = levenshtein_distance(first1, last1, std::next(first2), last2);
auto x3 = levenshtein_distance(std::next(first1), last1, first2, last2);
return std::ptrdiff_t{1} + std::min({x1, x2, x3});
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -45,6 +45,11 @@ struct literal : raw_data<literal> ...@@ -45,6 +45,11 @@ struct literal : raw_data<literal>
{ {
literal() {} literal() {}
/*!
* Empty literal with a specific shape type
*/
explicit literal(shape::type_t shape_type) : m_shape(shape_type, {}) {}
template <class U, class T = deduce<U>, shape::type_t ShapeType = shape::get_type<T>{}> template <class U, class T = deduce<U>, shape::type_t ShapeType = shape::get_type<T>{}>
literal(U x) : buffer(make_shared_array<char>(sizeof(T))), m_shape(ShapeType) literal(U x) : buffer(make_shared_array<char>(sizeof(T))), m_shape(ShapeType)
{ {
......
...@@ -37,7 +37,9 @@ enum padding_mode_t ...@@ -37,7 +37,9 @@ enum padding_mode_t
{ {
default_, // NOLINT default_, // NOLINT
same, same,
valid valid,
same_lower,
same_upper
}; };
// The pooling modes must correspond 1-1 to the operators defined for struct parse_pooling. // The pooling modes must correspond 1-1 to the operators defined for struct parse_pooling.
......
...@@ -41,8 +41,9 @@ struct convolution ...@@ -41,8 +41,9 @@ struct convolution
std::vector<std::size_t> stride = {1, 1}; std::vector<std::size_t> stride = {1, 1};
std::vector<std::size_t> dilation = {1, 1}; std::vector<std::size_t> dilation = {1, 1};
int group = 1; int group = 1;
padding_mode_t padding_mode = default_; padding_mode_t padding_mode = default_;
bool use_dynamic_same_auto_pad = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -51,7 +52,8 @@ struct convolution ...@@ -51,7 +52,8 @@ struct convolution
f(self.stride, "stride"), f(self.stride, "stride"),
f(self.dilation, "dilation"), f(self.dilation, "dilation"),
f(self.group, "group"), f(self.group, "group"),
f(self.padding_mode, "padding_mode")); f(self.padding_mode, "padding_mode"),
f(self.use_dynamic_same_auto_pad, "use_dynamic_same_auto_pad"));
} }
std::string name() const { return "convolution"; } std::string name() const { return "convolution"; }
...@@ -69,43 +71,137 @@ struct convolution ...@@ -69,43 +71,137 @@ struct convolution
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3); check_shapes{inputs, *this, true}.has(2).same_type().same_ndims().min_ndims(3);
check_attribute_size(); check_attribute_size();
// dim num of input and attribute should match // num of dims of input and attribute should match
auto input_size = inputs[0].lens().size(); const auto input_size = inputs[0].max_lens().size();
auto padding_size = padding.size(); const auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2)) if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2))
{ {
MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!"); MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!");
} }
const shape& input = inputs.at(0); const shape& x_shape = inputs.at(0);
const shape& weights = inputs.at(1); const shape& w_shape = inputs.at(1);
size_t kdims = input_size - 2; const size_t num_spatial_dims = input_size - 2;
if(kdims != this->kdims()) if(num_spatial_dims != this->kdims())
{ {
MIGRAPHX_THROW("convolution: input k-dims does not match attribute size"); MIGRAPHX_THROW("CONVOLUTION: input k-dims does not match attribute size");
} }
if(input.lens().at(1) != (weights.lens().at(1) * group)) if(not x_shape.dynamic() and not w_shape.dynamic() and
MIGRAPHX_THROW("CONVOLUTION: Mismatch channel numbers"); x_shape.lens().at(1) != (w_shape.lens().at(1) * group))
MIGRAPHX_THROW("CONVOLUTION: mismatched channel numbers");
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]}; std::vector<op::padding_mode_t> dyn_pad_modes = {op::padding_mode_t::same_upper,
op::padding_mode_t::same_lower};
if(use_dynamic_same_auto_pad and not contains(dyn_pad_modes, padding_mode))
{
MIGRAPHX_THROW("CONVOLUTION: use_dynamic_same_auto_pad set with invalid padding mode");
}
if(x_shape.dynamic() or w_shape.dynamic())
{
return dynamic_compute_shape(x_shape, w_shape);
}
else
{
return fixed_compute_shape(x_shape, w_shape);
}
}
std::vector<std::size_t> calc_conv_lens(std::vector<std::size_t> x_lens,
std::vector<std::size_t> w_lens) const
{
const size_t num_spatial_dims = x_lens.size() - 2;
std::vector<size_t> ret = {};
// calculate the output shape of the convolution: ((W - K + 2P) / S) + 1
for(size_t i = 0; i < num_spatial_dims; i++)
{
if(x_lens[i] == 0 or w_lens[i] == 0)
{
// for handling when a dimension = 0 (opt of dynamic_dimension)
ret.push_back(0);
}
else
{
auto padding_factor = 2 * padding[i];
if(padding.size() == 2 * num_spatial_dims)
{
// when padding is {x0_begin, x1_begin, ... x0_end , x1_end, ...}
padding_factor = padding[i] + padding[i + num_spatial_dims];
}
ret.push_back(std::size_t(std::max<std::ptrdiff_t>(
1,
(x_lens[i + 2] - (1 + dilation[i] * (w_lens[i + 2] - 1)) + padding_factor) /
stride[i] +
1)));
}
}
return ret;
}
for(size_t i = 0; i < kdims; i++) shape dynamic_compute_shape(shape x_shape, shape w_shape) const
{
std::vector<shape::dynamic_dimension> output_dyn_dims = {};
auto dynamic_shape_push_back = [&](const shape& input_shape) {
if(input_shape.dynamic())
{
output_dyn_dims.push_back(input_shape.dyn_dims().at(0));
}
else
{
auto l = input_shape.lens().at(0);
output_dyn_dims.push_back({l, l, 0});
}
};
dynamic_shape_push_back(x_shape);
dynamic_shape_push_back(w_shape);
const size_t num_spatial_dims = x_shape.max_lens().size() - 2;
if(use_dynamic_same_auto_pad)
{ {
auto padding_factor = 2 * padding[i]; for(std::size_t i = 0; i < num_spatial_dims; ++i)
if(padding_size == 2 * kdims) {
padding_factor = padding[i] + padding[i + kdims]; auto ceil_div = [](std::size_t x, std::size_t y) { return (x + y - 1) / y; };
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>( auto s = stride[i];
1, if(x_shape.dynamic())
(input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) + {
padding_factor) / auto x = x_shape.dyn_dims()[i + 2];
stride[i] + output_dyn_dims.push_back(shape::dynamic_dimension{
1))); ceil_div(x.min, s), ceil_div(x.max, s), ceil_div(x.opt, s)});
}
else
{
auto od = ceil_div(x_shape.lens()[i + 2], s);
output_dyn_dims.push_back(shape::dynamic_dimension{od, od, 0});
}
}
} }
else
{
auto min_spatial_dims = calc_conv_lens(x_shape.min_lens(), w_shape.max_lens());
auto max_spatial_dims = calc_conv_lens(x_shape.max_lens(), w_shape.min_lens());
auto opt_spatial_dims = calc_conv_lens(x_shape.opt_lens(), w_shape.opt_lens());
for(size_t i = 0; i < num_spatial_dims; ++i)
{
output_dyn_dims.push_back(shape::dynamic_dimension{
min_spatial_dims[i], max_spatial_dims[i], opt_spatial_dims[i]});
}
}
return shape{x_shape.type(), output_dyn_dims};
}
return inputs[0].with_lens(output_lens); shape fixed_compute_shape(shape x_shape, shape w_shape) const
{
std::vector<size_t> output_lens{x_shape.lens()[0], w_shape.lens()[0]};
auto spatial_lens = calc_conv_lens(x_shape.lens(), w_shape.lens());
std::for_each(spatial_lens.begin(), spatial_lens.end(), [&output_lens](auto x) {
output_lens.push_back(x);
});
return x_shape.with_lens(output_lens);
} }
size_t kdims() const size_t kdims() const
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <type_traits>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct fmod : binary<fmod>
{
std::string name() const { return "fmod"; }
value attributes() const
{
auto a = base_attributes();
a["commutative"] = false;
return a;
}
std::string point_function() const { return "fmod"; }
auto apply() const
{
return [](auto x, auto y) { return std::fmod(x, y); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <type_traits>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct mod : binary<mod>
{
std::string name() const { return "mod"; }
value attributes() const
{
auto a = base_attributes();
a["commutative"] = false;
return a;
}
std::string point_function() const { return "mod"; }
auto apply() const
{
return [](auto x, auto y) { return std::fmod((std::remainder(x, y)) + y, y); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -41,8 +41,9 @@ struct quant_convolution ...@@ -41,8 +41,9 @@ struct quant_convolution
std::vector<std::size_t> stride = {1, 1}; std::vector<std::size_t> stride = {1, 1};
std::vector<std::size_t> dilation = {1, 1}; std::vector<std::size_t> dilation = {1, 1};
padding_mode_t padding_mode = default_; padding_mode_t padding_mode = default_;
int group = 1; int group = 1;
bool use_dynamic_same_auto_pad = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -51,7 +52,8 @@ struct quant_convolution ...@@ -51,7 +52,8 @@ struct quant_convolution
f(self.stride, "stride"), f(self.stride, "stride"),
f(self.dilation, "dilation"), f(self.dilation, "dilation"),
f(self.padding_mode, "padding_mode"), f(self.padding_mode, "padding_mode"),
f(self.group, "group")); f(self.group, "group"),
f(self.use_dynamic_same_auto_pad, "use_dynamic_same_auto_pad"));
} }
value attributes() const value attributes() const
......
...@@ -68,8 +68,10 @@ struct operation ...@@ -68,8 +68,10 @@ struct operation
* *
* @param ctx This is the context created by the `target` during compilation. Implementations * @param ctx This is the context created by the `target` during compilation. Implementations
* can use the target's `context` class rather than the `context` interface class. * can use the target's `context` class rather than the `context` interface class.
* @param output This is the output shape. It is equivalent to running `compute_shape` with each * @param output Equivalent to running `compute_shape` with each `shape` of the `argument`.
* `shape` of the `argument`. * For a fixed shape, the returned argument will have the same shape as `output`.
* For a dynamic shape, the returned `argument` will be a fixed shape within the bounds
* set in the dynamic shape `output`.
* @param input This is the `argument` result from the previous instruction's computation. * @param input This is the `argument` result from the previous instruction's computation.
* @return Return an `argument` of the result computation. The `shape` of `argument` should be * @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape. * the same the `output` shape.
...@@ -137,7 +139,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs) ...@@ -137,7 +139,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs)) -> decltype(x.normalize_compute_shape(inputs))
{ {
dependent_type<operation, T> y = x; dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].lens()); normalize_attributes(y, inputs[0].max_lens());
return any_cast<T>(y).normalize_compute_shape(inputs); return any_cast<T>(y).normalize_compute_shape(inputs);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment