Unverified Commit 73b8a773 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Enforce types to avoid compilation error in pointwise fusions (#1077)

Enforce types to avoid compilation error in pointwise fusions
This fixes compile failure: gpt-2, fp16 on Navi
parent b20e3d4d
...@@ -88,6 +88,7 @@ struct cpp_generator_impl ...@@ -88,6 +88,7 @@ struct cpp_generator_impl
std::stringstream fs{}; std::stringstream fs{};
std::size_t function_count = 0; std::size_t function_count = 0;
std::function<std::string(std::string)> fmap = nullptr; std::function<std::string(std::string)> fmap = nullptr;
std::function<std::string(shape)> fresult = nullptr;
std::unordered_map<std::string, std::string> point_op_map = {}; std::unordered_map<std::string, std::string> point_op_map = {};
}; };
cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {} cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {}
...@@ -104,6 +105,8 @@ cpp_generator::~cpp_generator() noexcept = default; ...@@ -104,6 +105,8 @@ cpp_generator::~cpp_generator() noexcept = default;
void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; } void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; }
void cpp_generator::fresult(const std::function<std::string(shape)>& f) { impl->fresult = f; }
void cpp_generator::add_point_op(const std::string& op_name, const std::string& code) void cpp_generator::add_point_op(const std::string& op_name, const std::string& code)
{ {
impl->point_op_map[op_name] = code; impl->point_op_map[op_name] = code;
...@@ -174,7 +177,12 @@ cpp_generator::function cpp_generator::generate_module(const module& m) ...@@ -174,7 +177,12 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
ins->inputs().end(), ins->inputs().end(),
std::back_inserter(args), std::back_inserter(args),
[&](auto i) { return names.at(i); }); [&](auto i) { return names.at(i); });
return this->generate_point_op(ins->get_operator(), args);
auto s = this->generate_point_op(ins->get_operator(), args);
if(impl->fresult)
return impl->fresult(ins->get_shape()) + '(' + s + ')';
else
return s;
}); });
return f; return f;
} }
......
...@@ -68,6 +68,8 @@ struct cpp_generator ...@@ -68,6 +68,8 @@ struct cpp_generator
void fmap(const std::function<std::string(std::string)>& f); void fmap(const std::function<std::string(std::string)>& f);
void fresult(const std::function<std::string(shape)>& f);
void add_point_op(const std::string& op_name, const std::string& code); void add_point_op(const std::string& op_name, const std::string& code);
std::string generate_point_op(const operation& op, const std::vector<std::string>& args); std::string generate_point_op(const operation& op, const std::vector<std::string>& args);
......
...@@ -78,8 +78,8 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f) ...@@ -78,8 +78,8 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f)
template <class F> template <class F>
void par_for(std::size_t n, std::size_t min_grain, F f) void par_for(std::size_t n, std::size_t min_grain, F f)
{ {
const auto threadsize = const auto threadsize = std::min<std::size_t>(std::thread::hardware_concurrency(),
std::min<std::size_t>(std::thread::hardware_concurrency(), n / min_grain); n / std::max<std::size_t>(1, min_grain));
par_for_impl(n, threadsize, f); par_for_impl(n, threadsize, f);
} }
......
...@@ -12,6 +12,8 @@ namespace migraphx { ...@@ -12,6 +12,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL);
struct precompile_op struct precompile_op
{ {
operation op = op::identity{}; operation op = op::identity{};
...@@ -70,6 +72,14 @@ struct compiled_result ...@@ -70,6 +72,14 @@ struct compiled_result
instruction_ref ins; instruction_ref ins;
}; };
template <class F>
void par_compile(std::size_t n, F f)
{
if(n == 0)
return;
par_for(n, n / value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}, n), f);
}
void compile_ops::apply(module& m) const void compile_ops::apply(module& m) const
{ {
auto compilers = make_compilers(pointwise_compiler{}); auto compilers = make_compilers(pointwise_compiler{});
...@@ -85,7 +95,7 @@ void compile_ops::apply(module& m) const ...@@ -85,7 +95,7 @@ void compile_ops::apply(module& m) const
compiles.emplace_back([=]() -> compiled_result { return {c(*ctx, ins, preop), ins}; }); compiles.emplace_back([=]() -> compiled_result { return {c(*ctx, ins, preop), ins}; });
} }
std::vector<compiled_result> results(compiles.size()); std::vector<compiled_result> results(compiles.size());
par_for(compiles.size(), 1, [&](auto i) { results[i] = compiles[i](); }); par_compile(compiles.size(), [&](auto i) { results[i] = compiles[i](); });
for(const auto& cr : results) for(const auto& cr : results)
{ {
m.replace_instruction(cr.ins, cr.op, cr.ins->inputs()); m.replace_instruction(cr.ins, cr.op, cr.ins->inputs());
......
...@@ -70,6 +70,9 @@ operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, modu ...@@ -70,6 +70,9 @@ operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, modu
g.add_point_op("less", "migraphx::abs(${0} < ${1})"); g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})"); g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})"); g.add_point_op("not", "migraphx::abs(not ${0})");
// Add explict conversions
g.fresult(
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
auto name = auto name =
g.create_function(g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m)); g.create_function(g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m));
return compile_pointwise((ctx), inputs, "MIGRAPHX_LIFT(" + name + ")", g.str()); return compile_pointwise((ctx), inputs, "MIGRAPHX_LIFT(" + name + ")", g.str());
......
...@@ -66,6 +66,9 @@ __device__ __host__ auto as_vec(T* x) ...@@ -66,6 +66,9 @@ __device__ __host__ auto as_vec(T* x)
return reinterpret_cast<vec<T, N>*>(x); return reinterpret_cast<vec<T, N>*>(x);
} }
template <class T, index_int N>
using safe_vec = vec<std::conditional_t<std::is_same<T, bool>{}, uint8_t, T>, N>;
template <class... Ts> template <class... Ts>
constexpr auto vec_transform(Ts... xs) constexpr auto vec_transform(Ts... xs)
{ {
...@@ -74,7 +77,7 @@ constexpr auto vec_transform(Ts... xs) ...@@ -74,7 +77,7 @@ constexpr auto vec_transform(Ts... xs)
{ {
using type = decltype(f(vec_at(xs, 0)...)); using type = decltype(f(vec_at(xs, 0)...));
constexpr auto size = common_vec_size<Ts...>(); constexpr auto size = common_vec_size<Ts...>();
vec<type, size> result = {0}; safe_vec<type, size> result = {0};
for(int i = 0; i < size; i++) for(int i = 0; i < size; i++)
result[i] = f(vec_at(xs, i)...); result[i] = f(vec_at(xs, i)...);
return result; return result;
......
...@@ -50,14 +50,14 @@ constexpr auto shape_step(Shape s, Axis) ...@@ -50,14 +50,14 @@ constexpr auto shape_step(Shape s, Axis)
}); });
} }
// Bools can not be used as a vector type so convert it to int8 // Bools can not be used as a vector type so convert it to uint8
template <class T> template <class T>
__device__ __host__ T* remove_bool(T* x) __device__ __host__ T* remove_bool(T* x)
{ {
return x; return x;
} }
inline __device__ __host__ int8_t* remove_bool(bool* x) { return reinterpret_cast<int8_t*>(x); } inline __device__ __host__ uint8_t* remove_bool(bool* x) { return reinterpret_cast<uint8_t*>(x); }
template <index_int N, class T, class Axis> template <index_int N, class T, class Axis>
__device__ __host__ auto as_vec(T x, Axis axis) __device__ __host__ auto as_vec(T x, Axis axis)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment