"vscode:/vscode.git/clone" did not exist on "300582fc15a2b6d4e9c4a5ae3ba1e51758fcb558"
Unverified Commit 73b8a773 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Enforce types to avoid compilation error in pointwise fusions (#1077)

Enforce types to avoid compilation error in pointwise fusions
This fixes compile failure: gpt-2, fp16 on Navi
parent b20e3d4d
......@@ -88,6 +88,7 @@ struct cpp_generator_impl
std::stringstream fs{};
std::size_t function_count = 0;
std::function<std::string(std::string)> fmap = nullptr;
std::function<std::string(shape)> fresult = nullptr;
std::unordered_map<std::string, std::string> point_op_map = {};
};
cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {}
......@@ -104,6 +105,8 @@ cpp_generator::~cpp_generator() noexcept = default;
void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; }
void cpp_generator::fresult(const std::function<std::string(shape)>& f) { impl->fresult = f; }
void cpp_generator::add_point_op(const std::string& op_name, const std::string& code)
{
impl->point_op_map[op_name] = code;
......@@ -174,7 +177,12 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
ins->inputs().end(),
std::back_inserter(args),
[&](auto i) { return names.at(i); });
return this->generate_point_op(ins->get_operator(), args);
auto s = this->generate_point_op(ins->get_operator(), args);
if(impl->fresult)
return impl->fresult(ins->get_shape()) + '(' + s + ')';
else
return s;
});
return f;
}
......
......@@ -68,6 +68,8 @@ struct cpp_generator
void fmap(const std::function<std::string(std::string)>& f);
void fresult(const std::function<std::string(shape)>& f);
void add_point_op(const std::string& op_name, const std::string& code);
std::string generate_point_op(const operation& op, const std::vector<std::string>& args);
......
......@@ -78,8 +78,8 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f)
template <class F>
void par_for(std::size_t n, std::size_t min_grain, F f)
{
const auto threadsize =
std::min<std::size_t>(std::thread::hardware_concurrency(), n / min_grain);
const auto threadsize = std::min<std::size_t>(std::thread::hardware_concurrency(),
n / std::max<std::size_t>(1, min_grain));
par_for_impl(n, threadsize, f);
}
......
......@@ -12,6 +12,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL);
struct precompile_op
{
operation op = op::identity{};
......@@ -70,6 +72,14 @@ struct compiled_result
instruction_ref ins;
};
template <class F>
void par_compile(std::size_t n, F f)
{
if(n == 0)
return;
par_for(n, n / value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}, n), f);
}
void compile_ops::apply(module& m) const
{
auto compilers = make_compilers(pointwise_compiler{});
......@@ -85,7 +95,7 @@ void compile_ops::apply(module& m) const
compiles.emplace_back([=]() -> compiled_result { return {c(*ctx, ins, preop), ins}; });
}
std::vector<compiled_result> results(compiles.size());
par_for(compiles.size(), 1, [&](auto i) { results[i] = compiles[i](); });
par_compile(compiles.size(), [&](auto i) { results[i] = compiles[i](); });
for(const auto& cr : results)
{
m.replace_instruction(cr.ins, cr.op, cr.ins->inputs());
......
......@@ -70,6 +70,9 @@ operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, modu
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
// Add explict conversions
g.fresult(
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
auto name =
g.create_function(g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m));
return compile_pointwise((ctx), inputs, "MIGRAPHX_LIFT(" + name + ")", g.str());
......
......@@ -66,6 +66,9 @@ __device__ __host__ auto as_vec(T* x)
return reinterpret_cast<vec<T, N>*>(x);
}
template <class T, index_int N>
using safe_vec = vec<std::conditional_t<std::is_same<T, bool>{}, uint8_t, T>, N>;
template <class... Ts>
constexpr auto vec_transform(Ts... xs)
{
......@@ -74,7 +77,7 @@ constexpr auto vec_transform(Ts... xs)
{
using type = decltype(f(vec_at(xs, 0)...));
constexpr auto size = common_vec_size<Ts...>();
vec<type, size> result = {0};
safe_vec<type, size> result = {0};
for(int i = 0; i < size; i++)
result[i] = f(vec_at(xs, i)...);
return result;
......
......@@ -50,14 +50,14 @@ constexpr auto shape_step(Shape s, Axis)
});
}
// Bools can not be used as a vector type so convert it to int8
// Bools can not be used as a vector type so convert it to uint8
template <class T>
__device__ __host__ T* remove_bool(T* x)
{
return x;
}
inline __device__ __host__ int8_t* remove_bool(bool* x) { return reinterpret_cast<int8_t*>(x); }
inline __device__ __host__ uint8_t* remove_bool(bool* x) { return reinterpret_cast<uint8_t*>(x); }
template <index_int N, class T, class Axis>
__device__ __host__ auto as_vec(T x, Axis axis)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment