"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "7a2edaf474195e8dcd72a0906b39d90b3ae88579"
Commit e17c0bec authored by Paul's avatar Paul
Browse files

Brodcast data

parent 422c825b
...@@ -37,7 +37,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -37,7 +37,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
auto weights_lens = weights.get_shape().lens(); auto weights_lens = weights.get_shape().lens();
auto conv_lens = conv_ins->get_shape().lens(); auto conv_lens = conv_ins->get_shape().lens();
argument new_weights{weights.get_shape()}; argument new_weights{weights.get_shape()};
argument new_bias{conv_ins->get_shape()}; argument new_bias{bias.get_shape()};
visit_all(weights, gamma, bias, mean, variance, new_weights, new_bias)( visit_all(weights, gamma, bias, mean, variance, new_weights, new_bias)(
[&](auto weights2, [&](auto weights2,
auto gamma2, auto gamma2,
...@@ -51,9 +51,9 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -51,9 +51,9 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
new_weights2(k, c, h, w) = new_weights2(k, c, h, w) =
gamma2(k) / std::sqrt(variance2(k) + epsilon) * weights2(k, c, h, w); gamma2(k) / std::sqrt(variance2(k) + epsilon) * weights2(k, c, h, w);
}); });
dfor(conv_lens[0], conv_lens[1], conv_lens[2], conv_lens[3])( dfor(new_bias.get_shape().elements())(
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) { [&](std::size_t c) {
new_bias2(n, c, h, w) = new_bias2(c) =
bias2(c) - (mean2(c) / std::sqrt(variance2(c) + epsilon)); bias2(c) - (mean2(c) / std::sqrt(variance2(c) + epsilon));
}); });
}); });
...@@ -61,7 +61,8 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -61,7 +61,8 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()}); auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()});
auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()}); auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()});
auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->arguments[0], l_weights}); auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->arguments[0], l_weights});
p.replace_instruction(ins, add{}, {c, l_bias}); auto b = p.insert_instruction(ins, broadcast{1}, c, l_bias);
p.replace_instruction(ins, add{}, {c, b});
} }
} }
} // namespace migraph } // namespace migraph
...@@ -12,7 +12,7 @@ struct hip_add_relu ...@@ -12,7 +12,7 @@ struct hip_add_relu
std::string name() const { return "hip::add_relu"; } std::string name() const { return "hip::add_relu"; }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs}.has(3).standard(); check_shapes{inputs, *this}.has(3);
return inputs.front(); return inputs.front();
} }
argument compute(context&, const shape&, const std::vector<argument>& args) const argument compute(context&, const shape&, const std::vector<argument>& args) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment