Commit 3f4ef63e authored by charlie's avatar charlie
Browse files

Fix operation bug at mod_compute_shape

removed the new rank<1> function as it messed up batch_quant_dot_5
verify test
parent 700d7761
...@@ -140,9 +140,9 @@ template <class T> ...@@ -140,9 +140,9 @@ template <class T>
auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs) auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs)) -> decltype(x.normalize_compute_shape(inputs))
{ {
dependent_type<operation, T> y = x;
if(inputs.empty()) if(inputs.empty())
MIGRAPHX_THROW("At least one input is required for " + x.name()); MIGRAPHX_THROW("At least one input is required for " + x.name());
dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].max_lens()); normalize_attributes(y, inputs[0].max_lens());
return any_cast<T>(y).normalize_compute_shape(inputs); return any_cast<T>(y).normalize_compute_shape(inputs);
} }
...@@ -168,7 +168,7 @@ shape compute_shape_op(const T& x, const std::vector<shape>& inputs) ...@@ -168,7 +168,7 @@ shape compute_shape_op(const T& x, const std::vector<shape>& inputs)
} }
template <class T> template <class T>
auto mod_compute_shape_op(rank<2>, auto mod_compute_shape_op(rank<1>,
const T& x, const T& x,
const std::vector<shape>& inputs, const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args) const std::vector<module_ref>& mod_args)
...@@ -177,15 +177,6 @@ auto mod_compute_shape_op(rank<2>, ...@@ -177,15 +177,6 @@ auto mod_compute_shape_op(rank<2>,
return x.compute_shape(inputs, mod_args); return x.compute_shape(inputs, mod_args);
} }
template <class T>
auto mod_compute_shape_op(rank<1>,
const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>&) -> decltype(x.compute_shape(inputs))
{
return x.compute_shape(inputs);
}
template <class T> template <class T>
shape mod_compute_shape_op(rank<0>, shape mod_compute_shape_op(rank<0>,
const T& x, const T& x,
......
...@@ -168,7 +168,7 @@ shape compute_shape_op(const T& x, const std::vector<shape>& inputs) ...@@ -168,7 +168,7 @@ shape compute_shape_op(const T& x, const std::vector<shape>& inputs)
} }
template <class T> template <class T>
auto mod_compute_shape_op(rank<2>, auto mod_compute_shape_op(rank<1>,
const T& x, const T& x,
const std::vector<shape>& inputs, const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args) const std::vector<module_ref>& mod_args)
...@@ -177,15 +177,6 @@ auto mod_compute_shape_op(rank<2>, ...@@ -177,15 +177,6 @@ auto mod_compute_shape_op(rank<2>,
return x.compute_shape(inputs, mod_args); return x.compute_shape(inputs, mod_args);
} }
template <class T>
auto mod_compute_shape_op(rank<1>,
const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>&) -> decltype(x.compute_shape(inputs))
{
return x.compute_shape(inputs);
}
template <class T> template <class T>
shape mod_compute_shape_op(rank<0>, shape mod_compute_shape_op(rank<0>,
const T& x, const T& x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment