"example/vscode:/vscode.git/clone" did not exist on "11edd0f045361b8c1a443fab303f29983ed43a57"
Commit be99b85d authored by Khalique's avatar Khalique
Browse files

Merge branch 'test_bert' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into bert_ops

parents 8819f4dc b0e589fd
...@@ -461,8 +461,16 @@ struct onnx_parser ...@@ -461,8 +461,16 @@ struct onnx_parser
op::reshape op; op::reshape op;
if(args.size() == 1) if(args.size() == 1)
{ {
literal s = parse_value(attributes.at("shape")); if(contains(attributes, "shape"))
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); {
literal s = parse_value(attributes.at("shape"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
else
{
MIGRAPHX_THROW(
"Parse_reshape: shape attribute is needed when only one argument is provided!");
}
} }
if(args.size() == 2) if(args.size() == 2)
{ {
...@@ -470,6 +478,12 @@ struct onnx_parser ...@@ -470,6 +478,12 @@ struct onnx_parser
check_arg_empty(s, "Reshape: dynamic shape is not supported"); check_arg_empty(s, "Reshape: dynamic shape is not supported");
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
} }
if(!args[0]->get_shape().standard())
{
args[0] = prog.add_instruction(op::contiguous{}, args[0]);
}
return prog.add_instruction(op, args[0]); return prog.add_instruction(op, args[0]);
} }
...@@ -972,7 +986,6 @@ struct onnx_parser ...@@ -972,7 +986,6 @@ struct onnx_parser
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims); auto out_lens = compute_broadcasted_lens(in_lens, dims);
return prog.add_instruction(op::multibroadcast{out_lens}, args[0]); return prog.add_instruction(op::multibroadcast{out_lens}, args[0]);
} }
......
...@@ -25,7 +25,7 @@ argument gather(hipStream_t stream, argument result, argument arg1, argument arg ...@@ -25,7 +25,7 @@ argument gather(hipStream_t stream, argument result, argument arg1, argument arg
arg2.visit([&](auto indices) { arg2.visit([&](auto indices) {
const auto* indices_ptr = device_cast(indices.data()); const auto* indices_ptr = device_cast(indices.data());
auto* output_ptr = device_cast(output.data()); auto* output_ptr = device_cast(output.data());
gs_launch(stream, nelements)([=](auto i) { gs_launch(stream, nelements, 256)([=](auto i) {
auto idx = out_comp.multi(i); auto idx = out_comp.multi(i);
idx[axis_index] = indices_ptr[idx[axis_index]]; idx[axis_index] = indices_ptr[idx[axis_index]];
output_ptr[i] = input[idx]; output_ptr[i] = input[idx];
......
...@@ -167,10 +167,28 @@ rb_type<T>* to_rocblas_type(T* x) ...@@ -167,10 +167,28 @@ rb_type<T>* to_rocblas_type(T* x)
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); } rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
void miopen_gemm::batch_not_transposed(const std::vector<std::size_t>& strides) const
{
if(strides.size() <= 2)
return;
auto dim_0 = strides.size() - 2;
auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
return (i < j or i < matrix_size or j < matrix_size);
}) != batch.end())
{
MIGRAPHX_THROW("DOT: batch size of a {" + to_string_range(strides) + "} is transposed!");
}
}
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{ {
std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1); std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1);
check_shapes{input_shapes}.not_broadcasted(); check_shapes{input_shapes}.not_broadcasted();
batch_not_transposed(inputs[0].strides());
batch_not_transposed(inputs[1].strides());
return op.compute_shape(input_shapes); return op.compute_shape(input_shapes);
} }
......
...@@ -24,6 +24,7 @@ struct miopen_gemm ...@@ -24,6 +24,7 @@ struct miopen_gemm
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const;
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
void batch_not_transposed(const std::vector<std::size_t>& strides) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return shapes.size() - 1;
......
...@@ -447,6 +447,21 @@ TEST_CASE(reshape_test) ...@@ -447,6 +447,21 @@ TEST_CASE(reshape_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(reshape_non_standard)
{
migraphx::program p;
migraphx::op::reshape op;
std::vector<int64_t> reshape_dims{4, 3, 2};
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4}};
auto x = p.add_parameter("x", s);
auto tran_x = p.add_instruction(migraphx::op::transpose{{0, 2, 1}}, x);
auto cont_x = p.add_instruction(migraphx::op::contiguous{}, tran_x);
p.add_instruction(migraphx::op::reshape{{4, 3, 2}}, cont_x);
auto prog = migraphx::parse_onnx("reshape_non_standard.onnx");
EXPECT(p == prog);
}
TEST_CASE(shape_test) TEST_CASE(shape_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment