Commit 0d877311 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix review comments

parent 95cdb221
......@@ -172,14 +172,14 @@ shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1);
check_shapes{input_shapes}.not_broadcasted();
auto a_strides = inputs[0].strides();
auto dim_0 = a_strides.size() - 2;
if(a_strides.size() > 2)
{
if(!std::all_of(a_strides.begin(), a_strides.begin() + dim_0, [&](auto batch_size) {
return std::all_of(a_strides.begin() + dim_0, a_strides.end(), [&](auto data_size) {
return batch_size >= data_size;
});
}))
auto dim_1 = a_strides.size() - 1;
auto dim_0 = dim_1 - 1;
auto matrix_size = std::max(a_strides[dim_0], a_strides[1]);
if (std::adjacent_find(a_strides.begin(), a_strides.begin() + dim_0, [&](auto i, auto j) {
return (i < j or i < matrix_size or j < matrix_size);
}) != a_strides.begin() + dim_0)
{
MIGRAPHX_THROW("DOT: batch size of a {" + to_string_range(a_strides) +
"} is transposed!");
......@@ -189,11 +189,12 @@ shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
auto b_strides = inputs[1].strides();
if(b_strides.size() > 2)
{
if(!std::all_of(b_strides.begin(), b_strides.begin() + dim_0, [&](auto batch_size) {
return std::all_of(b_strides.begin() + dim_0, b_strides.end(), [&](auto data_size) {
return batch_size >= data_size;
});
}))
auto dim_1 = b_strides.size() - 1;
auto dim_0 = dim_1 - 1;
auto matrix_size = std::max(b_strides[dim_0], b_strides[1]);
if (std::adjacent_find(b_strides.begin(), b_strides.begin() + dim_0, [&](auto i, auto j) {
return (i < j or i < matrix_size or j < matrix_size);
}) != b_strides.begin() + dim_0)
{
MIGRAPHX_THROW("DOT: batch size of b {" + to_string_range(b_strides) +
"} is transposed!");
......
......@@ -721,13 +721,18 @@ struct tf_parser
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int axis = 1;
int axis = -1;
auto num_dims = args[0]->get_shape().lens().size();
if(contains(attributes, "axis"))
{
axis = static_cast<int>(attributes.at("axis").i());
}
if(axis < 0)
{
axis += num_dims;
}
return prog.add_instruction(Op{axis}, std::move(args));
return prog.add_instruction(Op{axis}, make_contiguous(args[0]));
}
instruction_ref parse_squeeze(const std::string&,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment