Commit ffebc6a8 authored by Paul's avatar Paul
Browse files

Fix verify bugs

parent c47fa173
...@@ -223,8 +223,9 @@ struct reduce_op ...@@ -223,8 +223,9 @@ struct reduce_op
} }
else if(ins->name() == "reduce_mean") else if(ins->name() == "reduce_mean")
{ {
auto reduce_elements = get_reduce_elements(ins->inputs()); auto s = ins->inputs().front()->get_shape();
auto reduce_type = ins->inputs().front()->get_shape().type(); auto reduce_elements = s.elements() / ins->get_shape().elements();
auto reduce_type = s.type();
r.reduction = "op::sum{}"; r.reduction = "op::sum{}";
std::string mean = "op::mean{" + std::to_string(reduce_elements) + "}"; std::string mean = "op::mean{" + std::to_string(reduce_elements) + "}";
// Use float accumulator when reduction size is too large for half // Use float accumulator when reduction size is too large for half
......
...@@ -47,7 +47,7 @@ ${preamble} ...@@ -47,7 +47,7 @@ ${preamble}
extern "C" { extern "C" {
__global__ void ${kernel}(${params}) __global__ void ${kernel}(${params})
{ {
transform_args(make_tensors(), ${transformers})(${args})([](auto y, auto... xs) { transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) {
fused_reduce<reduce::${algo}, ${reduced}>(y, partial(${lambda})(xs...)); fused_reduce<reduce::${algo}, ${reduced}>(y, partial(${lambda})(xs...));
}); });
} }
...@@ -58,11 +58,6 @@ __global__ void ${kernel}(${params}) ...@@ -58,11 +58,6 @@ __global__ void ${kernel}(${params})
)__migraphx__"; )__migraphx__";
static std::size_t get_reduce_elements(const std::vector<shape>& inputs)
{
return inputs.front().elements() / inputs.back().elements();
}
template <class T> template <class T>
static shape get_reduced_shape(const shape& s, const std::vector<T>& axes) static shape get_reduced_shape(const shape& s, const std::vector<T>& axes)
{ {
...@@ -115,7 +110,7 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler> ...@@ -115,7 +110,7 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
{ {
vec = vectorize::elements(ctx, faxis, options.virtual_inputs); vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
} }
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size; auto relements = reduced_shape.elements() / vec.size;
auto nelements = options.virtual_inputs.back().elements(); auto nelements = options.virtual_inputs.back().elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs, reduced_shape.lens())); auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs, reduced_shape.lens()));
if(algo == "block") if(algo == "block")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment