Unverified Commit 70ba8213 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Gpu batchnorm (#564)



* Initial cpu conv-nd

* Formatting

* Make index signed

* Formatting

* Assert the indices are greater than 0

* Use equal instead of lexicographical_compare

* Formatting

* change the batchnorm cpu implementation to support multiple input dimensions

* clang format

* add unit tests for cpu batch_norm nd implementation

* clang format

* support nd batchnormalization

* clang format

* add rewrite batch_norm unit tests

* clang format

* remove a unit test

* Fix tidy errors

* Formatting

* Handle different types

* Formatting

* Fix nested visits

* Formatting

* Add 3d conv test

* Formatting

* revert unnecessary changes

* remove a print line

* Fix ICE

* Formatting

* fix the per_activation mode of 2d

* clang format

* code clean up

* clang format

* add 1d and 3d gpu unit test

* clang format

* add unit test for rewrite_batchnorm

* clang format

* additional refinement

* fix review comments

* added a unit test to have more code coverage
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 9d16eaca
......@@ -42,7 +42,7 @@ struct convolution
{
if(not(padding.size() == stride.size() and padding.size() == dilation.size()))
{
MIGRAPHX_THROW("convolution: inconsistent attribute sizes");
MIGRAPHX_THROW("CONVOLUTION: inconsistent attribute sizes");
}
}
......@@ -50,6 +50,11 @@ struct convolution
{
check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3);
check_attribute_size();
// dim num of input and attribute should match
if(inputs[0].lens().size() != padding.size() + 2)
{
MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!");
}
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
......
......@@ -42,7 +42,7 @@ struct pooling
{
if(not(padding.size() == stride.size() and padding.size() == lengths.size()))
{
MIGRAPHX_THROW("pooling: inconsistent attribute sizes");
MIGRAPHX_THROW("POOLING: inconsistent attribute sizes");
}
}
......
......@@ -26,7 +26,8 @@ void rewrite_batchnorm::apply(program& p) const
if(any_of({gamma, bias, mean, variance}, [](auto arg) { return arg.empty(); }))
continue;
auto s = shape{ins->get_shape().type(), {ins->get_shape().lens()[1]}};
std::vector<std::size_t> lens = ins->inputs()[1]->get_shape().lens();
shape s{ins->get_shape().type(), lens};
// Get epsilon
auto bn_op = any_cast<op::batch_norm_inference>(ins->get_operator());
auto epsilon = bn_op.epsilon;
......
......@@ -8,16 +8,33 @@ namespace gpu {
shape miopen_batch_norm_inference::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(6);
check_shapes{inputs.data(), inputs.data() + 1, *this}.same_ndims().max_ndims(5);
return op.compute_shape({inputs.at(0), inputs.at(1), inputs.at(2), inputs.at(3), inputs.at(4)});
}
inline shape reshape_to_2d(const shape& input)
{
auto dims = input.lens();
if(dims.size() >= 4)
return input;
std::vector<size_t> new_dims(dims.begin(), dims.end());
std::size_t num = 4 - dims.size();
new_dims.insert(new_dims.end(), num, 1);
return {input.type(), new_dims};
}
argument miopen_batch_norm_inference::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
auto bn_desc = make_tensor(args[3].get_shape());
shape x_shape = args[0].get_shape();
shape y_shape = output_shape;
shape bn_shape = args[3].get_shape();
auto x_desc = make_tensor(reshape_to_2d(x_shape));
auto y_desc = make_tensor(reshape_to_2d(y_shape));
auto bn_desc = make_tensor(reshape_to_2d(bn_shape));
float alpha = 1.0;
float beta = 0.0f;
......
......@@ -432,21 +432,35 @@ struct miopen_apply
auto&& op = any_cast<op::batch_norm_inference>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape());
shape old_shape = ins->inputs().at(1)->get_shape();
std::vector<int64_t> new_shape{1, static_cast<int64_t>(old_shape.elements()), 1, 1};
auto reshape_op = op::reshape{new_shape};
auto input = ins->inputs()[0];
auto input_lens = input->get_shape().lens();
std::vector<int64_t> rsp_lens(input_lens.size(), 1);
// for per_activation case, also need to reshape input
if(op.bn_mode == op::batch_norm_inference::per_activation)
{
std::copy(input_lens.begin() + 1, input_lens.end(), rsp_lens.begin() + 1);
}
else
{
rsp_lens[1] = static_cast<int64_t>(old_shape.elements());
}
auto reshape_op = op::reshape{rsp_lens};
std::vector<instruction_ref> reshapes;
std::transform(ins->inputs().begin() + 1,
ins->inputs().end(),
std::back_inserter(reshapes),
[&](auto i) { return prog->insert_instruction(ins, reshape_op, i); });
return prog->replace_instruction(ins,
miopen_batch_norm_inference{op},
ins->inputs().at(0),
input,
reshapes[0],
reshapes[1],
reshapes[2],
reshapes[3],
output);
});
}
......
......@@ -1831,6 +1831,142 @@ struct test_batchnorm_inference : verify_program<test_batchnorm_inference>
}
};
struct test_batchnorm_1d : verify_program<test_batchnorm_1d>
{
const size_t size = 3;
const size_t channels = 3;
const size_t batches = 4;
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {batches, channels, size}};
migraphx::shape vars{migraphx::shape::float_type, {channels}};
auto x = p.add_parameter("x", s);
auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction(migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance);
return p;
}
};
struct test_batchnorm_3d : verify_program<test_batchnorm_3d>
{
const size_t d1 = 2;
const size_t d2 = 2;
const size_t d3 = 2;
const size_t channels = 2;
const size_t batches = 2;
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {batches, channels, d1, d2, d3}};
migraphx::shape vars{migraphx::shape::float_type, {channels}};
auto x = p.add_parameter("x", s);
auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction(migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance);
return p;
}
};
struct test_batchnorm_1d_per_actv : verify_program<test_batchnorm_1d_per_actv>
{
const size_t d1 = 5;
const size_t channels = 2;
const size_t batches = 3;
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {batches, channels, d1}};
migraphx::shape vars{migraphx::shape::float_type, {channels, d1}};
auto x = p.add_parameter("x", s);
auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction(
migraphx::op::batch_norm_inference{
1.0e-5, 0.96f, migraphx::op::batch_norm_inference::per_activation},
x,
scale,
bias,
mean,
variance);
return p;
}
};
struct test_batchnorm_2d_per_actv : verify_program<test_batchnorm_2d_per_actv>
{
const size_t d1 = 2;
const size_t d2 = 4;
const size_t channels = 2;
const size_t batches = 3;
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {batches, channels, d1, d2}};
migraphx::shape vars{migraphx::shape::float_type, {channels, d1, d2}};
auto x = p.add_parameter("x", s);
auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction(
migraphx::op::batch_norm_inference{
1.0e-6, 0.9f, migraphx::op::batch_norm_inference::per_activation},
x,
scale,
bias,
mean,
variance);
return p;
}
};
struct test_batchnorm_3d_per_actv : verify_program<test_batchnorm_3d_per_actv>
{
const size_t d1 = 2;
const size_t d2 = 4;
const size_t d3 = 5;
const size_t channels = 2;
const size_t batches = 3;
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {batches, channels, d1, d2, d3}};
migraphx::shape vars{migraphx::shape::float_type, {channels, d1, d2, d3}};
auto x = p.add_parameter("x", s);
auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction(
migraphx::op::batch_norm_inference{
1.0e-6, 0.8f, migraphx::op::batch_norm_inference::per_activation},
x,
scale,
bias,
mean,
variance);
return p;
}
};
struct test_clip : verify_program<test_clip>
{
migraphx::program create_program() const
......
......@@ -87,6 +87,8 @@ TEST_CASE(convolution_shape)
migraphx::op::convolution{{0, 0, 0}, {1, 1, 1}, {1, 1, 1}},
input_3d,
weights_3d);
throws_shape(migraphx::op::convolution{}, input_3d, weights_3d);
}
TEST_CASE(deconvolution_shape)
......
......@@ -134,6 +134,77 @@ TEST_CASE(as_literal)
visit_all(result1, result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
}
TEST_CASE(as_literal_1d)
{
migraphx::shape xs{migraphx::shape::float_type, {1, 3, 8}};
migraphx::shape ws{migraphx::shape::float_type, {4, 3, 1}};
migraphx::shape vars{migraphx::shape::float_type, {4}};
auto create_program = [&]() {
migraphx::program p;
auto x = p.add_literal(migraphx::generate_literal(xs, 1));
auto w = p.add_literal(migraphx::generate_literal(ws, 1));
auto conv = p.add_instruction(migraphx::op::convolution{{0}, {1}, {1}}, x, w);
auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction(migraphx::op::batch_norm_inference{}, conv, scale, bias, mean, variance);
return p;
};
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
migraphx::rewrite_batchnorm opt;
opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm));
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
auto result1 = p1.eval({}).back();
auto result2 = p2.eval({}).back();
visit_all(result1, result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
}
TEST_CASE(as_literal_3d)
{
migraphx::shape xs{migraphx::shape::float_type, {1, 3, 2, 4, 8}};
migraphx::shape ws{migraphx::shape::float_type, {4, 3, 1, 1, 1}};
migraphx::shape vars{migraphx::shape::float_type, {4}};
auto create_program = [&]() {
migraphx::program p;
migraphx::op::convolution conv_op;
conv_op.padding = {0, 0, 0};
conv_op.stride = {1, 1, 1};
conv_op.dilation = {1, 1, 1};
auto x = p.add_literal(migraphx::generate_literal(xs, 1));
auto w = p.add_literal(migraphx::generate_literal(ws, 1));
auto conv = p.add_instruction(conv_op, x, w);
auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction(migraphx::op::batch_norm_inference{}, conv, scale, bias, mean, variance);
return p;
};
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
migraphx::rewrite_batchnorm opt;
opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm));
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
auto result1 = p1.eval({}).back();
auto result2 = p2.eval({}).back();
visit_all(result1, result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
}
TEST_CASE(literal_reshape)
{
migraphx::shape xs{migraphx::shape::float_type, {1, 3, 8, 8}};
......@@ -142,17 +213,13 @@ TEST_CASE(literal_reshape)
auto create_program = [&]() {
migraphx::program p;
auto reshape = [&](auto ins) {
return p.add_instruction(migraphx::op::reshape{{1, 4, 1, 1}}, ins);
};
auto x = p.add_literal(migraphx::generate_literal(xs, 1));
auto w = p.add_literal(migraphx::generate_literal(ws, 1));
auto conv = p.add_instruction(migraphx::op::convolution{}, x, w);
auto scale = reshape(p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))));
auto bias = reshape(p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))));
auto mean = reshape(p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))));
auto variance = reshape(p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))));
auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction(migraphx::op::batch_norm_inference{}, conv, scale, bias, mean, variance);
return p;
};
......@@ -172,4 +239,46 @@ TEST_CASE(literal_reshape)
visit_all(result1, result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
}
TEST_CASE(literal_reshape_per_actv)
{
migraphx::shape xs{migraphx::shape::float_type, {1, 3, 8, 7, 4}};
migraphx::shape ws{migraphx::shape::float_type, {4, 3, 1, 1, 1}};
migraphx::shape vars{migraphx::shape::float_type, {4, 8, 7, 4}};
auto create_program = [&]() {
migraphx::program p;
auto x = p.add_literal(migraphx::generate_literal(xs, 1));
auto w = p.add_literal(migraphx::generate_literal(ws, 1));
auto conv =
p.add_instruction(migraphx::op::convolution{{0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, x, w);
auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction(
migraphx::op::batch_norm_inference{
1.0e-5, 0.88, migraphx::op::batch_norm_inference::per_activation},
conv,
scale,
bias,
mean,
variance);
return p;
};
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
migraphx::rewrite_batchnorm opt;
opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm));
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
auto result1 = p1.eval({}).back();
auto result2 = p2.eval({}).back();
visit_all(result1, result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment