"include/vscode:/vscode.git/clone" did not exist on "efc6207bf23e2258ddb573b9a7b3e965bdb29c62"
Commit 992666e6 authored by Shucai Xiao's avatar Shucai Xiao Committed by mvermeulen
Browse files

Improve operators for onnxruntime (#405)



* improve unsqueeze to support negative axis and parsing scalar

* clang format

* add a test example for the negative axis of unsqueeze

* improve the squeeze operator to support negative axis

* clang format

* fixed a small bug in the lrn implementation

* clang format

* support negative axis in argmax and argmin

* clang format

* improve flatten to support negative axis

* clang format

* change softmax/logsoftmax to support negative axis

* clang format

* improve transpose by adding default perm

* clang format

* add one more dimens for tensor size

* add one more dimens for tensor size

* disable conv ops fusion for non-symmetric cases

* clang format

* fixed review comments

* move computing axis from the device function to the compute function

* clang format

* move computing axis from device function to the operator computing function

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 2ee0f9e8
...@@ -27,24 +27,29 @@ struct argmax ...@@ -27,24 +27,29 @@ struct argmax
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size()); int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0) if(axis >= n_dim || axis < -n_dim)
{ {
MIGRAPHX_THROW("ARGMAX: axis is out of range."); MIGRAPHX_THROW("ARGMAX: axis is out of range.");
} }
lens[axis] = 1; int64_t tuned_axis = (axis < 0) ? axis + n_dim : axis;
lens[tuned_axis] = 1;
return {shape::int64_type, lens}; return {shape::int64_type, lens};
} }
template <class T> template <class T>
int64_t calc_argmax(T& input, std::vector<std::size_t>& indices, size_t item_num) const int64_t calc_argmax(T& input,
int64_t tuned_axis,
std::vector<std::size_t>& indices,
size_t item_num) const
{ {
auto max_val = input(indices.begin(), indices.end()); auto max_val = input(indices.begin(), indices.end());
int64_t max_index = 0; int64_t max_index = 0;
for(std::size_t i = 1; i < item_num; ++i) for(std::size_t i = 1; i < item_num; ++i)
{ {
indices[axis] = i; indices[tuned_axis] = i;
auto cur_val = input(indices.begin(), indices.end()); auto cur_val = input(indices.begin(), indices.end());
if(max_val < cur_val) if(max_val < cur_val)
{ {
...@@ -59,13 +64,15 @@ struct argmax ...@@ -59,13 +64,15 @@ struct argmax
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto batch_item_num = args.front().get_shape().lens()[axis]; auto n_dim = args.front().get_shape().lens().size();
auto tuned_axis = axis < 0 ? axis + n_dim : axis;
auto batch_item_num = args.front().get_shape().lens()[tuned_axis];
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
par_for(output_shape.elements(), [&](auto i) { par_for(output_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i); auto data_idx = output_shape.multi(i);
output[i] = this->calc_argmax(input, data_idx, batch_item_num); output[i] = this->calc_argmax(input, tuned_axis, data_idx, batch_item_num);
}); });
}); });
}); });
......
...@@ -27,24 +27,28 @@ struct argmin ...@@ -27,24 +27,28 @@ struct argmin
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size()); int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0) if(axis >= n_dim || axis < -n_dim)
{ {
MIGRAPHX_THROW("ARGMIN: axis is out of range."); MIGRAPHX_THROW("ARGMIN: axis is out of range.");
} }
lens[axis] = 1; int64_t tuned_axis = (axis < 0) ? axis + n_dim : axis;
lens[tuned_axis] = 1;
return {shape::int64_type, lens}; return {shape::int64_type, lens};
} }
template <class T> template <class T>
int64_t calc_argmin(T& input, std::vector<std::size_t>& indices, size_t item_num) const int64_t calc_argmin(T& input,
int64_t tuned_axis,
std::vector<std::size_t>& indices,
size_t item_num) const
{ {
auto min_val = input(indices.begin(), indices.end()); auto min_val = input(indices.begin(), indices.end());
int64_t min_index = 0; int64_t min_index = 0;
for(std::size_t i = 1; i < item_num; ++i) for(std::size_t i = 1; i < item_num; ++i)
{ {
indices[axis] = i; indices[tuned_axis] = i;
auto cur_val = input(indices.begin(), indices.end()); auto cur_val = input(indices.begin(), indices.end());
if(min_val > cur_val) if(min_val > cur_val)
{ {
...@@ -59,13 +63,15 @@ struct argmin ...@@ -59,13 +63,15 @@ struct argmin
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
std::size_t batch_item_num = args.front().get_shape().lens()[axis]; auto n_dim = args.front().get_shape().lens().size();
auto tuned_axis = axis < 0 ? axis + n_dim : axis;
std::size_t batch_item_num = args.front().get_shape().lens()[tuned_axis];
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
par_for(output_shape.elements(), [&](auto i) { par_for(output_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i); auto data_idx = output_shape.multi(i);
output[i] = this->calc_argmin(input, data_idx, batch_item_num); output[i] = this->calc_argmin(input, tuned_axis, data_idx, batch_item_num);
}); });
}); });
}); });
......
...@@ -18,7 +18,7 @@ namespace op { ...@@ -18,7 +18,7 @@ namespace op {
struct flatten struct flatten
{ {
uint64_t axis = 0; int64_t axis = 1;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -31,15 +31,18 @@ struct flatten ...@@ -31,15 +31,18 @@ struct flatten
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
auto&& lens = inputs.front().lens(); auto&& lens = inputs.front().lens();
int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis > lens.size()) if(axis > n_dim or axis < -n_dim)
{ {
MIGRAPHX_THROW("axis for flatten must be less than tensor rank"); MIGRAPHX_THROW("FLATTEN: axis for flatten is out of range");
} }
auto x =
std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{}); auto tuned_axis = (axis < 0) ? axis + n_dim : axis;
auto y =
std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{}); auto x = std::accumulate(
lens.begin(), lens.begin() + tuned_axis, std::size_t{1}, std::multiplies<>{});
auto y = std::accumulate(
lens.begin() + tuned_axis, lens.end(), std::size_t{1}, std::multiplies<>{});
return {inputs.at(0).type(), {x, y}}; return {inputs.at(0).type(), {x, y}};
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
......
...@@ -11,7 +11,7 @@ namespace op { ...@@ -11,7 +11,7 @@ namespace op {
struct logsoftmax struct logsoftmax
{ {
int axis = 1; int64_t axis = 1;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -23,7 +23,8 @@ struct logsoftmax ...@@ -23,7 +23,8 @@ struct logsoftmax
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1).standard(); check_shapes{inputs}.has(1).standard();
if(axis < 0 || axis >= inputs[0].lens().size()) int64_t n_dim = static_cast<int64_t>(inputs[0].lens().size());
if(axis < -n_dim || axis >= n_dim)
{ {
MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) + MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) +
" is out of range"); " is out of range");
......
...@@ -11,7 +11,7 @@ namespace op { ...@@ -11,7 +11,7 @@ namespace op {
struct softmax struct softmax
{ {
int axis = 1; int64_t axis = 1;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -23,7 +23,8 @@ struct softmax ...@@ -23,7 +23,8 @@ struct softmax
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1).standard(); check_shapes{inputs}.has(1).standard();
if(axis < 0 || axis >= inputs[0].lens().size()) int64_t n_dim = inputs[0].lens().size();
if(axis < -n_dim || axis >= n_dim)
{ {
MIGRAPHX_THROW("SoftMax: input axis value " + std::to_string(axis) + MIGRAPHX_THROW("SoftMax: input axis value " + std::to_string(axis) +
" is out of range"); " is out of range");
......
...@@ -33,13 +33,21 @@ struct squeeze ...@@ -33,13 +33,21 @@ struct squeeze
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
if(std::any_of(
axes.begin(), axes.end(), [&](auto axis) { return input_shape.lens()[axis] != 1; })) // change to support negative axis value
std::vector<int64_t> tuned_axes(axes.size());
std::transform(axes.begin(), axes.end(), tuned_axes.begin(), [&](auto i) {
return i >= 0 ? i : i + old_lens.size();
});
if(std::any_of(tuned_axes.begin(), tuned_axes.end(), [&](auto axis) {
return old_lens[axis] != 1;
}))
{ {
MIGRAPHX_THROW("squeeze axis dimension should be equal to 1"); MIGRAPHX_THROW("squeeze axis dimension should be equal to 1");
} }
std::vector<std::size_t> new_lens; std::vector<std::size_t> new_lens;
if(axes.empty()) if(tuned_axes.empty())
{ {
std::copy_if(old_lens.begin(), std::copy_if(old_lens.begin(),
old_lens.end(), old_lens.end(),
...@@ -50,7 +58,7 @@ struct squeeze ...@@ -50,7 +58,7 @@ struct squeeze
{ {
for(std::size_t i = 0; i < old_lens.size(); i++) for(std::size_t i = 0; i < old_lens.size(); i++)
{ {
if(std::find(axes.begin(), axes.end(), i) == axes.end()) if(std::find(tuned_axes.begin(), tuned_axes.end(), i) == tuned_axes.end())
{ {
new_lens.push_back(old_lens[i]); new_lens.push_back(old_lens[i]);
} }
......
...@@ -34,13 +34,22 @@ struct transpose ...@@ -34,13 +34,22 @@ struct transpose
auto input_lens = input.lens(); auto input_lens = input.lens();
auto input_strides = input.strides(); auto input_strides = input.strides();
auto t = input.type(); auto t = input.type();
if(dims.size() != input_lens.size()) auto tuned_dims = dims;
// if not perm provided, reverse the dims
if(tuned_dims.empty())
{
tuned_dims.resize(input_lens.size());
std::iota(tuned_dims.begin(), tuned_dims.end(), 0);
std::reverse(tuned_dims.begin(), tuned_dims.end());
}
if(tuned_dims.size() != input_lens.size())
{ {
MIGRAPHX_THROW("Permutation has wrong number of axes"); MIGRAPHX_THROW("Permutation has wrong number of axes");
} }
std::vector<int64_t> axes(dims.size()); std::vector<int64_t> axes(tuned_dims.size());
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
if(!std::is_permutation(axes.begin(), axes.end(), dims.begin())) if(!std::is_permutation(axes.begin(), axes.end(), tuned_dims.begin()))
{ {
MIGRAPHX_THROW("Invalid permutation"); MIGRAPHX_THROW("Invalid permutation");
} }
...@@ -48,8 +57,8 @@ struct transpose ...@@ -48,8 +57,8 @@ struct transpose
std::vector<size_t> output_strides(input_lens.size()); std::vector<size_t> output_strides(input_lens.size());
for(std::size_t i = 0; i < output_lens.size(); i++) for(std::size_t i = 0; i < output_lens.size(); i++)
{ {
output_lens[i] = input_lens[dims[i]]; output_lens[i] = input_lens[tuned_dims[i]];
output_strides[i] = input_strides[dims[i]]; output_strides[i] = input_strides[tuned_dims[i]];
} }
return {t, output_lens, output_strides}; return {t, output_lens, output_strides};
} }
......
...@@ -38,11 +38,18 @@ struct unsqueeze ...@@ -38,11 +38,18 @@ struct unsqueeze
return shape{type, old_lens}; return shape{type, old_lens};
std::size_t new_size = old_lens.size() + axes.size(); std::size_t new_size = old_lens.size() + axes.size();
// in case of axes to be negative, tune to positive
std::vector<int64_t> tuned_axes(axes.size());
std::transform(axes.begin(), axes.end(), tuned_axes.begin(), [new_size](auto i) {
return i >= 0 ? i : i + new_size;
});
std::vector<std::size_t> new_lens(new_size); std::vector<std::size_t> new_lens(new_size);
std::size_t p = 0; std::size_t p = 0;
for(std::size_t i = 0; i < new_size; i++) for(std::size_t i = 0; i < new_size; i++)
{ {
if(std::find(axes.begin(), axes.end(), i) != axes.end()) if(std::find(tuned_axes.begin(), tuned_axes.end(), i) != tuned_axes.end())
{ {
new_lens[i] = 1; new_lens[i] = 1;
} }
......
...@@ -231,8 +231,15 @@ struct onnx_parser ...@@ -231,8 +231,15 @@ struct onnx_parser
auto s0 = arg0->get_shape().lens(); auto s0 = arg0->get_shape().lens();
auto s1 = arg1->get_shape().lens(); auto s1 = arg1->get_shape().lens();
auto out_lens = compute_broadcasted_lens(s0, s1); auto out_lens = compute_broadcasted_lens(s0, s1);
auto l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1); auto l0 = arg0;
if(arg0->get_shape().lens() != out_lens)
l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);
auto l1 = arg1;
if(arg1->get_shape().lens() != out_lens)
l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);
return prog.add_instruction(x, l0, l1); return prog.add_instruction(x, l0, l1);
} }
else else
...@@ -283,7 +290,7 @@ struct onnx_parser ...@@ -283,7 +290,7 @@ struct onnx_parser
const attribute_map& attributes, const attribute_map& attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
int axis = 1; int64_t axis = 1;
if(contains(attributes, "axis")) if(contains(attributes, "axis"))
{ {
axis = parse_value(attributes.at("axis")).at<int>(); axis = parse_value(attributes.at("axis")).at<int>();
...@@ -463,7 +470,7 @@ struct onnx_parser ...@@ -463,7 +470,7 @@ struct onnx_parser
instruction_ref instruction_ref
parse_flatten(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_flatten(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
uint64_t axis = 1; int64_t axis = 1;
if(contains(attributes, "axis")) if(contains(attributes, "axis"))
{ {
axis = parse_value(attributes.at("axis")).at<int>(); axis = parse_value(attributes.at("axis")).at<int>();
...@@ -1696,6 +1703,9 @@ struct onnx_parser ...@@ -1696,6 +1703,9 @@ struct onnx_parser
} }
return batch_size; return batch_size;
}); });
if(dims.empty())
return {shape_type};
return {shape_type, dims}; return {shape_type, dims};
} }
......
...@@ -144,13 +144,14 @@ struct cpu_lrn ...@@ -144,13 +144,14 @@ struct cpu_lrn
int height = output_shape.lens()[2]; int height = output_shape.lens()[2];
int width = output_shape.lens()[3]; int width = output_shape.lens()[3];
float alphaoverarea = op.alpha / float(op.size); float alphaoverarea = op.alpha / float(op.size);
int radius = (op.size - 1) / 2; int radius_lower = (op.size - 1) / 2;
int radius_upper = op.size / 2 + 1;
par_dfor(n_batch, height, width)([&](int b, int h, int w) { par_dfor(n_batch, height, width)([&](int b, int h, int w) {
float scale = 0; float scale = 0;
dfor(channels)([&](int c) { dfor(channels)([&](int c) {
auto start = (c - radius) < 0 ? 0 : (c - radius); auto start = (c - radius_lower) < 0 ? 0 : (c - radius_lower);
auto end = (c + radius) > channels ? channels : (c + radius); auto end = (c + radius_upper) > channels ? channels : (c + radius_upper);
for(auto k = start; k < end; ++k) for(auto k = start; k < end; ++k)
{ {
scale += std::pow(input(b, k, h, w), 2); scale += std::pow(input(b, k, h, w), 2);
...@@ -599,8 +600,9 @@ struct cpu_softmax ...@@ -599,8 +600,9 @@ struct cpu_softmax
{ {
argument result{output_shape}; argument result{output_shape};
auto batch_lens = output_shape.lens(); auto batch_lens = output_shape.lens();
std::size_t n_dims = batch_lens[op.axis]; int64_t tuned_axis = (op.axis < 0) ? op.axis + args[0].get_shape().lens().size() : op.axis;
batch_lens[op.axis] = 1; std::size_t n_dims = batch_lens[tuned_axis];
batch_lens[tuned_axis] = 1;
shape batch_shape{shape::int32_type, batch_lens}; shape batch_shape{shape::int32_type, batch_lens};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
...@@ -612,26 +614,26 @@ struct cpu_softmax ...@@ -612,26 +614,26 @@ struct cpu_softmax
auto idx = batch_shape.multi(i); auto idx = batch_shape.multi(i);
for(std::size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
idx[op.axis] = j; idx[tuned_axis] = j;
batch_max[i] = std::max(batch_max[i], input(idx.begin(), idx.end())); batch_max[i] = std::max(batch_max[i], input(idx.begin(), idx.end()));
} }
for(std::size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
idx[op.axis] = j; idx[tuned_axis] = j;
std::size_t index = output_shape.index(idx); std::size_t index = output_shape.index(idx);
output[index] = std::exp(input[index] - batch_max[i]); output[index] = std::exp(input[index] - batch_max[i]);
} }
for(std::size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
idx[op.axis] = j; idx[tuned_axis] = j;
batch_sum[i] += output(idx.begin(), idx.end()); batch_sum[i] += output(idx.begin(), idx.end());
} }
for(std::size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
idx[op.axis] = j; idx[tuned_axis] = j;
output(idx.begin(), idx.end()) = output(idx.begin(), idx.end()) =
op.output()(output(idx.begin(), idx.end()), batch_sum[i]); op.output()(output(idx.begin(), idx.end()), batch_sum[i]);
} }
......
...@@ -14,7 +14,9 @@ shape hip_argmax::compute_shape(const std::vector<shape>& inputs) const ...@@ -14,7 +14,9 @@ shape hip_argmax::compute_shape(const std::vector<shape>& inputs) const
argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
device::argmax(ctx.get_stream().get(), args.back(), args.front(), op.axis); auto n_dim = args.front().get_shape().lens().size();
int64_t tuned_axis = (op.axis < 0) ? op.axis + n_dim : op.axis;
device::argmax(ctx.get_stream().get(), args.back(), args.front(), tuned_axis);
return args.back(); return args.back();
} }
......
...@@ -14,7 +14,9 @@ shape hip_argmin::compute_shape(const std::vector<shape>& inputs) const ...@@ -14,7 +14,9 @@ shape hip_argmin::compute_shape(const std::vector<shape>& inputs) const
argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argument>& args) const argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
device::argmin(ctx.get_stream().get(), args.back(), args.front(), op.axis); auto n_dim = args.front().get_shape().lens().size();
int64_t tuned_axis = (op.axis < 0) ? op.axis + n_dim : op.axis;
device::argmin(ctx.get_stream().get(), args.back(), args.front(), tuned_axis);
return args.back(); return args.back();
} }
......
...@@ -39,7 +39,12 @@ constexpr void visit_tensor_size(index_int n, F f) ...@@ -39,7 +39,12 @@ constexpr void visit_tensor_size(index_int n, F f)
f(std::integral_constant<index_int, 5>{}); f(std::integral_constant<index_int, 5>{});
break; break;
} }
default: throw std::runtime_error("Unknown tensor size"); case 6:
{
f(std::integral_constant<index_int, 6>{});
break;
}
default: throw std::runtime_error("Tensor size dim out of range");
} }
} }
......
...@@ -11,11 +11,10 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,11 +11,10 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, int axis) void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{ {
auto lens = result.get_shape().lens(); auto batch_lens = result.get_shape().lens();
auto batch_lens = lens; index_int batch_item_num = batch_lens[axis];
index_int batch_item_num = lens[axis];
batch_lens[axis] = 1; batch_lens[axis] = 1;
migraphx::shape batch_shape{result.get_shape().type(), batch_lens}; migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
......
...@@ -12,11 +12,10 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -12,11 +12,10 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis) void softmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{ {
auto lens = result.get_shape().lens(); auto batch_lens = result.get_shape().lens();
auto batch_lens = lens; index_int batch_item_num = batch_lens[axis];
index_int batch_item_num = lens[axis];
batch_lens[axis] = 1; batch_lens[axis] = 1;
migraphx::shape batch_shape{result.get_shape().type(), batch_lens}; migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
......
...@@ -148,6 +148,12 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins) ...@@ -148,6 +148,12 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
return false; return false;
if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd) if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd)
return false; return false;
// Do not fuse non-symmetric input
auto input_lens = ins->inputs().at(0)->get_shape().lens();
if(input_lens[2] != input_lens[3] or wei.lens()[2] != wei.lens()[3])
return false;
auto op = conv.op; auto op = conv.op;
// Dont fuse winograd for non-3x3s since there is no fused windograd for those configs // Dont fuse winograd for non-3x3s since there is no fused windograd for those configs
if(conv.algo == miopenConvolutionFwdAlgoWinograd and wei.lens()[2] != 3 and if(conv.algo == miopenConvolutionFwdAlgoWinograd and wei.lens()[2] != 3 and
......
...@@ -72,9 +72,8 @@ template <class Op> ...@@ -72,9 +72,8 @@ template <class Op>
void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int64_t axis) void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{ {
auto arg_shape = arg.get_shape(); auto arg_shape = arg.get_shape();
auto lens = arg_shape.lens(); auto batch_lens = arg_shape.lens();
auto batch_lens = lens; size_t batch_item_num = batch_lens[axis];
size_t batch_item_num = lens[axis];
batch_lens[axis] = 1; batch_lens[axis] = 1;
migraphx::shape batch_shape{arg_shape.type(), batch_lens}; migraphx::shape batch_shape{arg_shape.type(), batch_lens};
......
...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, int axis); void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis); void softmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -18,7 +18,9 @@ shape hip_logsoftmax::compute_shape(const std::vector<shape>& inputs) const ...@@ -18,7 +18,9 @@ shape hip_logsoftmax::compute_shape(const std::vector<shape>& inputs) const
argument argument
hip_logsoftmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const hip_logsoftmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
device::logsoftmax(ctx.get_stream().get(), args.back(), args.front(), op.axis); auto n_dim = args.front().get_shape().lens().size();
auto tuned_axis = (op.axis < 0) ? op.axis + n_dim : op.axis;
device::logsoftmax(ctx.get_stream().get(), args.back(), args.front(), tuned_axis);
return args.back(); return args.back();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment