Unverified Commit 0d52d99f authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge pull request #264 from ROCmSoftwarePlatform/hip_softmax

Add hip softmax
parents 0c798442 e1758782
...@@ -18,10 +18,23 @@ namespace op { ...@@ -18,10 +18,23 @@ namespace op {
struct softmax struct softmax
{ {
int axis = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "softmax"; } std::string name() const { return "softmax"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1).only_dims(4); check_shapes{inputs}.has(1).standard();
if(axis < 0 || axis >= inputs[0].lens().size())
{
MIGRAPHX_THROW("SoftMax: input axis value " + std::to_string(axis) +
" is out of range");
}
return inputs.at(0); return inputs.at(0);
} }
}; };
......
...@@ -517,40 +517,60 @@ struct cpu_unary ...@@ -517,40 +517,60 @@ struct cpu_unary
} }
}; };
struct softmax2d struct cpu_softmax
{ {
std::string name() const { return "cpu::softmax2d"; } op::softmax op;
shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const template <class Self, class F>
{ static auto reflect(Self& self, F f)
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type;
auto nb = input.get_shape().lens()[0];
auto nc = input.get_shape().lens()[1];
auto nh = input.get_shape().lens()[2];
auto nw = input.get_shape().lens()[3];
dfor(nb, nh, nw)([&](std::size_t b, std::size_t i, std::size_t j) {
value_type cmax = std::numeric_limits<value_type>::lowest();
for(std::size_t c = 0; c < nc; c++)
{
cmax = std::max(cmax, input(b, c, i, j));
}
for(std::size_t c = 0; c < nc; c++)
{ {
output(b, c, i, j) = std::exp(input(b, c, i, j) - cmax); return migraphx::reflect(self.op, f);
} }
value_type sum = value_type(0);
for(std::size_t c = 0; c < nc; c++) std::string name() const { return "cpu::softmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template <typename T>
std::size_t compute_batch_index(T idx, shape& batch_shape, int axis) const
{ {
sum += output(b, c, i, j); idx[axis] = 0;
return batch_shape.index(idx);
} }
for(std::size_t c = 0; c < nc; c++)
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
output(b, c, i, j) = output(b, c, i, j) / sum; argument result{output_shape};
} auto batch_lens = output_shape.lens();
batch_lens[op.axis] = 1;
shape batch_shape{shape::int32_type, batch_lens};
visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type;
std::vector<value_type> batch_max(batch_shape.elements(),
std::numeric_limits<value_type>::lowest());
shape_for_each(output_shape, [&](auto idx) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end()));
}); });
shape_for_each(output_shape, [&](auto idx) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
output(idx.begin(), idx.end()) =
std::exp(input(idx.begin(), idx.end()) - batch_max[index]);
}); });
std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
shape_for_each(output_shape, [&](auto idx) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
batch_sum[index] += output(idx.begin(), idx.end());
});
shape_for_each(output_shape, [&](auto idx) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
output(idx.begin(), idx.end()) /= batch_sum[index];
});
});
return result; return result;
} }
}; };
...@@ -646,7 +666,7 @@ struct cpu_apply ...@@ -646,7 +666,7 @@ struct cpu_apply
apply_map["logsoftmax"] = extend_op<cpu_logsoftmax, op::logsoftmax>(); apply_map["logsoftmax"] = extend_op<cpu_logsoftmax, op::logsoftmax>();
apply_map["lrn"] = extend_op<cpu_lrn, op::lrn>(); apply_map["lrn"] = extend_op<cpu_lrn, op::lrn>();
apply_map["pad"] = extend_op<cpu_pad, op::pad>(); apply_map["pad"] = extend_op<cpu_pad, op::pad>();
apply_map["softmax"] = simple_op<softmax2d>(); apply_map["softmax"] = extend_op<cpu_softmax, op::softmax>();
} }
void apply() void apply()
......
...@@ -27,6 +27,7 @@ add_library(migraphx_device ...@@ -27,6 +27,7 @@ add_library(migraphx_device
device/add_relu.cpp device/add_relu.cpp
device/contiguous.cpp device/contiguous.cpp
device/logsoftmax.cpp device/logsoftmax.cpp
device/softmax.cpp
device/convert.cpp device/convert.cpp
device/mul.cpp device/mul.cpp
device/concat.cpp device/concat.cpp
......
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/softmax.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument softmax(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
int axis)
{
auto lens = output_shape.lens();
auto batch_lens = lens;
size_t n_dims = lens[axis];
batch_lens[axis] = 1;
migraphx::shape batch_shape{shape::int32_type, batch_lens};
visit_all(args.back(), args.front())([&](auto output, auto input) {
const auto* input_ptr = device_cast(input.data());
auto* output_ptr = device_cast(output.data());
visit_tensor_size(batch_shape.lens().size(), [&](auto n_dim) {
hip_tensor_descriptor<n_dim> desc_batch(batch_shape);
hip_tensor_descriptor<n_dim> desc_data(output_shape);
// each thread is for one item in the batch
gs_launch(stream, batch_shape.elements())([=](auto i) {
auto batch_idx = desc_batch.multi(i);
auto data_idx = batch_idx;
// get max
auto batch_max = input_ptr[desc_data.linear(batch_idx)];
for(std::size_t j = 1; j < n_dims; ++j)
{
data_idx[axis] = j;
batch_max = std::max(to_hip_type(batch_max),
to_hip_type(input_ptr[desc_data.linear(data_idx)]));
}
for(std::size_t j = 0; j < n_dims; ++j)
{
data_idx[axis] = j;
auto idx = desc_data.linear(data_idx);
output_ptr[idx] = input_ptr[idx] - batch_max;
}
for(std::size_t j = 0; j < n_dims; ++j)
{
data_idx[axis] = j;
auto idx = desc_data.linear(data_idx);
output_ptr[idx] = exp(to_hip_type(output_ptr[idx]));
}
auto batch_sum = output_ptr[desc_data.linear(batch_idx)];
for(std::size_t j = 1; j < n_dims; ++j)
{
data_idx[axis] = j;
batch_sum += output_ptr[desc_data.linear(data_idx)];
}
for(std::size_t j = 0; j < n_dims; ++j)
{
data_idx[axis] = j;
auto idx = desc_data.linear(data_idx);
output_ptr[idx] = output_ptr[idx] / batch_sum;
}
});
});
});
return args.back();
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_SOFTMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_SOFTMAX_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument softmax(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
int axis);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP #define MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP
#include <migraphx/shape.hpp> #include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/softmax.hpp> #include <migraphx/op/softmax.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -30,6 +44,26 @@ struct miopen_softmax ...@@ -30,6 +44,26 @@ struct miopen_softmax
} }
}; };
struct hip_softmax
{
op::softmax op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::softmax"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -99,7 +99,7 @@ struct miopen_apply ...@@ -99,7 +99,7 @@ struct miopen_apply
add_extend_op<miopen_gemm, op::dot>("dot"); add_extend_op<miopen_gemm, op::dot>("dot");
add_extend_op<miopen_contiguous, op::contiguous>("contiguous"); add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
add_extend_op<hip_concat, op::concat>("concat"); add_extend_op<hip_concat, op::concat>("concat");
add_extend_op<miopen_softmax, op::softmax>("softmax"); add_extend_op<hip_softmax, op::softmax>("softmax");
add_extend_op<hip_logsoftmax, op::logsoftmax>("logsoftmax"); add_extend_op<hip_logsoftmax, op::logsoftmax>("logsoftmax");
add_extend_op<hip_gather, op::gather>("gather"); add_extend_op<hip_gather, op::gather>("gather");
add_extend_op<hip_pad, op::pad>("pad"); add_extend_op<hip_pad, op::pad>("pad");
......
#include <migraphx/gpu/softmax.hpp> #include <migraphx/gpu/softmax.hpp>
#include <migraphx/gpu/device/softmax.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
namespace migraphx { namespace migraphx {
...@@ -30,6 +31,19 @@ argument miopen_softmax::compute(context& ctx, ...@@ -30,6 +31,19 @@ argument miopen_softmax::compute(context& ctx,
return args[1]; return args[1];
} }
shape hip_softmax::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
return op.compute_shape({inputs.at(0)});
}
argument hip_softmax::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
return device::softmax(ctx.get_stream().get(), output_shape, args, op.axis);
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -24,7 +24,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -24,7 +24,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct tf_parser struct tf_parser
{ {
using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>; using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>;
using node_map = std::unordered_map<std::string, tensorflow::NodeDef>; using node_map = std::map<std::string, tensorflow::NodeDef>;
// using input_node_map = std::unordered_map<std::string, std::unordered_set<std::string>>; // using input_node_map = std::unordered_map<std::string, std::unordered_set<std::string>>;
using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>; using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>;
...@@ -277,29 +277,6 @@ struct tf_parser ...@@ -277,29 +277,6 @@ struct tf_parser
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
op::convolution op; op::convolution op;
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
}
else if(pad_mode.find("EXPLICIT") != std::string::npos)
{
std::vector<size_t> padding;
copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
{
MIGRAPHX_THROW("migraphx does not support asymetric padding");
}
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
}
if(contains(attributes, "strides")) if(contains(attributes, "strides"))
{ {
std::vector<size_t> stride; std::vector<size_t> stride;
...@@ -339,6 +316,34 @@ struct tf_parser ...@@ -339,6 +316,34 @@ struct tf_parser
} }
} }
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
}
else if(pad_mode.find("VALID") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::valid;
}
else if(pad_mode.find("EXPLICIT") != std::string::npos)
{
std::vector<size_t> padding;
copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
{
MIGRAPHX_THROW("migraphx does not support asymetric padding");
}
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
}
return prog.add_instruction(op, {args[0], weights}); return prog.add_instruction(op, {args[0], weights});
} }
......
...@@ -929,6 +929,24 @@ TEST_CASE(maxpool_test) ...@@ -929,6 +929,24 @@ TEST_CASE(maxpool_test)
EXPECT(migraphx::verify_range(results_vector, c)); EXPECT(migraphx::verify_range(results_vector, c));
} }
TEST_CASE(softmax_simple_test)
{
migraphx::program p;
std::vector<float> a = {0.25, 0.75};
std::vector<float> s = {0.377541, 0.622459};
migraphx::shape a_shape{migraphx::shape::float_type, {1, 2}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
p.add_instruction(migraphx::op::softmax{1}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
for(auto v : results_vector)
std::cout << v << "\t";
std::cout << std::endl;
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(softmax_test) TEST_CASE(softmax_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -569,13 +569,13 @@ struct test_sub2 : verify_program<test_sub2> ...@@ -569,13 +569,13 @@ struct test_sub2 : verify_program<test_sub2>
} }
}; };
struct test_softmax : verify_program<test_softmax> struct test_softmax1 : verify_program<test_softmax1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 3, 4, 2}}); auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 3, 3, 4}});
p.add_instruction(migraphx::op::softmax{}, x); p.add_instruction(migraphx::op::softmax{0}, x);
return p; return p;
} }
}; };
...@@ -592,6 +592,25 @@ struct test_softmax2 : verify_program<test_softmax2> ...@@ -592,6 +592,25 @@ struct test_softmax2 : verify_program<test_softmax2>
} }
}; };
template <int Axis>
struct test_softmax : verify_program<test_softmax<Axis>>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}};
auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::softmax{Axis}, param);
return p;
}
};
template struct test_softmax<0>;
template struct test_softmax<1>;
template struct test_softmax<2>;
template struct test_softmax<3>;
struct test_conv : verify_program<test_conv> struct test_conv : verify_program<test_conv>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
...@@ -346,53 +346,40 @@ TEST_CASE(gather) ...@@ -346,53 +346,40 @@ TEST_CASE(gather)
} }
} }
TEST_CASE(logsoftmax) template <class T>
void test_softmax_variations()
{ {
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 0; expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{0}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 1; expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{1}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 2; expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{2}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 3; expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{3}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 4; int axis = 4;
throws_shape(migraphx::op::logsoftmax{axis}, input); throws_shape(T{axis}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = -1;
throws_shape(migraphx::op::logsoftmax{axis}, input);
} }
} }
TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); }
// 2 inputs arguments // 2 inputs arguments
TEST_CASE(matmul) TEST_CASE(matmul)
{ {
......
...@@ -178,9 +178,9 @@ TEST_CASE(mean_test) ...@@ -178,9 +178,9 @@ TEST_CASE(mean_test)
p.add_literal(l); p.add_literal(l);
migraphx::op::pooling op; migraphx::op::pooling op;
op.lengths = {16, 16}; op.lengths = {16, 16};
p.add_instruction(op, l0);
auto l3 = p.add_instruction(op, l0); auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3); p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
p.add_instruction(op, l0);
auto prog = migraphx::parse_tf("mean_test.pb", false); auto prog = migraphx::parse_tf("mean_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -195,9 +195,9 @@ TEST_CASE(mean_test_nhwc) ...@@ -195,9 +195,9 @@ TEST_CASE(mean_test_nhwc)
p.add_literal(l); p.add_literal(l);
migraphx::op::pooling op; migraphx::op::pooling op;
op.lengths = {16, 16}; op.lengths = {16, 16};
p.add_instruction(op, l0);
auto l3 = p.add_instruction(op, l0); auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3); p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
p.add_instruction(op, l0);
auto prog = migraphx::parse_tf("mean_test_nhwc.pb", true); auto prog = migraphx::parse_tf("mean_test_nhwc.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -271,8 +271,8 @@ TEST_CASE(pooling_test) ...@@ -271,8 +271,8 @@ TEST_CASE(pooling_test)
max_pool_op.stride = {2, 2}; max_pool_op.stride = {2, 2};
avg_pool_op.lengths = {2, 2}; avg_pool_op.lengths = {2, 2};
max_pool_op.lengths = {2, 2}; max_pool_op.lengths = {2, 2};
p.add_instruction(max_pool_op, l0);
p.add_instruction(avg_pool_op, l0); p.add_instruction(avg_pool_op, l0);
p.add_instruction(max_pool_op, l0);
auto prog = migraphx::parse_tf("pooling_test.pb", true); auto prog = migraphx::parse_tf("pooling_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment