Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
edc23800
Commit
edc23800
authored
Feb 11, 2022
by
Shucai Xiao
Browse files
change the data type for lens and strides from size_t to int in the shape class
parent
c7419a9c
Changes
63
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
78 additions
and
78 deletions
+78
-78
src/common.cpp
src/common.cpp
+4
-4
src/eliminate_allocation.cpp
src/eliminate_allocation.cpp
+2
-2
src/eliminate_concat.cpp
src/eliminate_concat.cpp
+2
-2
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+8
-8
src/include/migraphx/common.hpp
src/include/migraphx/common.hpp
+2
-2
src/include/migraphx/normalize_attributes.hpp
src/include/migraphx/normalize_attributes.hpp
+1
-1
src/include/migraphx/onnx.hpp
src/include/migraphx/onnx.hpp
+3
-3
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+2
-2
src/include/migraphx/op/argmin.hpp
src/include/migraphx/op/argmin.hpp
+3
-3
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+2
-2
src/include/migraphx/op/concat.hpp
src/include/migraphx/op/concat.hpp
+8
-8
src/include/migraphx/op/convolution.hpp
src/include/migraphx/op/convolution.hpp
+8
-8
src/include/migraphx/op/deconvolution.hpp
src/include/migraphx/op/deconvolution.hpp
+11
-11
src/include/migraphx/op/flatten.hpp
src/include/migraphx/op/flatten.hpp
+2
-2
src/include/migraphx/op/gru.hpp
src/include/migraphx/op/gru.hpp
+3
-3
src/include/migraphx/op/im2col.hpp
src/include/migraphx/op/im2col.hpp
+5
-5
src/include/migraphx/op/load.hpp
src/include/migraphx/op/load.hpp
+1
-1
src/include/migraphx/op/lstm.hpp
src/include/migraphx/op/lstm.hpp
+3
-3
src/include/migraphx/op/multinomial.hpp
src/include/migraphx/op/multinomial.hpp
+4
-4
src/include/migraphx/op/nonzero.hpp
src/include/migraphx/op/nonzero.hpp
+4
-4
No files found.
src/common.cpp
View file @
edc23800
...
@@ -20,15 +20,15 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -20,15 +20,15 @@ inline namespace MIGRAPHX_INLINE_NS {
// In this case we need to broadcast the (:,:,1:,:) axis
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
// output_lens = (3,2,7,5)
std
::
vector
<
std
::
size_
t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_
t
>
s0
,
std
::
vector
<
in
t
>
compute_broadcasted_lens
(
std
::
vector
<
in
t
>
s0
,
std
::
vector
<
std
::
size_
t
>
s1
)
std
::
vector
<
in
t
>
s1
)
{
{
if
(
s0
==
s1
)
if
(
s0
==
s1
)
return
s0
;
return
s0
;
if
(
s0
.
size
()
>
s1
.
size
())
if
(
s0
.
size
()
>
s1
.
size
())
s0
.
swap
(
s1
);
s0
.
swap
(
s1
);
std
::
vector
<
std
::
size_
t
>
out_lens
(
s1
);
std
::
vector
<
in
t
>
out_lens
(
s1
);
auto
offset
=
s1
.
size
()
-
s0
.
size
();
auto
offset
=
s1
.
size
()
-
s0
.
size
();
std
::
transform
(
std
::
transform
(
s0
.
begin
(),
s0
.
end
(),
s1
.
begin
()
+
offset
,
out_lens
.
begin
()
+
offset
,
[
&
](
auto
a
,
auto
b
)
{
s0
.
begin
(),
s0
.
end
(),
s1
.
begin
()
+
offset
,
out_lens
.
begin
()
+
offset
,
[
&
](
auto
a
,
auto
b
)
{
...
@@ -43,7 +43,7 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
...
@@ -43,7 +43,7 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
return
out_lens
;
return
out_lens
;
}
}
std
::
vector
<
std
::
size_
t
>
compute_common_lens
(
const
std
::
vector
<
shape
>&
shapes
)
std
::
vector
<
in
t
>
compute_common_lens
(
const
std
::
vector
<
shape
>&
shapes
)
{
{
assert
(
not
shapes
.
empty
());
assert
(
not
shapes
.
empty
());
return
transform_accumulate
(
shapes
.
begin
()
+
1
,
return
transform_accumulate
(
shapes
.
begin
()
+
1
,
...
...
src/eliminate_allocation.cpp
View file @
edc23800
...
@@ -17,8 +17,8 @@ void eliminate_allocation::apply(module& p) const
...
@@ -17,8 +17,8 @@ void eliminate_allocation::apply(module& p) const
{
{
assert
(
alignment
>
0
);
assert
(
alignment
>
0
);
std
::
size_
t
n
=
0
;
in
t
n
=
0
;
std
::
vector
<
std
::
pair
<
instruction_ref
,
std
::
size_
t
>>
allocs
;
std
::
vector
<
std
::
pair
<
instruction_ref
,
in
t
>>
allocs
;
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
if
(
ins
->
name
()
!=
allocation_op
)
if
(
ins
->
name
()
!=
allocation_op
)
...
...
src/eliminate_concat.cpp
View file @
edc23800
...
@@ -36,7 +36,7 @@ void eliminate_concat::apply(module& p) const
...
@@ -36,7 +36,7 @@ void eliminate_concat::apply(module& p) const
// we only need to check the first input
// we only need to check the first input
auto
lens
=
ins
->
inputs
().
front
()
->
get_shape
().
lens
();
auto
lens
=
ins
->
inputs
().
front
()
->
get_shape
().
lens
();
auto
concat_op
=
concat_opt
.
get_concat
(
ins
->
get_operator
());
auto
concat_op
=
concat_opt
.
get_concat
(
ins
->
get_operator
());
std
::
size_
t
axis_index
=
tune_axis
(
lens
.
size
(),
concat_op
.
axis
,
concat_op
.
name
());
in
t
axis_index
=
tune_axis
(
lens
.
size
(),
concat_op
.
axis
,
concat_op
.
name
());
if
(
axis_index
==
0
||
if
(
axis_index
==
0
||
std
::
all_of
(
lens
.
begin
(),
lens
.
begin
()
+
axis_index
,
[](
auto
x
)
{
return
x
==
1
;
}))
std
::
all_of
(
lens
.
begin
(),
lens
.
begin
()
+
axis_index
,
[](
auto
x
)
{
return
x
==
1
;
}))
{
{
...
@@ -70,7 +70,7 @@ void eliminate_concat::apply(module& p) const
...
@@ -70,7 +70,7 @@ void eliminate_concat::apply(module& p) const
auto
first
=
sorted_allocations
.
front
();
auto
first
=
sorted_allocations
.
front
();
auto
super
=
p
.
move_instruction
(
last
,
first
);
auto
super
=
p
.
move_instruction
(
last
,
first
);
// Replace each allocation with a load
// Replace each allocation with a load
std
::
size_
t
offset
=
0
;
in
t
offset
=
0
;
for
(
auto
alloc
:
allocations
)
for
(
auto
alloc
:
allocations
)
{
{
op
::
load
op
{
alloc
->
get_shape
(),
offset
};
op
::
load
op
{
alloc
->
get_shape
(),
offset
};
...
...
src/include/migraphx/check_shapes.hpp
View file @
edc23800
...
@@ -39,7 +39,7 @@ struct check_shapes
...
@@ -39,7 +39,7 @@ struct check_shapes
return
name
+
": "
;
return
name
+
": "
;
}
}
std
::
size_
t
size
()
const
in
t
size
()
const
{
{
if
(
begin
==
end
)
if
(
begin
==
end
)
return
0
;
return
0
;
...
@@ -57,14 +57,14 @@ struct check_shapes
...
@@ -57,14 +57,14 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
const
check_shapes
&
nelements
(
std
::
size_
t
n
)
const
const
check_shapes
&
nelements
(
in
t
n
)
const
{
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes must have only "
+
std
::
to_string
(
n
)
+
" elements"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes must have only "
+
std
::
to_string
(
n
)
+
" elements"
);
return
*
this
;
return
*
this
;
}
}
const
check_shapes
&
only_dims
(
std
::
size_
t
n
)
const
const
check_shapes
&
only_dims
(
in
t
n
)
const
{
{
assert
(
begin
!=
nullptr
);
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
assert
(
end
!=
nullptr
);
...
@@ -76,7 +76,7 @@ struct check_shapes
...
@@ -76,7 +76,7 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
const
check_shapes
&
max_ndims
(
std
::
size_
t
n
)
const
const
check_shapes
&
max_ndims
(
in
t
n
)
const
{
{
assert
(
begin
!=
nullptr
);
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
assert
(
end
!=
nullptr
);
...
@@ -89,7 +89,7 @@ struct check_shapes
...
@@ -89,7 +89,7 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
const
check_shapes
&
min_ndims
(
std
::
size_
t
n
)
const
const
check_shapes
&
min_ndims
(
in
t
n
)
const
{
{
assert
(
begin
!=
nullptr
);
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
assert
(
end
!=
nullptr
);
...
@@ -179,7 +179,7 @@ struct check_shapes
...
@@ -179,7 +179,7 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
const
check_shapes
&
elements
(
std
::
size_
t
n
)
const
const
check_shapes
&
elements
(
in
t
n
)
const
{
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Wrong number of elements"
);
MIGRAPHX_THROW
(
prefix
()
+
"Wrong number of elements"
);
...
@@ -230,13 +230,13 @@ struct check_shapes
...
@@ -230,13 +230,13 @@ struct check_shapes
check_shapes
slice
(
long
start
,
long
last
)
const
{
return
{
get
(
start
),
get
(
last
),
name
};
}
check_shapes
slice
(
long
start
,
long
last
)
const
{
return
{
get
(
start
),
get
(
last
),
name
};
}
private:
private:
static
bool
batch_not_transposed_strides
(
const
std
::
vector
<
std
::
size_
t
>&
strides
)
static
bool
batch_not_transposed_strides
(
const
std
::
vector
<
in
t
>&
strides
)
{
{
if
(
strides
.
size
()
<=
2
)
if
(
strides
.
size
()
<=
2
)
return
true
;
return
true
;
auto
dim_0
=
strides
.
size
()
-
2
;
auto
dim_0
=
strides
.
size
()
-
2
;
auto
matrix_size
=
std
::
max
(
strides
[
dim_0
],
strides
[
dim_0
+
1
]);
auto
matrix_size
=
std
::
max
(
strides
[
dim_0
],
strides
[
dim_0
+
1
]);
std
::
vector
<
std
::
size_
t
>
batch
(
strides
.
begin
(),
strides
.
begin
()
+
dim_0
);
std
::
vector
<
in
t
>
batch
(
strides
.
begin
(),
strides
.
begin
()
+
dim_0
);
if
(
std
::
all_of
(
batch
.
begin
(),
batch
.
end
(),
[
&
](
auto
i
)
{
return
(
i
<
matrix_size
);
}))
if
(
std
::
all_of
(
batch
.
begin
(),
batch
.
end
(),
[
&
](
auto
i
)
{
return
(
i
<
matrix_size
);
}))
{
{
return
false
;
return
false
;
...
...
src/include/migraphx/common.hpp
View file @
edc23800
...
@@ -11,8 +11,8 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -11,8 +11,8 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
module
;
struct
module
;
struct
operation
;
struct
operation
;
std
::
vector
<
std
::
size_
t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_
t
>
s0
,
std
::
vector
<
in
t
>
compute_broadcasted_lens
(
std
::
vector
<
in
t
>
s0
,
std
::
vector
<
std
::
size_
t
>
s1
);
std
::
vector
<
in
t
>
s1
);
shape
common_shape
(
const
std
::
vector
<
shape
>&
shapes
);
shape
common_shape
(
const
std
::
vector
<
shape
>&
shapes
);
instruction_ref
insert_common_op
(
module
&
m
,
instruction_ref
insert_common_op
(
module
&
m
,
...
...
src/include/migraphx/normalize_attributes.hpp
View file @
edc23800
...
@@ -19,7 +19,7 @@ struct select_dependent_type
...
@@ -19,7 +19,7 @@ struct select_dependent_type
template
<
class
T
,
class
...
Ts
>
template
<
class
T
,
class
...
Ts
>
using
dependent_type
=
typename
select_dependent_type
<
T
,
Ts
...
>::
type
;
using
dependent_type
=
typename
select_dependent_type
<
T
,
Ts
...
>::
type
;
bool
normalize_attributes
(
operation
&
op
,
const
std
::
vector
<
std
::
size_
t
>&
lens
);
bool
normalize_attributes
(
operation
&
op
,
const
std
::
vector
<
in
t
>&
lens
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/onnx.hpp
View file @
edc23800
...
@@ -11,9 +11,9 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -11,9 +11,9 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
onnx_options
struct
onnx_options
{
{
/// default batch size to use (if not specified in onnx file)
/// default batch size to use (if not specified in onnx file)
std
::
size_
t
default_dim_value
=
1
;
in
t
default_dim_value
=
1
;
/// Explicitly specify the dims of an input
/// Explicitly specify the dims of an input
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_
t
>>
map_input_dims
=
{};
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
in
t
>>
map_input_dims
=
{};
/// Continue parsing onnx file if an unknown operator is found
/// Continue parsing onnx file if an unknown operator is found
bool
skip_unknown_operators
=
false
;
bool
skip_unknown_operators
=
false
;
/// Print program if an error occurs
/// Print program if an error occurs
...
@@ -29,7 +29,7 @@ program parse_onnx(const std::string& name, const onnx_options& = onnx_options{}
...
@@ -29,7 +29,7 @@ program parse_onnx(const std::string& name, const onnx_options& = onnx_options{}
program
parse_onnx_buffer
(
const
std
::
string
&
buffer
,
const
onnx_options
&
options
);
program
parse_onnx_buffer
(
const
std
::
string
&
buffer
,
const
onnx_options
&
options
);
/// Create a program from an onnx buffer
/// Create a program from an onnx buffer
program
parse_onnx_buffer
(
const
void
*
data
,
std
::
size_
t
size
,
const
onnx_options
&
options
);
program
parse_onnx_buffer
(
const
void
*
data
,
in
t
size
,
const
onnx_options
&
options
);
std
::
vector
<
std
::
string
>
get_onnx_operators
();
std
::
vector
<
std
::
string
>
get_onnx_operators
();
...
...
src/include/migraphx/op/argmax.hpp
View file @
edc23800
...
@@ -44,11 +44,11 @@ struct argmax
...
@@ -44,11 +44,11 @@ struct argmax
}
}
template
<
class
T
>
template
<
class
T
>
int64_t
calc_argmax
(
T
&
input
,
std
::
vector
<
std
::
size_
t
>&
indices
,
size_
t
item_num
)
const
int64_t
calc_argmax
(
T
&
input
,
std
::
vector
<
in
t
>&
indices
,
in
t
item_num
)
const
{
{
auto
max_val
=
input
(
indices
.
begin
(),
indices
.
end
());
auto
max_val
=
input
(
indices
.
begin
(),
indices
.
end
());
int64_t
max_index
=
0
;
int64_t
max_index
=
0
;
for
(
std
::
size_
t
i
=
1
;
i
<
item_num
;
++
i
)
for
(
in
t
i
=
1
;
i
<
item_num
;
++
i
)
{
{
indices
[
axis
]
=
i
;
indices
[
axis
]
=
i
;
auto
cur_val
=
input
(
indices
.
begin
(),
indices
.
end
());
auto
cur_val
=
input
(
indices
.
begin
(),
indices
.
end
());
...
...
src/include/migraphx/op/argmin.hpp
View file @
edc23800
...
@@ -44,11 +44,11 @@ struct argmin
...
@@ -44,11 +44,11 @@ struct argmin
}
}
template
<
class
T
>
template
<
class
T
>
int64_t
calc_argmin
(
T
&
input
,
std
::
vector
<
std
::
size_
t
>&
indices
,
size_
t
item_num
)
const
int64_t
calc_argmin
(
T
&
input
,
std
::
vector
<
in
t
>&
indices
,
in
t
item_num
)
const
{
{
auto
min_val
=
input
(
indices
.
begin
(),
indices
.
end
());
auto
min_val
=
input
(
indices
.
begin
(),
indices
.
end
());
int64_t
min_index
=
0
;
int64_t
min_index
=
0
;
for
(
std
::
size_
t
i
=
1
;
i
<
item_num
;
++
i
)
for
(
in
t
i
=
1
;
i
<
item_num
;
++
i
)
{
{
indices
[
axis
]
=
i
;
indices
[
axis
]
=
i
;
auto
cur_val
=
input
(
indices
.
begin
(),
indices
.
end
());
auto
cur_val
=
input
(
indices
.
begin
(),
indices
.
end
());
...
@@ -65,7 +65,7 @@ struct argmin
...
@@ -65,7 +65,7 @@ struct argmin
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
std
::
size_
t
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
axis
];
in
t
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
axis
];
result
.
visit
([
&
](
auto
output
)
{
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
...
...
src/include/migraphx/op/broadcast.hpp
View file @
edc23800
...
@@ -25,7 +25,7 @@ namespace op {
...
@@ -25,7 +25,7 @@ namespace op {
struct
broadcast
struct
broadcast
{
{
uint64_t
axis
=
0
;
uint64_t
axis
=
0
;
std
::
vector
<
std
::
size_
t
>
broadcast_lens
;
std
::
vector
<
in
t
>
broadcast_lens
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -39,7 +39,7 @@ struct broadcast
...
@@ -39,7 +39,7 @@ struct broadcast
auto
input
=
inputs
.
at
(
0
);
auto
input
=
inputs
.
at
(
0
);
auto
t
=
input
.
type
();
auto
t
=
input
.
type
();
std
::
vector
<
size_
t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
std
::
vector
<
in
t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
// the broacast op is deprecated now, so not handling the negative
// the broacast op is deprecated now, so not handling the negative
// value of axis anymore
// value of axis anymore
if
(
axis
>=
broadcast_lens
.
size
())
if
(
axis
>=
broadcast_lens
.
size
())
...
...
src/include/migraphx/op/concat.hpp
View file @
edc23800
...
@@ -37,12 +37,12 @@ struct concat
...
@@ -37,12 +37,12 @@ struct concat
}
}
std
::
string
name
()
const
{
return
"concat"
;
}
std
::
string
name
()
const
{
return
"concat"
;
}
std
::
vector
<
std
::
size_
t
>
compute_offsets
(
const
shape
&
output_shape
,
std
::
vector
<
in
t
>
compute_offsets
(
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
const
std
::
vector
<
argument
>&
args
)
const
{
{
auto
n_dims
=
args
[
0
].
get_shape
().
lens
().
size
();
auto
n_dims
=
args
[
0
].
get_shape
().
lens
().
size
();
std
::
vector
<
std
::
size_
t
>
offsets
;
std
::
vector
<
in
t
>
offsets
;
std
::
vector
<
std
::
size_
t
>
offset
(
n_dims
,
0
);
std
::
vector
<
in
t
>
offset
(
n_dims
,
0
);
offset
[
axis
]
=
0
;
offset
[
axis
]
=
0
;
for
(
const
auto
&
arg
:
args
)
for
(
const
auto
&
arg
:
args
)
{
{
...
@@ -60,7 +60,7 @@ struct concat
...
@@ -60,7 +60,7 @@ struct concat
const
auto
&
first_shape_lens
=
inputs
.
front
().
lens
();
const
auto
&
first_shape_lens
=
inputs
.
front
().
lens
();
const
auto
&
type
=
inputs
.
front
().
type
();
const
auto
&
type
=
inputs
.
front
().
type
();
for
(
std
::
size_
t
l
=
0
;
l
<
first_shape_lens
.
size
();
l
++
)
for
(
in
t
l
=
0
;
l
<
first_shape_lens
.
size
();
l
++
)
{
{
if
(
l
!=
axis
)
if
(
l
!=
axis
)
{
{
...
@@ -72,13 +72,13 @@ struct concat
...
@@ -72,13 +72,13 @@ struct concat
}
}
}
}
}
}
std
::
size_
t
new_dim_axis
=
0
;
in
t
new_dim_axis
=
0
;
for
(
const
auto
&
input
:
inputs
)
for
(
const
auto
&
input
:
inputs
)
{
{
const
auto
&
lens
=
input
.
lens
();
const
auto
&
lens
=
input
.
lens
();
new_dim_axis
+=
lens
[
axis
];
new_dim_axis
+=
lens
[
axis
];
}
}
std
::
vector
<
std
::
size_
t
>
new_lens
;
std
::
vector
<
in
t
>
new_lens
;
std
::
copy
(
first_shape_lens
.
begin
(),
first_shape_lens
.
end
(),
std
::
back_inserter
(
new_lens
));
std
::
copy
(
first_shape_lens
.
begin
(),
first_shape_lens
.
end
(),
std
::
back_inserter
(
new_lens
));
new_lens
[
axis
]
=
new_dim_axis
;
new_lens
[
axis
]
=
new_dim_axis
;
return
shape
::
from_permutation
(
type
,
new_lens
,
find_permutation
(
inputs
));
return
shape
::
from_permutation
(
type
,
new_lens
,
find_permutation
(
inputs
));
...
@@ -86,8 +86,8 @@ struct concat
...
@@ -86,8 +86,8 @@ struct concat
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
std
::
vector
<
std
::
size_
t
>
coffsets
=
compute_offsets
(
output_shape
,
args
);
std
::
vector
<
in
t
>
coffsets
=
compute_offsets
(
output_shape
,
args
);
for
(
std
::
size_
t
l
=
0
;
l
<
args
.
size
();
l
++
)
for
(
in
t
l
=
0
;
l
<
args
.
size
();
l
++
)
{
{
auto
argl
=
args
[
l
];
auto
argl
=
args
[
l
];
visit_all
(
result
,
argl
)([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
argl
)([
&
](
auto
output
,
auto
input
)
{
...
...
src/include/migraphx/op/convolution.hpp
View file @
edc23800
...
@@ -20,9 +20,9 @@ namespace op {
...
@@ -20,9 +20,9 @@ namespace op {
struct
convolution
struct
convolution
{
{
std
::
vector
<
std
::
size_
t
>
padding
=
{
0
,
0
};
std
::
vector
<
in
t
>
padding
=
{
0
,
0
};
std
::
vector
<
std
::
size_
t
>
stride
=
{
1
,
1
};
std
::
vector
<
in
t
>
stride
=
{
1
,
1
};
std
::
vector
<
std
::
size_
t
>
dilation
=
{
1
,
1
};
std
::
vector
<
in
t
>
dilation
=
{
1
,
1
};
int
group
=
1
;
int
group
=
1
;
padding_mode_t
padding_mode
=
default_
;
padding_mode_t
padding_mode
=
default_
;
...
@@ -64,7 +64,7 @@ struct convolution
...
@@ -64,7 +64,7 @@ struct convolution
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
weights
=
inputs
.
at
(
1
);
const
shape
&
weights
=
inputs
.
at
(
1
);
size_
t
kdims
=
input_size
-
2
;
in
t
kdims
=
input_size
-
2
;
if
(
kdims
!=
this
->
kdims
())
if
(
kdims
!=
this
->
kdims
())
{
{
MIGRAPHX_THROW
(
"convolution: input k-dims does not match attribute size"
);
MIGRAPHX_THROW
(
"convolution: input k-dims does not match attribute size"
);
...
@@ -73,14 +73,14 @@ struct convolution
...
@@ -73,14 +73,14 @@ struct convolution
if
(
input
.
lens
().
at
(
1
)
!=
(
weights
.
lens
().
at
(
1
)
*
group
))
if
(
input
.
lens
().
at
(
1
)
!=
(
weights
.
lens
().
at
(
1
)
*
group
))
MIGRAPHX_THROW
(
"CONVOLUTION: Mismatch channel numbers"
);
MIGRAPHX_THROW
(
"CONVOLUTION: Mismatch channel numbers"
);
std
::
vector
<
size_
t
>
output_lens
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
]};
std
::
vector
<
in
t
>
output_lens
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
]};
for
(
size_
t
i
=
0
;
i
<
kdims
;
i
++
)
for
(
in
t
i
=
0
;
i
<
kdims
;
i
++
)
{
{
auto
padding_factor
=
2
*
padding
[
i
];
auto
padding_factor
=
2
*
padding
[
i
];
if
(
padding_size
==
2
*
kdims
)
if
(
padding_size
==
2
*
kdims
)
padding_factor
=
padding
[
i
]
+
padding
[
i
+
kdims
];
padding_factor
=
padding
[
i
]
+
padding
[
i
+
kdims
];
output_lens
.
push_back
(
std
::
size_
t
(
std
::
max
<
std
::
ptrdiff_t
>
(
output_lens
.
push_back
(
in
t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
1
,
(
input
.
lens
()[
i
+
2
]
-
(
1
+
dilation
[
i
]
*
(
weights
.
lens
()[
i
+
2
]
-
1
))
+
(
input
.
lens
()[
i
+
2
]
-
(
1
+
dilation
[
i
]
*
(
weights
.
lens
()[
i
+
2
]
-
1
))
+
padding_factor
)
/
padding_factor
)
/
...
@@ -91,7 +91,7 @@ struct convolution
...
@@ -91,7 +91,7 @@ struct convolution
return
inputs
[
0
].
with_lens
(
output_lens
);
return
inputs
[
0
].
with_lens
(
output_lens
);
}
}
size_
t
kdims
()
const
in
t
kdims
()
const
{
{
check_attribute_size
();
check_attribute_size
();
return
stride
.
size
();
return
stride
.
size
();
...
...
src/include/migraphx/op/deconvolution.hpp
View file @
edc23800
...
@@ -20,9 +20,9 @@ namespace op {
...
@@ -20,9 +20,9 @@ namespace op {
struct
deconvolution
struct
deconvolution
{
{
std
::
vector
<
std
::
size_
t
>
padding
=
{
0
,
0
};
std
::
vector
<
in
t
>
padding
=
{
0
,
0
};
std
::
vector
<
std
::
size_
t
>
stride
=
{
1
,
1
};
std
::
vector
<
in
t
>
stride
=
{
1
,
1
};
std
::
vector
<
std
::
size_
t
>
dilation
=
{
1
,
1
};
std
::
vector
<
in
t
>
dilation
=
{
1
,
1
};
padding_mode_t
padding_mode
=
default_
;
padding_mode_t
padding_mode
=
default_
;
int
group
=
1
;
int
group
=
1
;
...
@@ -54,17 +54,17 @@ struct deconvolution
...
@@ -54,17 +54,17 @@ struct deconvolution
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
weights
=
inputs
.
at
(
1
);
const
shape
&
weights
=
inputs
.
at
(
1
);
size_
t
kdims
=
input
.
lens
().
size
()
-
2
;
in
t
kdims
=
input
.
lens
().
size
()
-
2
;
if
(
kdims
!=
this
->
kdims
())
if
(
kdims
!=
this
->
kdims
())
{
{
MIGRAPHX_THROW
(
"deconvolution: input k-dims does not match attribute size"
);
MIGRAPHX_THROW
(
"deconvolution: input k-dims does not match attribute size"
);
}
}
std
::
vector
<
size_
t
>
output_lens
{
input
.
lens
()[
0
],
weights
.
lens
()[
1
]};
std
::
vector
<
in
t
>
output_lens
{
input
.
lens
()[
0
],
weights
.
lens
()[
1
]};
for
(
size_
t
i
=
0
;
i
<
kdims
;
i
++
)
for
(
in
t
i
=
0
;
i
<
kdims
;
i
++
)
{
{
output_lens
.
push_back
(
std
::
size_
t
(
std
::
max
<
std
::
ptrdiff_t
>
(
output_lens
.
push_back
(
in
t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
1
,
stride
[
i
]
*
(
input
.
lens
()[
i
+
2
]
-
1
)
+
stride
[
i
]
*
(
input
.
lens
()[
i
+
2
]
-
1
)
+
((
weights
.
lens
()[
i
+
2
]
-
1
)
*
dilation
[
i
]
+
1
)
-
2
*
padding
[
i
])));
((
weights
.
lens
()[
i
+
2
]
-
1
)
*
dilation
[
i
]
+
1
)
-
2
*
padding
[
i
])));
...
@@ -91,7 +91,7 @@ struct deconvolution
...
@@ -91,7 +91,7 @@ struct deconvolution
auto
out_lens
=
output_shape
.
lens
();
auto
out_lens
=
output_shape
.
lens
();
std
::
vector
<
std
::
size_
t
>
win_size
{
in_c
};
std
::
vector
<
in
t
>
win_size
{
in_c
};
std
::
copy
(
in_lens
.
begin
()
+
2
,
in_lens
.
end
(),
std
::
back_inserter
(
win_size
));
std
::
copy
(
in_lens
.
begin
()
+
2
,
in_lens
.
end
(),
std
::
back_inserter
(
win_size
));
std
::
copy
(
wei
.
begin
()
+
2
,
wei
.
end
(),
std
::
back_inserter
(
win_size
));
std
::
copy
(
wei
.
begin
()
+
2
,
wei
.
end
(),
std
::
back_inserter
(
win_size
));
shape
win_shape
{
output_shape
.
type
(),
win_size
};
shape
win_shape
{
output_shape
.
type
(),
win_size
};
...
@@ -105,7 +105,7 @@ struct deconvolution
...
@@ -105,7 +105,7 @@ struct deconvolution
auto
wei_dims_start
=
idx_win
.
begin
()
+
kdims
+
1
;
auto
wei_dims_start
=
idx_win
.
begin
()
+
kdims
+
1
;
std
::
vector
<
std
::
ptrdiff_t
>
win_start
;
std
::
vector
<
std
::
ptrdiff_t
>
win_start
;
for
(
std
::
size_
t
n
=
0
;
n
<
kdims
;
++
n
)
for
(
in
t
n
=
0
;
n
<
kdims
;
++
n
)
{
{
win_start
.
push_back
(
std
::
ptrdiff_t
(
*
(
input_dims_start
+
n
)
*
stride
[
n
])
-
win_start
.
push_back
(
std
::
ptrdiff_t
(
*
(
input_dims_start
+
n
)
*
stride
[
n
])
-
std
::
ptrdiff_t
(
padding
[
n
]));
std
::
ptrdiff_t
(
padding
[
n
]));
...
@@ -116,7 +116,7 @@ struct deconvolution
...
@@ -116,7 +116,7 @@ struct deconvolution
std
::
vector
<
std
::
ptrdiff_t
>
idx_out
{
o
,
in_ch
};
std
::
vector
<
std
::
ptrdiff_t
>
idx_out
{
o
,
in_ch
};
for
(
size_
t
n
=
0
;
n
<
kdims
;
n
++
)
for
(
in
t
n
=
0
;
n
<
kdims
;
n
++
)
{
{
idx_out
.
push_back
(
win_start
[
n
]
+
*
(
wei_dims_start
+
n
)
*
dilation
[
n
]);
idx_out
.
push_back
(
win_start
[
n
]
+
*
(
wei_dims_start
+
n
)
*
dilation
[
n
]);
}
}
...
@@ -147,7 +147,7 @@ struct deconvolution
...
@@ -147,7 +147,7 @@ struct deconvolution
return
result
;
return
result
;
}
}
size_
t
kdims
()
const
in
t
kdims
()
const
{
{
check_attribute_size
();
check_attribute_size
();
return
stride
.
size
();
return
stride
.
size
();
...
...
src/include/migraphx/op/flatten.hpp
View file @
edc23800
...
@@ -42,9 +42,9 @@ struct flatten
...
@@ -42,9 +42,9 @@ struct flatten
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
&&
lens
=
inputs
.
front
().
lens
();
auto
&&
lens
=
inputs
.
front
().
lens
();
auto
x
=
auto
x
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
std
::
size_
t
{
1
},
std
::
multiplies
<>
{});
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
in
t
{
1
},
std
::
multiplies
<>
{});
auto
y
=
auto
y
=
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_
t
{
1
},
std
::
multiplies
<>
{});
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
in
t
{
1
},
std
::
multiplies
<>
{});
return
{
inputs
.
at
(
0
).
type
(),
{
x
,
y
}};
return
{
inputs
.
at
(
0
).
type
(),
{
x
,
y
}};
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
...
...
src/include/migraphx/op/gru.hpp
View file @
edc23800
...
@@ -21,7 +21,7 @@ namespace op {
...
@@ -21,7 +21,7 @@ namespace op {
struct
gru
struct
gru
{
{
std
::
size_
t
hidden_size
=
1
;
in
t
hidden_size
=
1
;
std
::
vector
<
operation
>
actv_funcs
{
sigmoid
{},
tanh
{}};
std
::
vector
<
operation
>
actv_funcs
{
sigmoid
{},
tanh
{}};
rnn_direction
direction
=
rnn_direction
::
forward
;
rnn_direction
direction
=
rnn_direction
::
forward
;
float
clip
=
0.0
f
;
float
clip
=
0.0
f
;
...
@@ -47,7 +47,7 @@ struct gru
...
@@ -47,7 +47,7 @@ struct gru
MIGRAPHX_THROW
(
"GRU: hidden size mismatch in attribute and input"
);
MIGRAPHX_THROW
(
"GRU: hidden size mismatch in attribute and input"
);
}
}
std
::
size_
t
num_directions
=
1
;
in
t
num_directions
=
1
;
if
(
direction
==
rnn_direction
::
bidirectional
)
if
(
direction
==
rnn_direction
::
bidirectional
)
{
{
num_directions
=
2
;
num_directions
=
2
;
...
@@ -58,7 +58,7 @@ struct gru
...
@@ -58,7 +58,7 @@ struct gru
MIGRAPHX_THROW
(
"GRU: num_direction does not match the direction attribute"
);
MIGRAPHX_THROW
(
"GRU: num_direction does not match the direction attribute"
);
}
}
std
::
vector
<
std
::
size_
t
>
out_dims
(
in_dims
);
std
::
vector
<
in
t
>
out_dims
(
in_dims
);
out_dims
.
insert
(
out_dims
.
begin
()
+
1
,
num_directions
);
out_dims
.
insert
(
out_dims
.
begin
()
+
1
,
num_directions
);
out_dims
.
back
()
=
hidden_size
;
out_dims
.
back
()
=
hidden_size
;
...
...
src/include/migraphx/op/im2col.hpp
View file @
edc23800
...
@@ -14,9 +14,9 @@ namespace op {
...
@@ -14,9 +14,9 @@ namespace op {
struct
im2col
struct
im2col
{
{
std
::
vector
<
std
::
size_
t
>
padding
{
0
,
0
};
std
::
vector
<
in
t
>
padding
{
0
,
0
};
std
::
vector
<
std
::
size_
t
>
stride
{
1
,
1
};
std
::
vector
<
in
t
>
stride
{
1
,
1
};
std
::
vector
<
std
::
size_
t
>
dilation
{
1
,
1
};
std
::
vector
<
in
t
>
dilation
{
1
,
1
};
padding_mode_t
padding_mode
=
default_
;
padding_mode_t
padding_mode
=
default_
;
...
@@ -52,11 +52,11 @@ struct im2col
...
@@ -52,11 +52,11 @@ struct im2col
padding_h
=
padding
[
0
]
+
padding
[
2
];
padding_h
=
padding
[
0
]
+
padding
[
2
];
padding_w
=
padding
[
1
]
+
padding
[
3
];
padding_w
=
padding
[
1
]
+
padding
[
3
];
}
}
auto
output_height
=
std
::
size_
t
(
std
::
max
<
std
::
ptrdiff_t
>
(
auto
output_height
=
in
t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
1
,
(
input
.
lens
()[
2
]
-
(
1
+
dilation
[
0
]
*
(
kernel_height
-
1
))
+
padding_h
)
/
stride
[
0
]
+
(
input
.
lens
()[
2
]
-
(
1
+
dilation
[
0
]
*
(
kernel_height
-
1
))
+
padding_h
)
/
stride
[
0
]
+
1
));
1
));
auto
output_width
=
std
::
size_
t
(
std
::
max
<
std
::
ptrdiff_t
>
(
auto
output_width
=
in
t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
1
,
(
input
.
lens
()[
3
]
-
(
1
+
dilation
[
1
]
*
(
kernel_width
-
1
))
+
padding_w
)
/
stride
[
1
]
+
(
input
.
lens
()[
3
]
-
(
1
+
dilation
[
1
]
*
(
kernel_width
-
1
))
+
padding_w
)
/
stride
[
1
]
+
1
));
1
));
...
...
src/include/migraphx/op/load.hpp
View file @
edc23800
...
@@ -17,7 +17,7 @@ namespace op {
...
@@ -17,7 +17,7 @@ namespace op {
struct
load
struct
load
{
{
shape
s
;
shape
s
;
std
::
size_
t
offset
=
0
;
in
t
offset
=
0
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
...
src/include/migraphx/op/lstm.hpp
View file @
edc23800
...
@@ -21,7 +21,7 @@ namespace op {
...
@@ -21,7 +21,7 @@ namespace op {
struct
lstm
struct
lstm
{
{
std
::
size_
t
hidden_size
=
1
;
in
t
hidden_size
=
1
;
std
::
vector
<
operation
>
actv_funcs
{
sigmoid
{},
tanh
{},
tanh
{}};
std
::
vector
<
operation
>
actv_funcs
{
sigmoid
{},
tanh
{},
tanh
{}};
rnn_direction
direction
=
rnn_direction
::
forward
;
rnn_direction
direction
=
rnn_direction
::
forward
;
float
clip
=
0.0
f
;
float
clip
=
0.0
f
;
...
@@ -47,7 +47,7 @@ struct lstm
...
@@ -47,7 +47,7 @@ struct lstm
MIGRAPHX_THROW
(
"LSTM: hidden size mismatch in attribute and input"
);
MIGRAPHX_THROW
(
"LSTM: hidden size mismatch in attribute and input"
);
}
}
std
::
size_
t
num_directions
=
1
;
in
t
num_directions
=
1
;
if
(
direction
==
rnn_direction
::
bidirectional
)
if
(
direction
==
rnn_direction
::
bidirectional
)
{
{
num_directions
=
2
;
num_directions
=
2
;
...
@@ -58,7 +58,7 @@ struct lstm
...
@@ -58,7 +58,7 @@ struct lstm
MIGRAPHX_THROW
(
"LSTM: num_direction does not match the direction attribute"
);
MIGRAPHX_THROW
(
"LSTM: num_direction does not match the direction attribute"
);
}
}
std
::
vector
<
std
::
size_
t
>
out_dims
(
in_dims
);
std
::
vector
<
in
t
>
out_dims
(
in_dims
);
out_dims
.
insert
(
out_dims
.
begin
()
+
1
,
num_directions
);
out_dims
.
insert
(
out_dims
.
begin
()
+
1
,
num_directions
);
out_dims
.
back
()
=
hidden_size
;
out_dims
.
back
()
=
hidden_size
;
...
...
src/include/migraphx/op/multinomial.hpp
View file @
edc23800
...
@@ -24,7 +24,7 @@ struct multinomial
...
@@ -24,7 +24,7 @@ struct multinomial
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
only_dims
(
2
);
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
only_dims
(
2
);
size_
t
sample_size
=
inputs
.
back
().
lens
().
back
();
in
t
sample_size
=
inputs
.
back
().
lens
().
back
();
if
(
not
contains
({
shape
::
int32_type
,
shape
::
int64_type
},
dtype
))
if
(
not
contains
({
shape
::
int32_type
,
shape
::
int64_type
},
dtype
))
MIGRAPHX_THROW
(
MIGRAPHX_THROW
(
...
@@ -36,9 +36,9 @@ struct multinomial
...
@@ -36,9 +36,9 @@ struct multinomial
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
size_
t
batch_size
=
output_shape
.
lens
().
front
();
in
t
batch_size
=
output_shape
.
lens
().
front
();
size_
t
class_size
=
args
[
0
].
get_shape
().
lens
().
back
();
in
t
class_size
=
args
[
0
].
get_shape
().
lens
().
back
();
size_
t
sample_size
=
output_shape
.
lens
().
back
();
in
t
sample_size
=
output_shape
.
lens
().
back
();
visit_all
(
args
[
0
],
args
[
1
])([
&
](
auto
cdf
,
auto
dist
)
{
visit_all
(
args
[
0
],
args
[
1
])([
&
](
auto
cdf
,
auto
dist
)
{
result
.
visit
([
&
](
auto
output
)
{
result
.
visit
([
&
](
auto
output
)
{
...
...
src/include/migraphx/op/nonzero.hpp
View file @
edc23800
...
@@ -21,15 +21,15 @@ struct nonzero
...
@@ -21,15 +21,15 @@ struct nonzero
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
elem_num
=
inputs
[
0
].
elements
();
auto
elem_num
=
inputs
[
0
].
elements
();
auto
dim_num
=
inputs
[
0
].
lens
().
size
();
int
dim_num
=
inputs
[
0
].
lens
().
size
();
std
::
vector
<
std
::
size_
t
>
out_lens
=
{
dim_num
,
elem_num
};
std
::
vector
<
in
t
>
out_lens
=
{
dim_num
,
elem_num
};
return
{
shape
::
int64_type
,
out_lens
};
return
{
shape
::
int64_type
,
out_lens
};
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
std
::
vector
<
std
::
vector
<
std
::
size_
t
>>
vec_idx
;
std
::
vector
<
std
::
vector
<
in
t
>>
vec_idx
;
auto
s
=
args
.
front
().
get_shape
();
auto
s
=
args
.
front
().
get_shape
();
args
.
front
().
visit
([
&
](
auto
v
)
{
args
.
front
().
visit
([
&
](
auto
v
)
{
shape_for_each
(
s
,
[
&
](
auto
idx
)
{
shape_for_each
(
s
,
[
&
](
auto
idx
)
{
...
@@ -44,7 +44,7 @@ struct nonzero
...
@@ -44,7 +44,7 @@ struct nonzero
result
.
visit
([
&
](
auto
output
)
{
result
.
visit
([
&
](
auto
output
)
{
std
::
fill
(
output
.
begin
(),
output
.
end
(),
0
);
std
::
fill
(
output
.
begin
(),
output
.
end
(),
0
);
par_for
(
vec_idx
.
size
(),
[
&
](
auto
i
)
{
par_for
(
vec_idx
.
size
(),
[
&
](
auto
i
)
{
for
(
std
::
size_
t
j
=
0
;
j
<
vec_idx
.
front
().
size
();
++
j
)
for
(
in
t
j
=
0
;
j
<
vec_idx
.
front
().
size
();
++
j
)
{
{
output
[
output_shape
.
index
({
j
,
i
})]
=
vec_idx
[
i
][
j
];
output
[
output_shape
.
index
({
j
,
i
})]
=
vec_idx
[
i
][
j
];
}
}
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment