Commit ffebc6a8 authored by Paul's avatar Paul
Browse files

Fix verify bugs

parent c47fa173
......@@ -223,8 +223,9 @@ struct reduce_op
}
else if(ins->name() == "reduce_mean")
{
auto reduce_elements = get_reduce_elements(ins->inputs());
auto reduce_type = ins->inputs().front()->get_shape().type();
auto s = ins->inputs().front()->get_shape();
auto reduce_elements = s.elements() / ins->get_shape().elements();
auto reduce_type = s.type();
r.reduction = "op::sum{}";
std::string mean = "op::mean{" + std::to_string(reduce_elements) + "}";
// Use float accumulator when reduction size is too large for half
......
......@@ -47,7 +47,7 @@ ${preamble}
extern "C" {
__global__ void ${kernel}(${params})
{
transform_args(make_tensors(), ${transformers})(${args})([](auto y, auto... xs) {
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) {
fused_reduce<reduce::${algo}, ${reduced}>(y, partial(${lambda})(xs...));
});
}
......@@ -58,11 +58,6 @@ __global__ void ${kernel}(${params})
)__migraphx__";
static std::size_t get_reduce_elements(const std::vector<shape>& inputs)
{
return inputs.front().elements() / inputs.back().elements();
}
template <class T>
static shape get_reduced_shape(const shape& s, const std::vector<T>& axes)
{
......@@ -115,7 +110,7 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
{
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
}
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto relements = reduced_shape.elements() / vec.size;
auto nelements = options.virtual_inputs.back().elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs, reduced_shape.lens()));
if(algo == "block")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment