Commit 0672c72a authored by Umang Yadav's avatar Umang Yadav
Browse files

Disable vectorization for float8

parent 27759bd0
......@@ -54,6 +54,11 @@ vectorize vectorize::elements(std::size_t axis,
const std::vector<shape>& inputs,
const std::vector<std::size_t>& sizes)
{
// disable vectorization for fp8 types
if(std::any_of(inputs.begin(), inputs.end(), [&](auto ishape) {
return ishape.type() == migraphx::shape::float8_type;
}))
return {1, axis};
if(std::all_of(
inputs.begin(), inputs.end(), [&](const auto& s) { return s.lens()[axis] == 1; }))
return {1, axis};
......@@ -86,6 +91,11 @@ vectorize vectorize::elements(std::size_t axis,
vectorize vectorize::elements(context& ctx, std::size_t axis, const std::vector<shape>& inputs)
{
// disable vectorization for fp8 types
if(std::any_of(inputs.begin(), inputs.end(), [&](auto ishape) {
return ishape.type() == migraphx::shape::float8_type;
}))
return {1, axis};
if(inputs.empty())
return {1, axis};
std::size_t n = std::max_element(inputs.begin(),
......@@ -305,7 +315,7 @@ std::string generate_reduce(const module& m, const std::string& name)
std::transform(
params.begin(), params.end(), params.begin(), [](auto s) { return "auto " + s; });
return interpolate_string(inner_template,
{{"inner", inner_name},
{{"inner", inner_name},
{"params", join_strings(params, ", ")},
{"args", join_strings(args, ", ")},
{"call", call_function}});
......
......@@ -237,9 +237,7 @@ template <index_int N, index_int Axis, class T>
__device__ __host__ auto vectorize_tensor(T x)
{
constexpr auto shape = get_shape_c<T>{};
if constexpr(is_same<typename T::type, migraphx_fp8::fp8e4m3fnuz>{})
return x;
else if constexpr(shape.lens[Axis] == 1)
if constexpr(shape.lens[Axis] == 1)
return x;
else if constexpr(shape.strides[Axis] == 0)
return tensor_step<N>(x, _c<Axis>);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment