Commit 6001b750 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix review comments

parent 52a96f69
...@@ -167,38 +167,27 @@ rb_type<T>* to_rocblas_type(T* x) ...@@ -167,38 +167,27 @@ rb_type<T>* to_rocblas_type(T* x)
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); } rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const void miopen_gemm::batch_not_transposed(const std::vector<std::size_t>& strides) const
{ {
std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1); if (strides.size() <= 2) return;
check_shapes{input_shapes}.not_broadcasted(); auto dim_0 = strides.size() - 2;
auto a_strides = inputs[0].strides(); auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
auto dim_0 = a_strides.size() - 2; std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
if(a_strides.size() > 2) if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
{ return (i < j or i < matrix_size or j < matrix_size);
if(!std::all_of(a_strides.begin(), a_strides.begin() + dim_0, [&](auto batch_size) { }) != batch.end())
return std::all_of(a_strides.begin() + dim_0, a_strides.end(), [&](auto data_size) {
return batch_size >= data_size;
});
}))
{ {
MIGRAPHX_THROW("DOT: batch size of a {" + to_string_range(a_strides) + MIGRAPHX_THROW("DOT: batch size of a {" + to_string_range(strides) +
"} is transposed!"); "} is transposed!");
} }
} }
auto b_strides = inputs[1].strides(); shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
if(b_strides.size() > 2) {
{ std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1);
if(!std::all_of(b_strides.begin(), b_strides.begin() + dim_0, [&](auto batch_size) { check_shapes{input_shapes}.not_broadcasted();
return std::all_of(b_strides.begin() + dim_0, b_strides.end(), [&](auto data_size) { batch_not_transposed(inputs[0].strides());
return batch_size >= data_size; batch_not_transposed(inputs[1].strides());
});
}))
{
MIGRAPHX_THROW("DOT: batch size of b {" + to_string_range(b_strides) +
"} is transposed!");
}
}
return op.compute_shape(input_shapes); return op.compute_shape(input_shapes);
} }
......
...@@ -24,6 +24,7 @@ struct miopen_gemm ...@@ -24,6 +24,7 @@ struct miopen_gemm
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const;
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
void batch_not_transposed(const std::vector<std::size_t>& strides) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return shapes.size() - 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment