Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
13845447
Commit
13845447
authored
Jan 25, 2019
by
Khalique
Browse files
Merge branch 'develop' of
https://github.com/ROCmSoftwarePlatform/AMDMIGraphX
into pad_op
parents
752fa7cf
b5090737
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
144 additions
and
56 deletions
+144
-56
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+18
-9
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+51
-28
src/targets/gpu/device/gather.cpp
src/targets/gpu/device/gather.cpp
+5
-4
src/targets/gpu/include/migraphx/gpu/device/gather.hpp
src/targets/gpu/include/migraphx/gpu/device/gather.hpp
+1
-1
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+24
-4
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+21
-3
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+5
-5
test/op_shape_test.cpp
test/op_shape_test.cpp
+19
-2
No files found.
src/include/migraphx/operators.hpp
View file @
13845447
...
@@ -671,49 +671,58 @@ struct as_shape
...
@@ -671,49 +671,58 @@ struct as_shape
struct
gather
struct
gather
{
{
std
::
size_
t
axis
=
0
;
in
t
axis
=
0
;
std
::
string
name
()
const
{
return
"gather"
;
}
std
::
string
name
()
const
{
return
"gather"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
auto
lens
=
inputs
[
0
].
lens
();
auto
lens
=
inputs
[
0
].
lens
();
if
(
axis
>=
lens
.
size
())
int
n_dim
=
static_cast
<
int
>
(
lens
.
size
());
if
(
axis
>=
n_dim
||
axis
<
-
n_dim
)
{
{
MIGRAPHX_THROW
(
"Gather
,
axis is out of range."
);
MIGRAPHX_THROW
(
"Gather
:
axis is out of range."
);
}
}
// negative axis means counting dimensions from back
int
axis_index
=
(
axis
<
0
)
?
(
n_dim
+
axis
)
:
axis
;
auto
type
=
inputs
[
0
].
type
();
auto
type
=
inputs
[
0
].
type
();
lens
[
axis
]
=
inputs
[
1
].
elements
();
lens
[
axis
_index
]
=
inputs
[
1
].
elements
();
return
{
type
,
lens
};
return
{
type
,
lens
};
}
}
template
<
class
T
>
template
<
class
T
>
void
compute_index
(
const
T
&
out_idx
,
void
compute_index
(
const
T
&
out_idx
,
const
int
axis_index
,
const
std
::
vector
<
std
::
size_t
>&
vec_indices
,
const
std
::
vector
<
std
::
size_t
>&
vec_indices
,
const
std
::
size_t
max_dim
,
const
std
::
size_t
max_dim
,
T
&
in_idx
)
const
T
&
in_idx
)
const
{
{
in_idx
=
out_idx
;
in_idx
=
out_idx
;
std
::
size_t
idx
=
vec_indices
.
at
(
out_idx
[
axis
]);
std
::
size_t
idx
=
vec_indices
.
at
(
out_idx
[
axis
_index
]);
if
(
idx
>=
max_dim
)
if
(
idx
>=
max_dim
)
{
{
MIGRAPHX_THROW
(
"Gather: indices are out of range in input tensor"
);
MIGRAPHX_THROW
(
"Gather: indices are out of range in input tensor"
);
}
}
in_idx
[
axis
]
=
idx
;
in_idx
[
axis
_index
]
=
idx
;
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
// negative axis means counting dimensions from back
int
axis_index
=
(
axis
<
0
)
?
(
output_shape
.
lens
().
size
()
+
axis
)
:
axis
;
// max dimension in axis
// max dimension in axis
std
::
size_t
max_dim
=
args
[
0
].
get_shape
().
lens
()[
axis
];
std
::
size_t
max_dim
=
args
[
0
].
get_shape
().
lens
()[
axis
_index
];
std
::
vector
<
std
::
size_t
>
vec_indices
;
std
::
vector
<
std
::
size_t
>
vec_indices
;
args
[
1
].
visit
([
&
](
auto
indices
)
{
vec_indices
.
assign
(
indices
.
begin
(),
indices
.
end
());
});
args
[
1
].
visit
([
&
](
auto
indices
)
{
vec_indices
.
assign
(
indices
.
begin
(),
indices
.
end
());
});
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
std
::
vector
<
std
::
size_t
>
in_idx
;
std
::
vector
<
std
::
size_t
>
in_idx
;
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
this
->
compute_index
(
idx
,
vec_indices
,
max_dim
,
in_idx
);
this
->
compute_index
(
idx
,
axis_index
,
vec_indices
,
max_dim
,
in_idx
);
output
(
idx
.
begin
(),
idx
.
end
())
=
input
(
in_idx
.
begin
(),
in_idx
.
end
());
output
(
idx
.
begin
(),
idx
.
end
())
=
input
(
in_idx
.
begin
(),
in_idx
.
end
());
});
});
});
});
...
...
src/onnx/onnx.cpp
View file @
13845447
...
@@ -24,7 +24,8 @@ struct onnx_parser
...
@@ -24,7 +24,8 @@ struct onnx_parser
{
{
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
;
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
;
using
node_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
;
using
node_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
;
using
op_func
=
std
::
function
<
instruction_ref
(
attribute_map
,
std
::
vector
<
instruction_ref
>
)
>
;
using
op_func
=
std
::
function
<
std
::
vector
<
instruction_ref
>
(
attribute_map
,
std
::
vector
<
instruction_ref
>
)
>
;
node_map
nodes
;
node_map
nodes
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
program
prog
=
program
();
program
prog
=
program
();
...
@@ -89,6 +90,15 @@ struct onnx_parser
...
@@ -89,6 +90,15 @@ struct onnx_parser
template
<
class
F
>
template
<
class
F
>
void
add_op
(
std
::
string
name
,
F
f
)
void
add_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
[
=
](
auto
&&
...
xs
)
{
return
std
::
vector
<
instruction_ref
>
{
f
(
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...)};
});
}
// Multi output op
template
<
class
F
>
void
add_multi_op
(
std
::
string
name
,
F
f
)
{
{
ops
.
emplace
(
name
,
f
);
ops
.
emplace
(
name
,
f
);
}
}
...
@@ -96,7 +106,7 @@ struct onnx_parser
...
@@ -96,7 +106,7 @@ struct onnx_parser
template
<
class
F
>
template
<
class
F
>
void
add_mem_op
(
std
::
string
name
,
F
f
)
void
add_mem_op
(
std
::
string
name
,
F
f
)
{
{
ops
.
emplace
(
name
,
[
=
](
auto
&&
...
xs
)
{
add_op
(
name
,
[
=
](
auto
&&
...
xs
)
{
return
std
::
mem_fn
(
f
)(
*
this
,
name
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
return
std
::
mem_fn
(
f
)(
*
this
,
name
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
});
});
}
}
...
@@ -104,7 +114,7 @@ struct onnx_parser
...
@@ -104,7 +114,7 @@ struct onnx_parser
template
<
class
T
>
template
<
class
T
>
void
add_binary_op
(
std
::
string
name
,
T
x
)
void
add_binary_op
(
std
::
string
name
,
T
x
)
{
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
add_op
(
name
,
[
this
,
x
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
args
.
size
()
!=
2
)
if
(
args
.
size
()
!=
2
)
MIGRAPHX_THROW
(
"binary operators should have 2 operands"
);
MIGRAPHX_THROW
(
"binary operators should have 2 operands"
);
if
(
contains
(
attributes
,
"broadcast"
)
and
contains
(
attributes
,
"axis"
))
if
(
contains
(
attributes
,
"broadcast"
)
and
contains
(
attributes
,
"axis"
))
...
@@ -173,7 +183,7 @@ struct onnx_parser
...
@@ -173,7 +183,7 @@ struct onnx_parser
template
<
class
T
>
template
<
class
T
>
void
add_generic_op
(
std
::
string
name
,
T
x
)
void
add_generic_op
(
std
::
string
name
,
T
x
)
{
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
add_op
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
x
,
args
);
return
prog
.
add_instruction
(
x
,
args
);
});
});
}
}
...
@@ -181,7 +191,7 @@ struct onnx_parser
...
@@ -181,7 +191,7 @@ struct onnx_parser
template
<
class
T
>
template
<
class
T
>
void
add_variadic_op
(
std
::
string
name
,
T
x
)
void
add_variadic_op
(
std
::
string
name
,
T
x
)
{
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
add_op
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
std
::
accumulate
(
std
::
next
(
args
.
begin
()),
return
std
::
accumulate
(
std
::
next
(
args
.
begin
()),
args
.
end
(),
args
.
end
(),
args
.
front
(),
args
.
front
(),
...
@@ -361,7 +371,7 @@ struct onnx_parser
...
@@ -361,7 +371,7 @@ struct onnx_parser
instruction_ref
instruction_ref
parse_gather
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
parse_gather
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
{
std
::
size_
t
axis
=
0
;
in
t
axis
=
0
;
if
(
contains
(
attributes
,
"axis"
))
if
(
contains
(
attributes
,
"axis"
))
{
{
axis
=
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
int
>
();
axis
=
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
int
>
();
...
@@ -666,7 +676,7 @@ struct onnx_parser
...
@@ -666,7 +676,7 @@ struct onnx_parser
}
}
else
else
{
{
throw
std
::
runtime_error
(
"Failed reading"
);
MIGRAPHX_THROW
(
"Failed reading
onnx file.
"
);
}
}
}
}
...
@@ -696,7 +706,7 @@ struct onnx_parser
...
@@ -696,7 +706,7 @@ struct onnx_parser
}
}
for
(
auto
&&
p
:
nodes
)
for
(
auto
&&
p
:
nodes
)
{
{
this
->
parse_node
(
get_name
(
p
.
second
)
);
this
->
parse_node
(
p
.
first
);
}
}
}
}
...
@@ -712,23 +722,37 @@ struct onnx_parser
...
@@ -712,23 +722,37 @@ struct onnx_parser
{
{
if
(
nodes
.
count
(
input
)
>
0
)
if
(
nodes
.
count
(
input
)
>
0
)
{
{
auto
&&
iname
=
get_name
(
nodes
.
at
(
input
));
assert
(
name
!=
input
);
assert
(
name
!=
iname
);
this
->
parse_node
(
input
);
this
->
parse_node
(
iname
);
args
.
push_back
(
instructions
.
at
(
input
));
args
.
push_back
(
instructions
.
at
(
iname
));
}
}
else
else
{
{
args
.
push_back
(
instructions
.
at
(
input
));
args
.
push_back
(
instructions
.
at
(
input
));
}
}
}
}
std
::
vector
<
instruction_ref
>
result
;
if
(
ops
.
count
(
node
.
op_type
())
==
0
)
if
(
ops
.
count
(
node
.
op_type
())
==
0
)
{
{
instructions
[
name
]
=
prog
.
add_instruction
(
unknown
{
node
.
op_type
()},
args
);
result
.
push_back
(
prog
.
add_instruction
(
unknown
{
node
.
op_type
()},
args
)
)
;
}
}
else
else
{
{
instructions
[
name
]
=
ops
[
node
.
op_type
()](
get_attributes
(
node
),
args
);
result
=
ops
[
node
.
op_type
()](
get_attributes
(
node
),
args
);
}
// Even no output nodes produce output in migraphx
if
(
node
.
output
().
empty
()
and
result
.
size
()
==
1
)
{
instructions
[
name
]
=
result
.
front
();
}
else
{
assert
(
node
.
output
().
size
()
>=
result
.
size
());
std
::
transform
(
result
.
begin
(),
result
.
end
(),
node
.
output
().
begin
(),
std
::
inserter
(
instructions
,
instructions
.
end
()),
[](
auto
&&
x
,
auto
&&
y
)
{
return
std
::
make_pair
(
y
,
x
);
});
}
}
}
}
}
}
...
@@ -743,25 +767,24 @@ struct onnx_parser
...
@@ -743,25 +767,24 @@ struct onnx_parser
return
result
;
return
result
;
}
}
static
std
::
string
get_name
(
const
onnx
::
NodeProto
&
node
)
{
if
(
node
.
name
().
empty
())
{
std
::
string
generated
=
"migraphx_unnamed_node"
;
return
std
::
accumulate
(
node
.
output
().
begin
(),
node
.
output
().
end
(),
generated
,
[](
auto
x
,
auto
y
)
{
return
x
+
"_"
+
y
;
});
}
return
node
.
name
();
}
static
node_map
get_nodes
(
const
onnx
::
GraphProto
&
graph
)
static
node_map
get_nodes
(
const
onnx
::
GraphProto
&
graph
)
{
{
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
result
;
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
result
;
std
::
size_t
n
=
0
;
for
(
auto
&&
node
:
graph
.
node
())
for
(
auto
&&
node
:
graph
.
node
())
{
{
result
[
get_name
(
node
)]
=
node
;
if
(
node
.
output
().
empty
())
{
if
(
node
.
name
().
empty
())
{
result
[
"migraphx_unamed_node_"
+
std
::
to_string
(
n
)]
=
node
;
n
++
;
}
else
{
result
[
node
.
name
()]
=
node
;
}
}
for
(
auto
&&
output
:
node
.
output
())
for
(
auto
&&
output
:
node
.
output
())
{
{
result
[
output
]
=
node
;
result
[
output
]
=
node
;
...
...
src/targets/gpu/device/gather.cpp
View file @
13845447
...
@@ -14,8 +14,9 @@ namespace device {
...
@@ -14,8 +14,9 @@ namespace device {
argument
gather
(
hipStream_t
stream
,
argument
gather
(
hipStream_t
stream
,
const
migraphx
::
shape
&
output_shape
,
const
migraphx
::
shape
&
output_shape
,
std
::
vector
<
migraphx
::
argument
>
args
,
std
::
vector
<
migraphx
::
argument
>
args
,
std
::
size_
t
axis
)
in
t
axis
)
{
{
int
axis_index
=
(
axis
<
0
)
?
(
axis
+
output_shape
.
lens
().
size
())
:
axis
;
visit_all
(
args
.
back
(),
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
visit_all
(
args
.
back
(),
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
std
::
size_t
nelements
=
output_shape
.
elements
();
std
::
size_t
nelements
=
output_shape
.
elements
();
args
[
1
].
visit
([
&
](
auto
indices
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
...
@@ -27,7 +28,7 @@ argument gather(hipStream_t stream,
...
@@ -27,7 +28,7 @@ argument gather(hipStream_t stream,
hip_tensor_descriptor
<
ndim
>
desc_output
(
output
.
get_shape
());
hip_tensor_descriptor
<
ndim
>
desc_output
(
output
.
get_shape
());
gs_launch
(
stream
,
nelements
)([
=
](
auto
i
)
{
gs_launch
(
stream
,
nelements
)([
=
](
auto
i
)
{
auto
lens
=
desc_output
.
multi
(
i
);
auto
lens
=
desc_output
.
multi
(
i
);
lens
[
axis
]
=
indices_ptr
[
lens
[
axis
]];
lens
[
axis
_index
]
=
indices_ptr
[
lens
[
axis
_index
]];
outptr
[
i
]
=
inptr
[
desc_input
.
linear
(
lens
)];
outptr
[
i
]
=
inptr
[
desc_input
.
linear
(
lens
)];
});
});
});
});
...
...
src/targets/gpu/include/migraphx/gpu/device/gather.hpp
View file @
13845447
...
@@ -13,7 +13,7 @@ namespace device {
...
@@ -13,7 +13,7 @@ namespace device {
argument
gather
(
hipStream_t
stream
,
argument
gather
(
hipStream_t
stream
,
const
migraphx
::
shape
&
output_shape
,
const
migraphx
::
shape
&
output_shape
,
std
::
vector
<
migraphx
::
argument
>
args
,
std
::
vector
<
migraphx
::
argument
>
args
,
std
::
size_
t
axis
);
in
t
axis
);
}
// namespace device
}
// namespace device
}
// namespace gpu
}
// namespace gpu
...
...
test/cpu_ops_test.cpp
View file @
13845447
...
@@ -113,7 +113,7 @@ TEST_CASE(gather_test)
...
@@ -113,7 +113,7 @@ TEST_CASE(gather_test)
migraphx
::
shape
s_indices
{
migraphx
::
shape
::
int32_type
,
{
1
,
2
}};
migraphx
::
shape
s_indices
{
migraphx
::
shape
::
int32_type
,
{
1
,
2
}};
std
::
vector
<
int
>
indices
{
0
,
2
};
std
::
vector
<
int
>
indices
{
0
,
2
};
auto
a1
=
p
.
add_literal
(
migraphx
::
literal
{
s_indices
,
indices
});
auto
a1
=
p
.
add_literal
(
migraphx
::
literal
{
s_indices
,
indices
});
std
::
size_
t
axis
=
0
;
in
t
axis
=
0
;
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
a0
,
a1
);
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
a0
,
a1
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
auto
result
=
p
.
eval
({});
...
@@ -133,7 +133,27 @@ TEST_CASE(gather_test)
...
@@ -133,7 +133,27 @@ TEST_CASE(gather_test)
migraphx
::
shape
s_indices
{
migraphx
::
shape
::
int32_type
,
{
1
,
2
}};
migraphx
::
shape
s_indices
{
migraphx
::
shape
::
int32_type
,
{
1
,
2
}};
std
::
vector
<
int
>
indices
{
0
,
2
};
std
::
vector
<
int
>
indices
{
0
,
2
};
auto
a1
=
p
.
add_literal
(
migraphx
::
literal
{
s_indices
,
indices
});
auto
a1
=
p
.
add_literal
(
migraphx
::
literal
{
s_indices
,
indices
});
std
::
size_t
axis
=
1
;
int
axis
=
1
;
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
a0
,
a1
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
res_data
(
4
*
5
);
std
::
vector
<
float
>
golden
=
{
0.5
f
,
2.5
f
,
3.5
f
,
5.5
f
,
6.5
f
,
8.5
f
};
result
.
visit
([
&
](
auto
output
)
{
res_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
res_data
,
golden
));
}
{
migraphx
::
program
p
;
std
::
vector
<
float
>
data
(
3
*
3
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0.5
);
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
3
}};
auto
a0
=
p
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
migraphx
::
shape
s_indices
{
migraphx
::
shape
::
int32_type
,
{
1
,
2
}};
std
::
vector
<
int
>
indices
{
0
,
2
};
auto
a1
=
p
.
add_literal
(
migraphx
::
literal
{
s_indices
,
indices
});
int
axis
=
-
1
;
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
a0
,
a1
);
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
a0
,
a1
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
auto
result
=
p
.
eval
({});
...
...
test/gpu/miopen.cpp
View file @
13845447
...
@@ -963,7 +963,23 @@ struct test_gather
...
@@ -963,7 +963,23 @@ struct test_gather
std
::
vector
<
int
>
indices
{
1
,
2
,
2
,
1
};
std
::
vector
<
int
>
indices
{
1
,
2
,
2
,
1
};
auto
a0
=
p
.
add_parameter
(
"data"
,
s
);
auto
a0
=
p
.
add_parameter
(
"data"
,
s
);
auto
a1
=
p
.
add_literal
(
migraphx
::
literal
{
s_indices
,
indices
});
auto
a1
=
p
.
add_literal
(
migraphx
::
literal
{
s_indices
,
indices
});
std
::
size_t
axis
=
0
;
int
axis
=
0
;
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
a0
,
a1
);
return
p
;
}
};
struct
test_gather_neg_axis
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
3
}};
migraphx
::
shape
s_indices
{
migraphx
::
shape
::
int32_type
,
{
2
,
2
}};
std
::
vector
<
int
>
indices
{
1
,
2
,
2
,
1
};
auto
a0
=
p
.
add_parameter
(
"data"
,
s
);
auto
a1
=
p
.
add_literal
(
migraphx
::
literal
{
s_indices
,
indices
});
int
axis
=
-
1
;
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
a0
,
a1
);
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
a0
,
a1
);
return
p
;
return
p
;
}
}
...
@@ -1111,4 +1127,6 @@ int main()
...
@@ -1111,4 +1127,6 @@ int main()
verify_program
<
test_conv_bn_relu_pooling
>
();
verify_program
<
test_conv_bn_relu_pooling
>
();
verify_program
<
test_conv_bn_relu_pooling2
>
();
verify_program
<
test_conv_bn_relu_pooling2
>
();
verify_program
<
test_slice
>
();
verify_program
<
test_slice
>
();
verify_program
<
test_gather
>
();
verify_program
<
test_gather_neg_axis
>
();
}
}
test/onnx/onnx_test.cpp
View file @
13845447
...
@@ -417,7 +417,7 @@ TEST_CASE(gather_test)
...
@@ -417,7 +417,7 @@ TEST_CASE(gather_test)
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"data"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
5
,
6
}});
auto
l0
=
p
.
add_parameter
(
"data"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
5
,
6
}});
auto
l1
=
p
.
add_parameter
(
"indices"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}});
auto
l1
=
p
.
add_parameter
(
"indices"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}});
std
::
size_
t
axis
=
1
;
in
t
axis
=
1
;
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
l0
,
l1
);
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
l0
,
l1
);
auto
prog
=
migraphx
::
parse_onnx
(
"gather_test.onnx"
);
auto
prog
=
migraphx
::
parse_onnx
(
"gather_test.onnx"
);
...
@@ -432,7 +432,7 @@ TEST_CASE(shape_gather_test)
...
@@ -432,7 +432,7 @@ TEST_CASE(shape_gather_test)
p
.
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
3
}},
l0
->
get_shape
().
lens
());
p
.
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
3
}},
l0
->
get_shape
().
lens
());
migraphx
::
shape
const_shape
{
migraphx
::
shape
::
int32_type
,
{
1
}};
migraphx
::
shape
const_shape
{
migraphx
::
shape
::
int32_type
,
{
1
}};
auto
l2
=
p
.
add_literal
(
migraphx
::
literal
{
const_shape
,
{
1
}});
auto
l2
=
p
.
add_literal
(
migraphx
::
literal
{
const_shape
,
{
1
}});
std
::
size_
t
axis
=
0
;
in
t
axis
=
0
;
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
l1
,
l2
);
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
l1
,
l2
);
auto
prog
=
migraphx
::
parse_onnx
(
"shape_gather.onnx"
);
auto
prog
=
migraphx
::
parse_onnx
(
"shape_gather.onnx"
);
...
...
test/op_shape_test.cpp
View file @
13845447
...
@@ -217,7 +217,7 @@ TEST_CASE(gather)
...
@@ -217,7 +217,7 @@ TEST_CASE(gather)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}};
std
::
size_
t
axis
=
1
;
in
t
axis
=
1
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
6
,
4
,
5
}},
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
6
,
4
,
5
}},
migraphx
::
op
::
gather
{
axis
},
migraphx
::
op
::
gather
{
axis
},
input
,
input
,
...
@@ -227,7 +227,24 @@ TEST_CASE(gather)
...
@@ -227,7 +227,24 @@ TEST_CASE(gather)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}};
std
::
size_t
axis
=
4
;
int
axis
=
-
4
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
6
,
3
,
4
,
5
}},
migraphx
::
op
::
gather
{
axis
},
input
,
indices
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}};
int
axis
=
4
;
throws_shape
(
migraphx
::
op
::
gather
{
axis
},
input
,
indices
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}};
int
axis
=
-
5
;
throws_shape
(
migraphx
::
op
::
gather
{
axis
},
input
,
indices
);
throws_shape
(
migraphx
::
op
::
gather
{
axis
},
input
,
indices
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment