Commit 0f85317e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

changes to make bert model work

parent a3aacad6
...@@ -48,6 +48,7 @@ struct dot ...@@ -48,6 +48,7 @@ struct dot
"} x {" + to_string_range(b.lens()) + "}"); "} x {" + to_string_range(b.lens()) + "}");
} }
// dims for batch should be standard
std::size_t dim_0 = a.lens().size() - 2; std::size_t dim_0 = a.lens().size() - 2;
std::size_t dim_1 = a.lens().size() - 1; std::size_t dim_1 = a.lens().size() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0]) if(a.lens()[dim_1] != b.lens()[dim_0])
......
...@@ -691,14 +691,6 @@ struct onnx_parser ...@@ -691,14 +691,6 @@ struct onnx_parser
} }
} }
if(!bl1->get_shape().standard())
{
bl1 = prog.add_instruction(op::contiguous{}, bl1);
}
if(!bl0->get_shape().standard())
{
bl0 = prog.add_instruction(op::contiguous{}, bl0);
}
auto dot_res = prog.add_instruction(op::dot{1.0f, 0.0f}, bl0, bl1); auto dot_res = prog.add_instruction(op::dot{1.0f, 0.0f}, bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size()); int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if(is_a_prepended) if(is_a_prepended)
......
...@@ -170,7 +170,34 @@ rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_hal ...@@ -170,7 +170,34 @@ rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_hal
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{ {
std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1); std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1);
check_shapes{input_shapes}.standard(); check_shapes{input_shapes}.not_broadcasted();
auto a_strides = inputs[0].strides();
auto dim_0 = a_strides.size() - 2;
if (a_strides.size() > 2)
{
if (!std::all_of(a_strides.begin(), a_strides.begin() + dim_0, [&](auto batch_size) {
return std::all_of(a_strides.begin() + dim_0, a_strides.end(), [&](auto data_size) {
return batch_size >= data_size;
});
}))
{
MIGRAPHX_THROW("DOT: batch size of a {" + to_string_range(a_strides) + "} is transposed!");
}
}
auto b_strides = inputs[1].strides();
if (b_strides.size() > 2)
{
if (!std::all_of(b_strides.begin(), b_strides.begin() + dim_0, [&](auto batch_size) {
return std::all_of(b_strides.begin() + dim_0, b_strides.end(), [&](auto data_size) {
return batch_size >= data_size;
});
}))
{
MIGRAPHX_THROW("DOT: batch size of b {" + to_string_range(b_strides) + "} is transposed!");
}
}
return op.compute_shape(input_shapes); return op.compute_shape(input_shapes);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment