Commit 8a079721 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'logsoftmax_operator' into seq2seq_example

parents 077970f7 07214d76
...@@ -974,6 +974,22 @@ struct softmax ...@@ -974,6 +974,22 @@ struct softmax
} }
}; };
struct logsoftmax
{
int axis = 1;
std::string name() const { return "logsoftmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
if(axis < 0 || axis >= inputs[0].lens().size())
{
MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) +
" is out of range");
}
return inputs.at(0);
}
};
struct flatten struct flatten
{ {
uint64_t axis = 0; uint64_t axis = 0;
......
...@@ -79,6 +79,7 @@ struct onnx_parser ...@@ -79,6 +79,7 @@ struct onnx_parser
add_mem_op("Gemm", &onnx_parser::parse_gemm); add_mem_op("Gemm", &onnx_parser::parse_gemm);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm); add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
add_mem_op("Softmax", &onnx_parser::parse_softmax); add_mem_op("Softmax", &onnx_parser::parse_softmax);
add_mem_op("LogSoftmax", &onnx_parser::parse_logsoftmax);
add_mem_op("Squeeze", &onnx_parser::parse_squeeze); add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze); add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
add_mem_op("Slice", &onnx_parser::parse_slice); add_mem_op("Slice", &onnx_parser::parse_slice);
...@@ -228,6 +229,19 @@ struct onnx_parser ...@@ -228,6 +229,19 @@ struct onnx_parser
return prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1])}}, s); return prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1])}}, s);
} }
instruction_ref parse_logsoftmax(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int axis = 1;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
}
return prog.add_instruction(op::logsoftmax{axis}, std::move(args));
}
instruction_ref instruction_ref
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
......
...@@ -613,6 +613,63 @@ struct softmax2d ...@@ -613,6 +613,63 @@ struct softmax2d
} }
}; };
struct cpu_logsoftmax
{
op::logsoftmax op;
std::string name() const { return "cpu::logsoftmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template <typename T>
std::size_t compute_batch_index(const T& idx, shape& batch_shape, int axis) const
{
std::vector<std::size_t> batch_idx(idx.begin(), idx.begin() + axis);
return batch_shape.index(batch_idx.begin(), batch_idx.end());
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto lens = output_shape.lens();
std::vector<std::size_t> batch_lens(lens.begin(), lens.begin() + op.axis);
shape batch_shape{migraphx::shape::uint32_type, batch_lens};
// use float for now, need to change later
visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type;
std::vector<value_type> batch_max(batch_shape.elements(),
std::numeric_limits<value_type>::lowest());
shape_for_each(output_shape, [&](auto idx) {
auto index = compute_batch_index(idx, batch_shape, op.axis);
batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end()));
});
shape_for_each(output_shape, [&](auto idx) {
auto index = compute_batch_index(idx, batch_shape, op.axis);
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end()) - batch_max[index];
});
std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
shape_for_each(output_shape, [&](auto idx) {
auto index = compute_batch_index(idx, batch_shape, op.axis);
batch_sum[index] += std::exp(output(idx.begin(), idx.end()));
});
for(std::size_t i = 0; i < batch_sum.size(); ++i)
{
batch_sum[i] = std::log(batch_sum[i]);
}
shape_for_each(output_shape, [&](auto idx) {
auto index = compute_batch_index(idx, batch_shape, op.axis);
output(idx.begin(), idx.end()) =
input(idx.begin(), idx.end()) - batch_max[index] - batch_sum[index];
});
});
return result;
}
};
struct add_op struct add_op
{ {
std::string name() const { return "add"; } std::string name() const { return "add"; }
...@@ -723,6 +780,7 @@ struct cpu_apply ...@@ -723,6 +780,7 @@ struct cpu_apply
apply_map["pad"] = extend_op<cpu_pad, op::pad>(); apply_map["pad"] = extend_op<cpu_pad, op::pad>();
apply_map["concat"] = extend_op<cpu_concat, op::concat>(); apply_map["concat"] = extend_op<cpu_concat, op::concat>();
apply_map["gather"] = extend_op<cpu_gather, op::gather>(); apply_map["gather"] = extend_op<cpu_gather, op::gather>();
apply_map["logsoftmax"] = extend_op<cpu_logsoftmax, op::logsoftmax>();
apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>(); apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>();
apply_map["elu"] = extend_op<cpu_unary<elu_op>, op::elu>(); apply_map["elu"] = extend_op<cpu_unary<elu_op>, op::elu>();
apply_map["identity"] = simple_op<cpu_unary<identity_op>>(); apply_map["identity"] = simple_op<cpu_unary<identity_op>>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment