Unverified Commit 8f9a766f authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge pull request #339 from ROCmSoftwarePlatform/driver-params

Add parameter to driver to fill inputs with 1s instead of random values
parents 3ab91a79 1fce53ab
......@@ -28,10 +28,32 @@ inline namespace MIGRAPHX_INLINE_NS {
#define MIGRAPHX_DRIVER_STATIC static
#endif
template <class T>
using bare = std::remove_cv_t<std::remove_reference_t<T>>;
namespace detail {
template <class T>
auto is_container(int, T&& x) -> decltype(x.insert(x.end(), *x.begin()), std::true_type{});
template <class T>
std::false_type is_container(float, T&&);
} // namespace detail
template <class T>
struct is_container : decltype(detail::is_container(int(0), std::declval<T>()))
{
};
template <class T>
using is_multi_value =
std::integral_constant<bool, (is_container<T>{} and not std::is_convertible<T, std::string>{})>;
template <class T>
struct value_parser
{
template <MIGRAPHX_REQUIRES(not std::is_enum<T>{})>
template <MIGRAPHX_REQUIRES(not std::is_enum<T>{} and not is_multi_value<T>{})>
static T apply(const std::string& x)
{
T result;
......@@ -43,7 +65,7 @@ struct value_parser
return result;
}
template <MIGRAPHX_REQUIRES(std::is_enum<T>{})>
template <MIGRAPHX_REQUIRES(std::is_enum<T>{} and not is_multi_value<T>{})>
static T apply(const std::string& x)
{
std::ptrdiff_t i;
......@@ -54,6 +76,15 @@ struct value_parser
throw std::runtime_error("Failed to parse: " + x);
return static_cast<T>(i);
}
template <MIGRAPHX_REQUIRES(is_multi_value<T>{} and not std::is_enum<T>{})>
static T apply(const std::string& x)
{
T result;
using value_type = typename T::value_type;
result.insert(result.end(), value_parser<value_type>::apply(x));
return result;
}
};
struct argument_parser
......@@ -69,6 +100,18 @@ struct argument_parser
unsigned nargs = 1;
};
template <class T, MIGRAPHX_REQUIRES(is_multi_value<T>{})>
std::string as_string_value(const T& x)
{
return to_string_range(x);
}
template <class T, MIGRAPHX_REQUIRES(not is_multi_value<T>{})>
std::string as_string_value(const T& x)
{
return to_string(x);
}
template <class T, class... Fs>
void operator()(T& x, const std::vector<std::string>& flags, Fs... fs)
{
......@@ -81,7 +124,7 @@ struct argument_parser
argument& arg = arguments.back();
arg.type = migraphx::get_type_name<T>();
arg.default_value = to_string(x);
arg.default_value = as_string_value(x);
migraphx::each_args([&](auto f) { f(x, arg); }, fs...);
}
......@@ -127,7 +170,7 @@ struct argument_parser
MIGRAPHX_DRIVER_STATIC auto append()
{
return write_action([](auto&, auto& x, auto& params) {
using type = typename decltype(params)::value_type;
using type = typename bare<decltype(params)>::value_type;
std::transform(params.begin(),
params.end(),
std::inserter(x, x.end()),
......
......@@ -8,6 +8,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
......@@ -80,11 +81,13 @@ struct compiler
{
loader l;
bool gpu = true;
std::vector<std::string> fill1;
void parse(argument_parser& ap)
{
l.parse(ap);
ap(gpu, {"--gpu"}, ap.help("Compile on the gpu"), ap.set_value(true));
ap(gpu, {"--cpu"}, ap.help("Compile on the cpu"), ap.set_value(false));
ap(fill1, {"--fill1"}, ap.help("Fill parameter with 1s"), ap.append());
}
program compile()
......@@ -94,7 +97,14 @@ struct compiler
return p;
}
auto params(const program& p) { return create_param_map(p, gpu); }
auto params(const program& p)
{
program::parameter_map m;
for(auto&& s : fill1)
m[s] = fill_argument(p.get_parameter_shape(s), 1);
fill_param_map(m, p, gpu);
return m;
}
};
struct read : command<read>
......@@ -109,6 +119,19 @@ struct read : command<read>
}
};
struct params : command<params>
{
loader l;
void parse(argument_parser& ap) { l.parse(ap); }
void run()
{
auto p = l.load();
for(auto&& param : p.get_parameter_shapes())
std::cout << param.first << ": " << param.second << std::endl;
}
};
struct verify : command<verify>
{
loader l;
......
......@@ -11,6 +11,23 @@ namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
program::parameter_map fill_param_map(program::parameter_map& m, const program& p, bool gpu)
{
for(auto&& x : p.get_parameter_shapes())
{
argument& arg = m[x.first];
if(arg.empty())
arg = generate_argument(x.second);
#ifdef HAVE_GPU
if(gpu)
arg = gpu::to_gpu(arg);
#else
(void)gpu;
#endif
}
return m;
}
program::parameter_map create_param_map(const program& p, bool gpu)
{
program::parameter_map m;
......
......@@ -7,6 +7,7 @@ namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
program::parameter_map fill_param_map(program::parameter_map& m, const program& p, bool gpu);
program::parameter_map create_param_map(const program& p, bool gpu = true);
void compile_program(program& p, bool gpu = true);
......
......@@ -3,6 +3,17 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
argument fill_argument(shape s, unsigned long value)
{
argument result;
s.visit_type([&](auto as) {
using type = typename decltype(as)::type;
auto v = fill_tensor_data<type>(s, value);
result = {s, [v]() mutable { return reinterpret_cast<char*>(v.data()); }};
});
return result;
}
argument generate_argument(shape s, unsigned long seed)
{
argument result;
......
......@@ -87,6 +87,16 @@ std::vector<T> generate_tensor_data(const migraphx::shape& s, unsigned long seed
return result;
}
template <class T>
std::vector<T> fill_tensor_data(const migraphx::shape& s, unsigned long value = 0)
{
std::vector<T> result(s.elements());
std::generate(result.begin(), result.end(), [=] { return value; });
return result;
}
argument fill_argument(shape s, unsigned long value = 0);
argument generate_argument(shape s, unsigned long seed = 0);
literal generate_literal(shape s, unsigned long seed = 0);
......
......@@ -23,9 +23,10 @@ using bool_c = std::integral_constant<bool, B>;
#ifdef CPPCHECK
#define MIGRAPHX_REQUIRES(...) class = void
#else
#define MIGRAPHX_REQUIRES(...) \
bool MIGRAPHX_REQUIRES_VAR() = true, \
typename std::enable_if<(MIGRAPHX_REQUIRES_VAR() && (migraphx::and_<__VA_ARGS__>{})), \
#define MIGRAPHX_REQUIRES(...) \
long MIGRAPHX_REQUIRES_VAR() = __LINE__, \
typename std::enable_if<(MIGRAPHX_REQUIRES_VAR() == __LINE__ && \
(migraphx::and_<__VA_ARGS__>{})), \
int>::type = 0
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment