Unverified Commit 73b8a773 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Enforce types to avoid compilation error in pointwise fusions (#1077)

Enforce types to avoid compilation error in pointwise fusions
This fixes compile failure: gpt-2, fp16 on Navi
parent b20e3d4d
......@@ -88,6 +88,7 @@ struct cpp_generator_impl
std::stringstream fs{};
std::size_t function_count = 0;
std::function<std::string(std::string)> fmap = nullptr;
std::function<std::string(shape)> fresult = nullptr;
std::unordered_map<std::string, std::string> point_op_map = {};
};
cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {}
......@@ -104,6 +105,8 @@ cpp_generator::~cpp_generator() noexcept = default;
void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; }
void cpp_generator::fresult(const std::function<std::string(shape)>& f) { impl->fresult = f; }
void cpp_generator::add_point_op(const std::string& op_name, const std::string& code)
{
impl->point_op_map[op_name] = code;
......@@ -174,7 +177,12 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
ins->inputs().end(),
std::back_inserter(args),
[&](auto i) { return names.at(i); });
return this->generate_point_op(ins->get_operator(), args);
auto s = this->generate_point_op(ins->get_operator(), args);
if(impl->fresult)
return impl->fresult(ins->get_shape()) + '(' + s + ')';
else
return s;
});
return f;
}
......
......@@ -68,6 +68,8 @@ struct cpp_generator
void fmap(const std::function<std::string(std::string)>& f);
void fresult(const std::function<std::string(shape)>& f);
void add_point_op(const std::string& op_name, const std::string& code);
std::string generate_point_op(const operation& op, const std::vector<std::string>& args);
......
......@@ -78,8 +78,8 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f)
template <class F>
void par_for(std::size_t n, std::size_t min_grain, F f)
{
const auto threadsize =
std::min<std::size_t>(std::thread::hardware_concurrency(), n / min_grain);
const auto threadsize = std::min<std::size_t>(std::thread::hardware_concurrency(),
n / std::max<std::size_t>(1, min_grain));
par_for_impl(n, threadsize, f);
}
......
......@@ -12,6 +12,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL);
struct precompile_op
{
operation op = op::identity{};
......@@ -70,6 +72,14 @@ struct compiled_result
instruction_ref ins;
};
template <class F>
void par_compile(std::size_t n, F f)
{
if(n == 0)
return;
par_for(n, n / value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}, n), f);
}
void compile_ops::apply(module& m) const
{
auto compilers = make_compilers(pointwise_compiler{});
......@@ -85,7 +95,7 @@ void compile_ops::apply(module& m) const
compiles.emplace_back([=]() -> compiled_result { return {c(*ctx, ins, preop), ins}; });
}
std::vector<compiled_result> results(compiles.size());
par_for(compiles.size(), 1, [&](auto i) { results[i] = compiles[i](); });
par_compile(compiles.size(), [&](auto i) { results[i] = compiles[i](); });
for(const auto& cr : results)
{
m.replace_instruction(cr.ins, cr.op, cr.ins->inputs());
......
......@@ -70,6 +70,9 @@ operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, modu
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
// Add explict conversions
g.fresult(
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
auto name =
g.create_function(g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m));
return compile_pointwise((ctx), inputs, "MIGRAPHX_LIFT(" + name + ")", g.str());
......
......@@ -66,6 +66,9 @@ __device__ __host__ auto as_vec(T* x)
return reinterpret_cast<vec<T, N>*>(x);
}
template <class T, index_int N>
using safe_vec = vec<std::conditional_t<std::is_same<T, bool>{}, uint8_t, T>, N>;
template <class... Ts>
constexpr auto vec_transform(Ts... xs)
{
......@@ -74,7 +77,7 @@ constexpr auto vec_transform(Ts... xs)
{
using type = decltype(f(vec_at(xs, 0)...));
constexpr auto size = common_vec_size<Ts...>();
vec<type, size> result = {0};
safe_vec<type, size> result = {0};
for(int i = 0; i < size; i++)
result[i] = f(vec_at(xs, i)...);
return result;
......
......@@ -50,14 +50,14 @@ constexpr auto shape_step(Shape s, Axis)
});
}
// Bools can not be used as a vector type so convert it to int8
// Bools can not be used as a vector type so convert it to uint8
template <class T>
__device__ __host__ T* remove_bool(T* x)
{
return x;
}
inline __device__ __host__ int8_t* remove_bool(bool* x) { return reinterpret_cast<int8_t*>(x); }
inline __device__ __host__ uint8_t* remove_bool(bool* x) { return reinterpret_cast<uint8_t*>(x); }
template <index_int N, class T, class Axis>
__device__ __host__ auto as_vec(T x, Axis axis)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment