Commit 6001b750 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix review comments

parent 52a96f69
......@@ -167,38 +167,27 @@ rb_type<T>* to_rocblas_type(T* x)
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
void miopen_gemm::batch_not_transposed(const std::vector<std::size_t>& strides) const
{
std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1);
check_shapes{input_shapes}.not_broadcasted();
auto a_strides = inputs[0].strides();
auto dim_0 = a_strides.size() - 2;
if(a_strides.size() > 2)
if (strides.size() <= 2) return;
auto dim_0 = strides.size() - 2;
auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
return (i < j or i < matrix_size or j < matrix_size);
}) != batch.end())
{
if(!std::all_of(a_strides.begin(), a_strides.begin() + dim_0, [&](auto batch_size) {
return std::all_of(a_strides.begin() + dim_0, a_strides.end(), [&](auto data_size) {
return batch_size >= data_size;
});
}))
{
MIGRAPHX_THROW("DOT: batch size of a {" + to_string_range(a_strides) +
"} is transposed!");
}
MIGRAPHX_THROW("DOT: batch size of a {" + to_string_range(strides) +
"} is transposed!");
}
}
auto b_strides = inputs[1].strides();
if(b_strides.size() > 2)
{
if(!std::all_of(b_strides.begin(), b_strides.begin() + dim_0, [&](auto batch_size) {
return std::all_of(b_strides.begin() + dim_0, b_strides.end(), [&](auto data_size) {
return batch_size >= data_size;
});
}))
{
MIGRAPHX_THROW("DOT: batch size of b {" + to_string_range(b_strides) +
"} is transposed!");
}
}
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1);
check_shapes{input_shapes}.not_broadcasted();
batch_not_transposed(inputs[0].strides());
batch_not_transposed(inputs[1].strides());
return op.compute_shape(input_shapes);
}
......
......@@ -24,6 +24,7 @@ struct miopen_gemm
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
void batch_not_transposed(const std::vector<std::size_t>& strides) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment