Commit 255d6868 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into keep_std_shape

parents 2c603654 73b8a773
...@@ -88,6 +88,7 @@ struct cpp_generator_impl ...@@ -88,6 +88,7 @@ struct cpp_generator_impl
std::stringstream fs{}; std::stringstream fs{};
std::size_t function_count = 0; std::size_t function_count = 0;
std::function<std::string(std::string)> fmap = nullptr; std::function<std::string(std::string)> fmap = nullptr;
std::function<std::string(shape)> fresult = nullptr;
std::unordered_map<std::string, std::string> point_op_map = {}; std::unordered_map<std::string, std::string> point_op_map = {};
}; };
cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {} cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {}
...@@ -104,6 +105,8 @@ cpp_generator::~cpp_generator() noexcept = default; ...@@ -104,6 +105,8 @@ cpp_generator::~cpp_generator() noexcept = default;
void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; } void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; }
void cpp_generator::fresult(const std::function<std::string(shape)>& f) { impl->fresult = f; }
void cpp_generator::add_point_op(const std::string& op_name, const std::string& code) void cpp_generator::add_point_op(const std::string& op_name, const std::string& code)
{ {
impl->point_op_map[op_name] = code; impl->point_op_map[op_name] = code;
...@@ -174,7 +177,12 @@ cpp_generator::function cpp_generator::generate_module(const module& m) ...@@ -174,7 +177,12 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
ins->inputs().end(), ins->inputs().end(),
std::back_inserter(args), std::back_inserter(args),
[&](auto i) { return names.at(i); }); [&](auto i) { return names.at(i); });
return this->generate_point_op(ins->get_operator(), args);
auto s = this->generate_point_op(ins->get_operator(), args);
if(impl->fresult)
return impl->fresult(ins->get_shape()) + '(' + s + ')';
else
return s;
}); });
return f; return f;
} }
......
...@@ -68,6 +68,8 @@ struct cpp_generator ...@@ -68,6 +68,8 @@ struct cpp_generator
void fmap(const std::function<std::string(std::string)>& f); void fmap(const std::function<std::string(std::string)>& f);
void fresult(const std::function<std::string(shape)>& f);
void add_point_op(const std::string& op_name, const std::string& code); void add_point_op(const std::string& op_name, const std::string& code);
std::string generate_point_op(const operation& op, const std::vector<std::string>& args); std::string generate_point_op(const operation& op, const std::vector<std::string>& args);
......
...@@ -89,6 +89,13 @@ inline std::ptrdiff_t operator-(basic_iota_iterator<F, Iterator> x, ...@@ -89,6 +89,13 @@ inline std::ptrdiff_t operator-(basic_iota_iterator<F, Iterator> x,
return x.index - y.index; return x.index - y.index;
} }
template <class F, class Iterator>
inline basic_iota_iterator<F, Iterator> operator-(basic_iota_iterator<F, Iterator> x,
std::ptrdiff_t y)
{
return x -= y;
}
template <class F, class Iterator> template <class F, class Iterator>
inline bool operator==(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y) inline bool operator==(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{ {
......
...@@ -78,8 +78,8 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f) ...@@ -78,8 +78,8 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f)
template <class F> template <class F>
void par_for(std::size_t n, std::size_t min_grain, F f) void par_for(std::size_t n, std::size_t min_grain, F f)
{ {
const auto threadsize = const auto threadsize = std::min<std::size_t>(std::thread::hardware_concurrency(),
std::min<std::size_t>(std::thread::hardware_concurrency(), n / min_grain); n / std::max<std::size_t>(1, min_grain));
par_for_impl(n, threadsize, f); par_for_impl(n, threadsize, f);
} }
......
...@@ -180,6 +180,63 @@ void program::finalize() ...@@ -180,6 +180,63 @@ void program::finalize()
mm->finalize(this->impl->ctx); mm->finalize(this->impl->ctx);
} }
template <class T>
std::string classify(T x)
{
switch(std::fpclassify(x))
{
case FP_INFINITE: return "inf";
case FP_NAN: return "nan";
case FP_NORMAL: return "normal";
case FP_SUBNORMAL: return "subnormal";
case FP_ZERO: return "zero";
default: return "unknown";
}
}
std::unordered_set<std::string> classify_argument(const argument& a)
{
std::unordered_set<std::string> result;
a.visit(
[&](auto t) {
for(const auto& x : t)
result.insert(classify(x));
},
[&](const auto& xs) {
for(const auto& x : xs)
{
auto r = classify_argument(x);
result.insert(r.begin(), r.end());
}
});
return result;
}
void preview_argument(std::ostream& os, const argument& a)
{
a.visit(
[&](auto t) {
if(t.size() <= 10)
{
os << t;
}
else
{
os << to_string_range(t.begin(), t.begin() + 5);
os << ", ..., ";
os << to_string_range(t.end() - 5, t.end());
}
},
[&](const auto& xs) {
for(const auto& x : xs)
{
os << '{';
preview_argument(os, x);
os << '}';
}
});
}
template <class F> template <class F>
std::vector<argument> generic_eval(const module* mod, std::vector<argument> generic_eval(const module* mod,
context& ctx, context& ctx,
...@@ -312,8 +369,21 @@ std::vector<argument> program::eval(parameter_map params) const ...@@ -312,8 +369,21 @@ std::vector<argument> program::eval(parameter_map params) const
if(trace_level > 1 and ins->name().front() != '@' and if(trace_level > 1 and ins->name().front() != '@' and
ins->name() != "load" and not result.empty()) ins->name() != "load" and not result.empty())
{ {
target tgt = make_target(this->impl->target_name); target tgt = make_target(this->impl->target_name);
std::cout << "Output: " << tgt.copy_from(result) << std::endl; auto buffer = tgt.copy_from(result);
if(trace_level == 2)
{
std::cout << "Output has "
<< to_string_range(classify_argument(buffer))
<< std::endl;
std::cout << "Output: ";
preview_argument(std::cout, buffer);
std::cout << std::endl;
}
else
{
std::cout << "Output: " << buffer << std::endl;
}
} }
return result; return result;
})); }));
......
...@@ -12,6 +12,8 @@ namespace migraphx { ...@@ -12,6 +12,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL);
struct precompile_op struct precompile_op
{ {
operation op = op::identity{}; operation op = op::identity{};
...@@ -70,6 +72,14 @@ struct compiled_result ...@@ -70,6 +72,14 @@ struct compiled_result
instruction_ref ins; instruction_ref ins;
}; };
template <class F>
void par_compile(std::size_t n, F f)
{
if(n == 0)
return;
par_for(n, n / value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}, n), f);
}
void compile_ops::apply(module& m) const void compile_ops::apply(module& m) const
{ {
auto compilers = make_compilers(pointwise_compiler{}); auto compilers = make_compilers(pointwise_compiler{});
...@@ -85,7 +95,7 @@ void compile_ops::apply(module& m) const ...@@ -85,7 +95,7 @@ void compile_ops::apply(module& m) const
compiles.emplace_back([=]() -> compiled_result { return {c(*ctx, ins, preop), ins}; }); compiles.emplace_back([=]() -> compiled_result { return {c(*ctx, ins, preop), ins}; });
} }
std::vector<compiled_result> results(compiles.size()); std::vector<compiled_result> results(compiles.size());
par_for(compiles.size(), 1, [&](auto i) { results[i] = compiles[i](); }); par_compile(compiles.size(), [&](auto i) { results[i] = compiles[i](); });
for(const auto& cr : results) for(const auto& cr : results)
{ {
m.replace_instruction(cr.ins, cr.op, cr.ins->inputs()); m.replace_instruction(cr.ins, cr.op, cr.ins->inputs());
......
...@@ -70,6 +70,9 @@ operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, modu ...@@ -70,6 +70,9 @@ operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, modu
g.add_point_op("less", "migraphx::abs(${0} < ${1})"); g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})"); g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})"); g.add_point_op("not", "migraphx::abs(not ${0})");
// Add explict conversions
g.fresult(
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
auto name = auto name =
g.create_function(g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m)); g.create_function(g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m));
return compile_pointwise((ctx), inputs, "MIGRAPHX_LIFT(" + name + ")", g.str()); return compile_pointwise((ctx), inputs, "MIGRAPHX_LIFT(" + name + ")", g.str());
......
...@@ -66,15 +66,18 @@ __device__ __host__ auto as_vec(T* x) ...@@ -66,15 +66,18 @@ __device__ __host__ auto as_vec(T* x)
return reinterpret_cast<vec<T, N>*>(x); return reinterpret_cast<vec<T, N>*>(x);
} }
template <class T, index_int N>
using safe_vec = vec<std::conditional_t<std::is_same<T, bool>{}, uint8_t, T>, N>;
template <class... Ts> template <class... Ts>
constexpr auto vec_transform(Ts... xs) constexpr auto vec_transform(Ts... xs)
{ {
return [=](auto f) { return [=](auto f) {
if constexpr(is_any_vec<Ts...>()) if constexpr(is_any_vec<Ts...>())
{ {
using type = decltype(f(vec_at(xs, 0)...)); using type = decltype(f(vec_at(xs, 0)...));
constexpr auto size = common_vec_size<Ts...>(); constexpr auto size = common_vec_size<Ts...>();
vec<type, size> result = {0}; safe_vec<type, size> result = {0};
for(int i = 0; i < size; i++) for(int i = 0; i < size; i++)
result[i] = f(vec_at(xs, i)...); result[i] = f(vec_at(xs, i)...);
return result; return result;
......
...@@ -50,14 +50,14 @@ constexpr auto shape_step(Shape s, Axis) ...@@ -50,14 +50,14 @@ constexpr auto shape_step(Shape s, Axis)
}); });
} }
// Bools can not be used as a vector type so convert it to int8 // Bools can not be used as a vector type so convert it to uint8
template <class T> template <class T>
__device__ __host__ T* remove_bool(T* x) __device__ __host__ T* remove_bool(T* x)
{ {
return x; return x;
} }
inline __device__ __host__ int8_t* remove_bool(bool* x) { return reinterpret_cast<int8_t*>(x); } inline __device__ __host__ uint8_t* remove_bool(bool* x) { return reinterpret_cast<uint8_t*>(x); }
template <index_int N, class T, class Axis> template <index_int N, class T, class Axis>
__device__ __host__ auto as_vec(T x, Axis axis) __device__ __host__ auto as_vec(T x, Axis axis)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment