Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
aa521f17
Commit
aa521f17
authored
Feb 26, 2019
by
Shucai Xiao
Browse files
merge from branch gather_operator
parents
0b769919
4c1a1d63
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
224 additions
and
44 deletions
+224
-44
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+32
-26
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+14
-2
src/targets/gpu/device/gather.cpp
src/targets/gpu/device/gather.cpp
+16
-12
src/targets/gpu/device/logsoftmax.cpp
src/targets/gpu/device/logsoftmax.cpp
+1
-2
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+42
-0
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+51
-0
test/onnx/constant_scalar.onnx
test/onnx/constant_scalar.onnx
+7
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+9
-0
test/op_shape_test.cpp
test/op_shape_test.cpp
+52
-2
No files found.
src/include/migraphx/operators.hpp
View file @
aa521f17
...
@@ -757,43 +757,49 @@ struct gather
...
@@ -757,43 +757,49 @@ struct gather
// negative axis means counting dimensions from back
// negative axis means counting dimensions from back
int
axis_index
=
(
axis
<
0
)
?
(
n_dim
+
axis
)
:
axis
;
int
axis_index
=
(
axis
<
0
)
?
(
n_dim
+
axis
)
:
axis
;
auto
type
=
inputs
[
0
].
type
();
auto
type
=
inputs
[
0
].
type
();
lens
[
axis_index
]
=
inputs
[
1
].
elements
();
lens
.
erase
(
lens
.
begin
()
+
axis_index
);
if
(
!
inputs
[
1
].
scalar
())
return
{
type
,
lens
};
{
}
auto
ind_lens
=
inputs
[
1
].
lens
();
lens
.
insert
(
lens
.
begin
()
+
axis_index
,
ind_lens
.
begin
(),
ind_lens
.
end
());
}
template
<
class
T
>
// for scalar output
void
compute_index
(
const
T
&
out_idx
,
if
(
lens
.
empty
())
const
int
axis_index
,
const
std
::
vector
<
std
::
size_t
>&
vec_indices
,
const
std
::
size_t
max_dim
,
T
&
in_idx
)
const
{
in_idx
=
out_idx
;
std
::
size_t
idx
=
vec_indices
.
at
(
out_idx
[
axis_index
]);
if
(
idx
>=
max_dim
)
{
{
MIGRAPHX_THROW
(
"Gather: indices are out of range in input tensor"
)
;
return
{
type
,
{
1
},
{
0
}}
;
}
}
in_idx
[
axis_index
]
=
idx
;
return
{
type
,
lens
};
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
// negative axis means counting dimensions from back
// negative axis means counting dimensions from back
int
axis_index
=
(
axis
<
0
)
?
(
output_shape
.
lens
().
size
()
+
axis
)
:
axis
;
int
axis_index
=
(
axis
<
0
)
?
static_cast
<
int
>
(
args
[
0
].
get_shape
().
lens
().
size
()
+
axis
)
:
axis
;
// max dimension in axis
// max dimension in axis
std
::
size_t
max_dim
=
args
[
0
].
get_shape
().
lens
()[
axis_index
];
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
std
::
vector
<
std
::
size_t
>
vec_indices
;
args
[
1
].
visit
([
&
](
auto
indices
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
vec_indices
.
assign
(
indices
.
begin
(),
indices
.
end
());
});
if
(
output_shape
.
scalar
())
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
{
std
::
vector
<
std
::
size_t
>
in_idx
;
output
[
0
]
=
data
[
indices
.
front
()];
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
}
this
->
compute_index
(
idx
,
axis_index
,
vec_indices
,
max_dim
,
in_idx
);
else
output
(
idx
.
begin
(),
idx
.
end
())
=
input
(
in_idx
.
begin
(),
in_idx
.
end
());
{
auto
out_lens
=
data
.
get_shape
().
lens
();
out_lens
[
axis_index
]
=
indices
.
get_shape
().
elements
();
migraphx
::
shape
out_comp_shape
{
data
.
get_shape
().
type
(),
out_lens
};
shape_for_each
(
out_comp_shape
,
[
&
](
const
auto
&
out_idx
)
{
auto
data_idx
=
out_idx
;
data_idx
[
axis_index
]
=
indices
[
data_idx
[
axis_index
]];
output
[
out_comp_shape
.
index
(
out_idx
.
begin
(),
out_idx
.
end
())]
=
data
(
data_idx
.
begin
(),
data_idx
.
end
());
});
}
});
});
});
});
...
...
src/onnx/onnx.cpp
View file @
aa521f17
...
@@ -448,7 +448,15 @@ struct onnx_parser
...
@@ -448,7 +448,15 @@ struct onnx_parser
attribute_map
attributes
,
attribute_map
attributes
,
const
std
::
vector
<
instruction_ref
>&
)
const
std
::
vector
<
instruction_ref
>&
)
{
{
literal
v
=
parse_value
(
attributes
.
at
(
"value"
));
literal
v
=
parse_value
(
attributes
.
at
(
"value"
));
auto
dim_size
=
attributes
.
at
(
"value"
).
t
().
dims_size
();
// if dim_size is 0, it is a scalar
if
(
dim_size
==
0
)
{
migraphx
::
shape
scalar_shape
{
v
.
get_shape
().
type
(),
{
1
},
{
0
}};
return
prog
.
add_literal
(
migraphx
::
literal
{
scalar_shape
,
v
.
data
()});
}
return
prog
.
add_literal
(
v
);
return
prog
.
add_literal
(
v
);
}
}
...
@@ -475,6 +483,7 @@ struct onnx_parser
...
@@ -475,6 +483,7 @@ struct onnx_parser
{
{
transb
=
parse_value
(
attributes
.
at
(
"transB"
)).
at
<
bool
>
();
transb
=
parse_value
(
attributes
.
at
(
"transB"
)).
at
<
bool
>
();
}
}
std
::
vector
<
int64_t
>
perm
=
{
1
,
0
};
std
::
vector
<
int64_t
>
perm
=
{
1
,
0
};
auto
l1
=
(
transa
)
?
prog
.
add_instruction
(
op
::
transpose
{
perm
},
args
[
0
])
:
args
[
0
];
auto
l1
=
(
transa
)
?
prog
.
add_instruction
(
op
::
transpose
{
perm
},
args
[
0
])
:
args
[
0
];
auto
l2
=
(
transb
)
?
prog
.
add_instruction
(
op
::
transpose
{
perm
},
args
[
1
])
:
args
[
1
];
auto
l2
=
(
transb
)
?
prog
.
add_instruction
(
op
::
transpose
{
perm
},
args
[
1
])
:
args
[
1
];
...
@@ -495,7 +504,10 @@ struct onnx_parser
...
@@ -495,7 +504,10 @@ struct onnx_parser
return
add_broadcastable_binary_op
(
l3
,
l4
,
op
::
add
{});
return
add_broadcastable_binary_op
(
l3
,
l4
,
op
::
add
{});
}
}
}
}
return
prog
.
add_instruction
(
op
::
dot
{
alpha
,
beta
},
l1
,
l2
);
auto
dot_res
=
prog
.
add_instruction
(
op
::
dot
{
alpha
,
beta
},
l1
,
l2
);
return
dot_res
;
}
}
instruction_ref
instruction_ref
...
...
src/targets/gpu/device/gather.cpp
View file @
aa521f17
...
@@ -16,20 +16,24 @@ argument gather(hipStream_t stream,
...
@@ -16,20 +16,24 @@ argument gather(hipStream_t stream,
std
::
vector
<
migraphx
::
argument
>
args
,
std
::
vector
<
migraphx
::
argument
>
args
,
int
axis
)
int
axis
)
{
{
int
axis_index
=
(
axis
<
0
)
?
(
axis
+
outpu
t_shape
.
lens
().
size
())
:
axis
;
int
axis_index
=
(
axis
<
0
)
?
(
axis
+
args
[
0
].
ge
t_shape
()
.
lens
().
size
())
:
axis
;
visit_all
(
args
.
back
(),
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
visit_all
(
args
.
back
(),
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
std
::
size_t
nelements
=
output_shape
.
elements
();
std
::
size_t
nelements
=
output_shape
.
elements
();
args
[
1
].
visit
([
=
](
auto
indices
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
visit_tensor_size
(
output_shape
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
auto
*
out_ptr
=
device_cast
(
output
.
data
());
auto
*
outptr
=
device_cast
(
output
.
data
());
const
auto
*
in_ptr
=
device_cast
(
input
.
data
());
const
auto
*
inptr
=
device_cast
(
input
.
data
());
auto
&
input_shape
=
args
[
0
].
get_shape
();
hip_tensor_descriptor
<
ndim
>
desc_input
(
input
.
get_shape
());
auto
lens
=
input_shape
.
lens
();
hip_tensor_descriptor
<
ndim
>
desc_output
(
output
.
get_shape
());
lens
[
axis_index
]
=
args
[
1
].
get_shape
().
elements
();
gs_launch
(
stream
,
nelements
)([
&
](
auto
i
)
{
migraphx
::
shape
out_comp_shape
{
output_shape
.
type
(),
lens
};
auto
lens
=
desc_output
.
multi
(
i
);
visit_tensor_size
(
out_comp_shape
.
lens
().
size
(),
[
&
](
auto
n_out_dim
)
{
lens
[
axis_index
]
=
indices_ptr
[
lens
[
axis_index
]];
hip_tensor_descriptor
<
n_out_dim
>
desc_input
(
input_shape
);
outptr
[
i
]
=
inptr
[
desc_input
.
linear
(
lens
)];
hip_tensor_descriptor
<
n_out_dim
>
desc_output
(
out_comp_shape
);
gs_launch
(
stream
,
nelements
)([
=
](
auto
ii
)
{
auto
in_idx
=
desc_output
.
multi
(
ii
);
in_idx
[
axis_index
]
=
indices_ptr
[
in_idx
[
axis_index
]];
out_ptr
[
ii
]
=
in_ptr
[
desc_input
.
linear
(
in_idx
)];
});
});
});
});
});
});
...
...
src/targets/gpu/device/logsoftmax.cpp
View file @
aa521f17
...
@@ -45,13 +45,12 @@ argument logsoftmax(hipStream_t stream,
...
@@ -45,13 +45,12 @@ argument logsoftmax(hipStream_t stream,
output_ptr
[
ind
]
=
input_ptr
[
ind
]
-
batch_max
;
output_ptr
[
ind
]
=
input_ptr
[
ind
]
-
batch_max
;
}
}
auto
batch_sum
=
output_ptr
[
row_start
];
auto
batch_sum
=
::
exp
(
to_hip_type
(
output_ptr
[
row_start
]
))
;
for
(
std
::
size_t
j
=
1
;
j
<
n_dims
;
++
j
)
for
(
std
::
size_t
j
=
1
;
j
<
n_dims
;
++
j
)
{
{
auto
ind
=
row_start
+
j
;
auto
ind
=
row_start
+
j
;
batch_sum
+=
::
exp
(
to_hip_type
(
output_ptr
[
ind
]));
batch_sum
+=
::
exp
(
to_hip_type
(
output_ptr
[
ind
]));
}
}
batch_sum
=
::
log
(
to_hip_type
(
batch_sum
));
batch_sum
=
::
log
(
to_hip_type
(
batch_sum
));
for
(
std
::
size_t
j
=
0
;
j
<
n_dims
;
++
j
)
for
(
std
::
size_t
j
=
0
;
j
<
n_dims
;
++
j
)
...
...
test/cpu_ops_test.cpp
View file @
aa521f17
...
@@ -164,6 +164,48 @@ TEST_CASE(gather_test)
...
@@ -164,6 +164,48 @@ TEST_CASE(gather_test)
result
.
visit
([
&
](
auto
output
)
{
res_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
res_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
res_data
,
golden
));
EXPECT
(
migraphx
::
verify_range
(
res_data
,
golden
));
}
}
{
migraphx
::
program
p
;
std
::
vector
<
float
>
data
(
3
*
3
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0.5
);
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
3
}};
auto
a0
=
p
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
// scalar index
migraphx
::
shape
s_indices
{
migraphx
::
shape
::
int32_type
,
{
1
},
{
0
}};
std
::
vector
<
int
>
indices
{
0
};
auto
a1
=
p
.
add_literal
(
migraphx
::
literal
{
s_indices
,
indices
});
int
axis
=
-
1
;
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
a0
,
a1
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
res_data
{};
std
::
vector
<
float
>
golden
=
{
0.5
f
,
3.5
f
,
6.5
f
};
result
.
visit
([
&
](
auto
output
)
{
res_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
res_data
,
golden
));
}
{
migraphx
::
program
p
;
std
::
vector
<
float
>
data
(
3
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0.5
);
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
}};
auto
a0
=
p
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
// scalar index
migraphx
::
shape
s_indices
{
migraphx
::
shape
::
int32_type
,
{
1
},
{
0
}};
std
::
vector
<
int
>
indices
{
0
};
auto
a1
=
p
.
add_literal
(
migraphx
::
literal
{
s_indices
,
indices
});
int
axis
=
-
1
;
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
a0
,
a1
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
res_data
{};
std
::
vector
<
float
>
golden
=
{
0.5
f
};
result
.
visit
([
&
](
auto
output
)
{
res_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
res_data
,
golden
));
}
}
}
TEST_CASE
(
squeeze_test
)
TEST_CASE
(
squeeze_test
)
...
...
test/gpu/miopen.cpp
View file @
aa521f17
...
@@ -1068,6 +1068,54 @@ struct test_gather_neg_axis
...
@@ -1068,6 +1068,54 @@ struct test_gather_neg_axis
}
}
};
};
struct
test_gather_scalar_output
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
}};
migraphx
::
shape
s_indices
{
migraphx
::
shape
::
int32_type
,
{
1
},
{
0
}};
std
::
vector
<
int
>
indices
{
1
};
auto
a0
=
p
.
add_parameter
(
"data"
,
s
);
auto
a1
=
p
.
add_literal
(
migraphx
::
literal
{
s_indices
,
indices
});
int
axis
=
0
;
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
a0
,
a1
);
return
p
;
}
};
struct
test_gather_scalar_index
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
3
}};
migraphx
::
shape
s_indices
{
migraphx
::
shape
::
int32_type
,
{
1
},
{
0
}};
std
::
vector
<
int
>
indices
{
1
};
auto
a0
=
p
.
add_parameter
(
"data"
,
s
);
auto
a1
=
p
.
add_literal
(
migraphx
::
literal
{
s_indices
,
indices
});
int
axis
=
-
1
;
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
a0
,
a1
);
return
p
;
}
};
struct
test_gather_1d_index
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
3
}};
migraphx
::
shape
s_indices
{
migraphx
::
shape
::
int32_type
,
{
1
}};
std
::
vector
<
int
>
indices
{
1
};
auto
a0
=
p
.
add_parameter
(
"data"
,
s
);
auto
a1
=
p
.
add_literal
(
migraphx
::
literal
{
s_indices
,
indices
});
int
axis
=
-
1
;
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
a0
,
a1
);
return
p
;
}
};
void
manual_identity
()
void
manual_identity
()
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
@@ -2904,6 +2952,9 @@ int main()
...
@@ -2904,6 +2952,9 @@ int main()
verify_program
<
test_slice
>
();
verify_program
<
test_slice
>
();
verify_program
<
test_gather
>
();
verify_program
<
test_gather
>
();
verify_program
<
test_gather_neg_axis
>
();
verify_program
<
test_gather_neg_axis
>
();
verify_program
<
test_gather_scalar_output
>
();
verify_program
<
test_gather_scalar_index
>
();
verify_program
<
test_gather_1d_index
>
();
verify_program
<
test_rnn_forward
>
();
verify_program
<
test_rnn_forward
>
();
verify_program
<
test_rnn_forward10
>
();
verify_program
<
test_rnn_forward10
>
();
verify_program
<
test_rnn_reverse
>
();
verify_program
<
test_rnn_reverse
>
();
...
...
test/onnx/constant_scalar.onnx
0 → 100644
View file @
aa521f17
shape-gather-example:O
2value"Constant*
value**Bconst_tensor constantb
z
B
\ No newline at end of file
test/onnx/onnx_test.cpp
View file @
aa521f17
...
@@ -521,6 +521,15 @@ TEST_CASE(constant_test)
...
@@ -521,6 +521,15 @@ TEST_CASE(constant_test)
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
constant_test_scalar
)
{
migraphx
::
program
p
;
p
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
1
},
{
0
}},
{
1
}});
auto
prog
=
migraphx
::
parse_onnx
(
"constant_scalar.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
constant_fill_test
)
TEST_CASE
(
constant_fill_test
)
{
{
{
{
...
...
test/op_shape_test.cpp
View file @
aa521f17
...
@@ -235,7 +235,7 @@ TEST_CASE(gather)
...
@@ -235,7 +235,7 @@ TEST_CASE(gather)
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}};
int
axis
=
1
;
int
axis
=
1
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
6
,
4
,
5
}},
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
3
,
4
,
5
}},
migraphx
::
op
::
gather
{
axis
},
migraphx
::
op
::
gather
{
axis
},
input
,
input
,
indices
);
indices
);
...
@@ -245,7 +245,57 @@ TEST_CASE(gather)
...
@@ -245,7 +245,57 @@ TEST_CASE(gather)
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}};
int
axis
=
-
4
;
int
axis
=
-
4
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
6
,
3
,
4
,
5
}},
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
3
,
4
,
5
}},
migraphx
::
op
::
gather
{
axis
},
input
,
indices
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
1
}};
int
axis
=
-
4
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
4
,
5
}},
migraphx
::
op
::
gather
{
axis
},
input
,
indices
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
1
},
{
0
}};
int
axis
=
-
4
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
5
}},
migraphx
::
op
::
gather
{
axis
},
input
,
indices
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
1
},
{
0
}};
int
axis
=
3
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}},
migraphx
::
op
::
gather
{
axis
},
input
,
indices
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
1
},
{
0
}};
int
axis
=
0
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
},
{
0
}},
migraphx
::
op
::
gather
{
axis
},
input
,
indices
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
1
}};
int
axis
=
0
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}},
migraphx
::
op
::
gather
{
axis
},
migraphx
::
op
::
gather
{
axis
},
input
,
input
,
indices
);
indices
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment