Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
edc23800
Commit
edc23800
authored
Feb 11, 2022
by
Shucai Xiao
Browse files
change the data type for lens and strides from size_t to int in the shape class
parent
c7419a9c
Changes
63
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
224 additions
and
129 deletions
+224
-129
src/onnx/parse_imagescalar.cpp
src/onnx/parse_imagescalar.cpp
+1
-1
src/onnx/parse_matmul.cpp
src/onnx/parse_matmul.cpp
+2
-2
src/onnx/parse_multinomial.cpp
src/onnx/parse_multinomial.cpp
+2
-2
src/onnx/parse_nonzero.cpp
src/onnx/parse_nonzero.cpp
+7
-7
src/onnx/parse_onehot.cpp
src/onnx/parse_onehot.cpp
+1
-1
src/onnx/parse_pad.cpp
src/onnx/parse_pad.cpp
+2
-2
src/onnx/parse_pooling.cpp
src/onnx/parse_pooling.cpp
+3
-3
src/onnx/parse_shape.cpp
src/onnx/parse_shape.cpp
+2
-2
src/patch.patch
src/patch.patch
+95
-0
src/reduce_dims.cpp
src/reduce_dims.cpp
+11
-11
src/rewrite_batchnorm.cpp
src/rewrite_batchnorm.cpp
+3
-3
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+10
-10
src/shape.cpp
src/shape.cpp
+48
-48
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+11
-11
src/targets/gpu/device/include/migraphx/gpu/device/reduce.hpp
...targets/gpu/device/include/migraphx/gpu/device/reduce.hpp
+3
-3
src/targets/ref/gemm.cpp
src/targets/ref/gemm.cpp
+7
-7
src/tf/parse_conv.cpp
src/tf/parse_conv.cpp
+7
-7
src/tf/parse_depthwiseconv.cpp
src/tf/parse_depthwiseconv.cpp
+6
-6
src/tf/parse_expanddims.cpp
src/tf/parse_expanddims.cpp
+2
-2
src/tf/parse_onehot.cpp
src/tf/parse_onehot.cpp
+1
-1
No files found.
src/onnx/parse_imagescalar.cpp
View file @
edc23800
...
...
@@ -33,7 +33,7 @@ struct parse_imagescalar : op_parser<parse_imagescalar>
auto
input_type
=
input_shape
.
type
();
auto
scale_val
=
info
.
add_literal
(
literal
{
shape
{
input_type
},
{
scale
}});
auto
bias_vals
=
info
.
add_literal
(
literal
{
shape
{
input_type
,
{
bias
.
size
()}},
bias
});
auto
bias_vals
=
info
.
add_literal
(
literal
{
shape
{
input_type
,
{
static_cast
<
int
>
(
bias
.
size
()
)
}},
bias
});
auto
scale_tensor
=
info
.
add_instruction
(
migraphx
::
make_op
(
"scalar"
,
{{
"scalar_bcst_dims"
,
input_lens
}}),
scale_val
);
...
...
src/onnx/parse_matmul.cpp
View file @
edc23800
...
...
@@ -47,9 +47,9 @@ struct parse_matmul : op_parser<parse_matmul>
if
(
!
std
::
equal
(
l0_lens
.
rbegin
()
+
2
,
l0_lens
.
rend
(),
l1_lens
.
rbegin
()
+
2
,
l1_lens
.
rend
()))
{
auto
l0_it
=
l0_lens
.
begin
()
+
l0_lens
.
size
()
-
2
;
std
::
vector
<
std
::
size_
t
>
l0_broadcasted_lens
(
l0_lens
.
begin
(),
l0_it
);
std
::
vector
<
in
t
>
l0_broadcasted_lens
(
l0_lens
.
begin
(),
l0_it
);
auto
l1_it
=
l1_lens
.
begin
()
+
l1_lens
.
size
()
-
2
;
std
::
vector
<
std
::
size_
t
>
l1_broadcasted_lens
(
l1_lens
.
begin
(),
l1_it
);
std
::
vector
<
in
t
>
l1_broadcasted_lens
(
l1_lens
.
begin
(),
l1_it
);
auto
output_lens
=
compute_broadcasted_lens
(
l0_broadcasted_lens
,
l1_broadcasted_lens
);
l0_broadcasted_lens
=
output_lens
;
l0_broadcasted_lens
.
insert
(
l0_broadcasted_lens
.
end
(),
l0_it
,
l0_lens
.
end
());
...
...
src/onnx/parse_multinomial.cpp
View file @
edc23800
...
...
@@ -23,7 +23,7 @@ struct parse_multinomial : op_parser<parse_multinomial>
dtype
=
info
.
attributes
.
at
(
"dtype"
).
i
();
shape
::
type_t
output_type
=
get_type
(
dtype
);
size_
t
sample_size
=
1
;
in
t
sample_size
=
1
;
if
(
contains
(
info
.
attributes
,
"sample_size"
))
sample_size
=
info
.
attributes
.
at
(
"sample_size"
).
i
();
...
...
@@ -46,7 +46,7 @@ struct parse_multinomial : op_parser<parse_multinomial>
gen
.
seed
(
info
.
attributes
.
at
(
"seed"
).
f
());
std
::
uniform_real_distribution
<>
dis
(
0.0
,
1.0
);
size_
t
batch_size
=
args
[
0
]
->
get_shape
().
lens
().
front
();
in
t
batch_size
=
args
[
0
]
->
get_shape
().
lens
().
front
();
migraphx
::
shape
dist_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
sample_size
}};
std
::
vector
<
float
>
random_dist
(
batch_size
*
sample_size
);
...
...
src/onnx/parse_nonzero.cpp
View file @
edc23800
...
...
@@ -9,10 +9,10 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
onnx
{
template
<
class
T
>
static
std
::
vector
<
std
::
size_
t
>
nonzero_indices
(
const
std
::
vector
<
T
>&
data
)
static
std
::
vector
<
in
t
>
nonzero_indices
(
const
std
::
vector
<
T
>&
data
)
{
std
::
vector
<
std
::
size_
t
>
indices
;
for
(
std
::
size_
t
i
=
0
;
i
<
data
.
size
();
++
i
)
std
::
vector
<
in
t
>
indices
;
for
(
in
t
i
=
0
;
i
<
data
.
size
();
++
i
)
{
if
(
!
float_equal
(
data
[
i
],
0
))
indices
.
push_back
(
i
);
...
...
@@ -37,7 +37,7 @@ struct parse_nonzero : op_parser<parse_nonzero>
}
else
{
std
::
vector
<
std
::
size_
t
>
indices
;
std
::
vector
<
in
t
>
indices
;
data_arg
.
visit
([
&
](
auto
val
)
{
using
val_type
=
std
::
remove_cv_t
<
typename
decltype
(
val
)
::
value_type
>
;
std
::
vector
<
val_type
>
vec_data
;
...
...
@@ -46,13 +46,13 @@ struct parse_nonzero : op_parser<parse_nonzero>
});
shape
in_s
=
args
[
0
]
->
get_shape
();
shape
out_s
{
shape
::
int64_type
,
{
in_s
.
lens
().
size
(),
indices
.
size
()}};
shape
out_s
{
shape
::
int64_type
,
{
static_cast
<
int
>
(
in_s
.
lens
().
size
()
)
,
static_cast
<
int
>
(
indices
.
size
()
)
}};
std
::
vector
<
int64_t
>
out_data
(
out_s
.
elements
());
for
(
std
::
size_
t
i
=
0
;
i
<
indices
.
size
();
++
i
)
for
(
in
t
i
=
0
;
i
<
indices
.
size
();
++
i
)
{
auto
idx
=
in_s
.
multi
(
indices
[
i
]);
for
(
std
::
size_
t
j
=
0
;
j
<
in_s
.
lens
().
size
();
++
j
)
for
(
in
t
j
=
0
;
j
<
in_s
.
lens
().
size
();
++
j
)
{
out_data
[
out_s
.
index
({
j
,
i
})]
=
idx
[
j
];
}
...
...
src/onnx/parse_onehot.cpp
View file @
edc23800
...
...
@@ -20,7 +20,7 @@ struct parse_onehot : op_parser<parse_onehot>
{
migraphx
::
argument
depth_arg
=
args
[
1
]
->
eval
();
check_arg_empty
(
depth_arg
,
"PARSE_ONEHOT: depth - dynamic shape not supported"
);
size_
t
depth
=
depth_arg
.
at
<
size_
t
>
();
in
t
depth
=
depth_arg
.
at
<
in
t
>
();
int64_t
axis
=
-
1
;
if
(
contains
(
info
.
attributes
,
"axis"
))
...
...
src/onnx/parse_pad.cpp
View file @
edc23800
...
...
@@ -32,7 +32,7 @@ instruction_ref reflect_pad(const onnx_parser::node_info& info,
const
std
::
vector
<
int64_t
>&
pads
,
instruction_ref
input
)
{
size_
t
num_dims
=
pads
.
size
()
/
2
;
in
t
num_dims
=
pads
.
size
()
/
2
;
std
::
vector
<
int
>
ldims
(
pads
.
begin
(),
pads
.
begin
()
+
num_dims
);
std
::
vector
<
int
>
rdims
(
pads
.
begin
()
+
num_dims
,
pads
.
end
());
assert
(
ldims
.
size
()
==
rdims
.
size
());
...
...
@@ -50,7 +50,7 @@ instruction_ref reflect_pad(const onnx_parser::node_info& info,
continue
;
// calculate starts and ends for each iteration since shape may change
std
::
vector
<
size_
t
>
dims
=
input
->
get_shape
().
lens
();
std
::
vector
<
in
t
>
dims
=
input
->
get_shape
().
lens
();
std
::
vector
<
int64_t
>
starts
(
axes
.
size
(),
0
);
std
::
vector
<
int64_t
>
ends
(
dims
.
begin
(),
dims
.
end
());
std
::
vector
<
instruction_ref
>
slices
;
...
...
src/onnx/parse_pooling.cpp
View file @
edc23800
...
...
@@ -36,7 +36,7 @@ struct parse_pooling : op_parser<parse_pooling>
if
(
starts_with
(
opd
.
onnx_name
,
"Global"
))
{
values
[
"lengths"
]
=
std
::
vector
<
size_
t
>
(
in_lens
.
begin
()
+
2
,
in_lens
.
end
());
values
[
"lengths"
]
=
std
::
vector
<
in
t
>
(
in_lens
.
begin
()
+
2
,
in_lens
.
end
());
}
// does not support ceil_mode
...
...
@@ -86,7 +86,7 @@ struct parse_pooling : op_parser<parse_pooling>
// return paddings could be empty, then setting to 0 for no padding
cal_auto_padding_size
(
info
,
values
,
values
[
"lengths"
].
to_vector
<
std
::
size_
t
>
(),
values
[
"lengths"
].
to_vector
<
in
t
>
(),
{
1
,
1
},
in_lens
,
paddings
);
...
...
@@ -133,7 +133,7 @@ struct parse_pooling : op_parser<parse_pooling>
slice_end
.
begin
(),
[](
auto
i
,
auto
j
)
{
return
i
+
j
;
});
}
values
[
"padding"
]
=
std
::
vector
<
size_
t
>
(
paddings
.
begin
(),
paddings
.
end
());
values
[
"padding"
]
=
std
::
vector
<
in
t
>
(
paddings
.
begin
(),
paddings
.
end
());
check_asym_padding
(
info
,
l0
,
paddings
,
values
,
count_include_pad
,
pad_val
);
op
.
from_value
(
values
);
...
...
src/onnx/parse_shape.cpp
View file @
edc23800
...
...
@@ -20,9 +20,9 @@ struct parse_shape : op_parser<parse_shape>
{
if
(
args
.
size
()
!=
1
)
MIGRAPHX_THROW
(
"Shape: operator should have 1 operand"
);
std
::
vector
<
std
::
size_
t
>
arg_shape
=
args
[
0
]
->
get_shape
().
lens
();
std
::
vector
<
in
t
>
arg_shape
=
args
[
0
]
->
get_shape
().
lens
();
std
::
vector
<
int64_t
>
vec_shape
(
arg_shape
.
size
());
migraphx
::
shape
s
(
migraphx
::
shape
::
int64_type
,
{
arg_shape
.
size
()});
migraphx
::
shape
s
(
migraphx
::
shape
::
int64_type
,
{
static_cast
<
int
>
(
arg_shape
.
size
()
)
});
std
::
transform
(
arg_shape
.
begin
(),
arg_shape
.
end
(),
vec_shape
.
begin
(),
[](
auto
i
)
{
return
int64_t
(
i
);
});
...
...
src/patch.patch
0 → 100644
View file @
edc23800
diff --git a/src/include/migraphx/op/capture.hpp b/src/include/migraphx/op/capture.hpp
index f33eab9bb..80ffcbe6b 100644
--- a/src/include/migraphx/op/capture.hpp
+++ b/src/include/migraphx/op/capture.hpp
@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
+#include <migraphx/context.hpp>
#include <cmath>
#include <utility>
@@ -29,7 +30,9 @@
struct capture
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
- argument compute(const shape&, std::vector<argument> args) const
+ argument compute(const shape&, std::vector<argument> args) const { return args.front(); }
+
+ argument compute(context&, const shape&, const std::vector<argument>& args) const
{
if(f)
{
diff --git a/src/include/migraphx/operation.hpp b/src/include/migraphx/operation.hpp
index 922eabd67..56108a871 100644
--- a/src/include/migraphx/operation.hpp
+++ b/src/include/migraphx/operation.hpp
@@ -271,25 +271,25 @@
auto compute_op(rank<3>,
template <class T, class F>
auto compute_op(rank<2>,
const T& x,
- context&,
+ context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
- F) -> decltype(x.compute(output, inputs))
+ F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs))
{
- return x.compute(output, inputs);
+ return x.compute(auto_any_cast(ctx), output, inputs);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
- context& ctx,
+ context&,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
- F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs))
+ F) -> decltype(x.compute(output, inputs))
{
- return x.compute(auto_any_cast(ctx), output, inputs);
+ return x.compute(output, inputs);
}
template <class T, class F>
diff --git a/tools/include/operation.hpp b/tools/include/operation.hpp
index 0c49edfaf..ef9927cdc 100644
--- a/tools/include/operation.hpp
+++ b/tools/include/operation.hpp
@@ -271,25 +271,25 @@
auto compute_op(rank<3>,
template <class T, class F>
auto compute_op(rank<2>,
const T& x,
- context&,
+ context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
- F) -> decltype(x.compute(output, inputs))
+ F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs))
{
- return x.compute(output, inputs);
+ return x.compute(auto_any_cast(ctx), output, inputs);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
- context& ctx,
+ context&,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
- F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs))
+ F) -> decltype(x.compute(output, inputs))
{
- return x.compute(auto_any_cast(ctx), output, inputs);
+ return x.compute(output, inputs);
}
template <class T, class F>
src/reduce_dims.cpp
View file @
edc23800
...
...
@@ -3,9 +3,9 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
bool
reduce_dim
(
std
::
vector
<
shape
>&
shapes
,
std
::
size_
t
n
)
bool
reduce_dim
(
std
::
vector
<
shape
>&
shapes
,
in
t
n
)
{
std
::
vector
<
std
::
size_
t
>
new_lens
;
std
::
vector
<
in
t
>
new_lens
;
for
(
const
auto
&
s
:
shapes
)
{
assert
(
n
<
s
.
lens
().
size
());
...
...
@@ -23,7 +23,7 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
}
if
(
new_lens
.
size
()
!=
shapes
.
size
())
return
false
;
std
::
size_
t
i
=
0
;
in
t
i
=
0
;
for
(
auto
&
s
:
shapes
)
{
auto
lens
=
s
.
lens
();
...
...
@@ -37,7 +37,7 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
return
true
;
}
std
::
size_
t
reduce_dim_all
(
std
::
vector
<
shape
>&
shapes
,
std
::
size_
t
n
)
in
t
reduce_dim_all
(
std
::
vector
<
shape
>&
shapes
,
in
t
n
)
{
while
(
reduce_dim
(
shapes
,
n
)
and
n
<
shapes
.
size
())
{
...
...
@@ -47,16 +47,16 @@ std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n)
}
void
reduce_dim_all
(
std
::
vector
<
shape
>&
shapes
)
{
std
::
size_
t
n
=
0
;
in
t
n
=
0
;
while
(
n
<
shapes
.
front
().
lens
().
size
()
-
1
)
n
=
reduce_dim_all
(
shapes
,
n
);
}
std
::
vector
<
std
::
size_
t
>
base_lens
(
const
std
::
vector
<
shape
>&
shapes
)
std
::
vector
<
in
t
>
base_lens
(
const
std
::
vector
<
shape
>&
shapes
)
{
return
std
::
accumulate
(
shapes
.
begin
()
+
1
,
shapes
.
end
(),
shapes
.
front
().
lens
(),
[](
auto
&&
lens
,
auto
&&
s
)
{
std
::
vector
<
std
::
size_
t
>
result
;
std
::
vector
<
in
t
>
result
;
const
auto
*
x
=
&
s
.
lens
();
const
auto
*
y
=
&
lens
;
if
(
x
->
size
()
>
y
->
size
())
...
...
@@ -69,12 +69,12 @@ std::vector<std::size_t> base_lens(const std::vector<shape>& shapes)
});
}
shape
mask_shape
(
const
shape
&
s
,
const
std
::
vector
<
std
::
size_
t
>&
lens
)
shape
mask_shape
(
const
shape
&
s
,
const
std
::
vector
<
in
t
>&
lens
)
{
assert
(
s
.
lens
().
size
()
==
lens
.
size
());
std
::
vector
<
std
::
size_
t
>
rstrides
(
lens
.
size
());
std
::
size_
t
stride
=
1
;
for
(
std
::
size_
t
i
=
lens
.
size
()
-
1
;
i
<
lens
.
size
();
i
--
)
std
::
vector
<
in
t
>
rstrides
(
lens
.
size
());
in
t
stride
=
1
;
for
(
in
t
i
=
lens
.
size
()
-
1
;
i
<
lens
.
size
();
i
--
)
{
if
(
lens
[
i
]
==
s
.
lens
()[
i
])
{
...
...
src/rewrite_batchnorm.cpp
View file @
edc23800
...
...
@@ -28,7 +28,7 @@ void rewrite_batchnorm::apply(module& p) const
if
(
any_of
({
gamma
,
bias
,
mean
,
variance
},
[](
auto
arg
)
{
return
arg
.
empty
();
}))
continue
;
std
::
vector
<
std
::
size_
t
>
lens
=
ins
->
inputs
()[
1
]
->
get_shape
().
lens
();
std
::
vector
<
in
t
>
lens
=
ins
->
inputs
()[
1
]
->
get_shape
().
lens
();
shape
s
{
ins
->
get_shape
().
type
(),
lens
};
// Get epsilon
auto
bn_op
=
any_cast
<
op
::
batch_norm_inference
>
(
ins
->
get_operator
());
...
...
@@ -39,8 +39,8 @@ void rewrite_batchnorm::apply(module& p) const
visit_all
(
gamma
,
bias
,
mean
,
variance
,
a
,
b
)(
[
&
](
auto
gamma2
,
auto
bias2
,
auto
mean2
,
auto
variance2
,
auto
a2
,
auto
b2
)
{
dfor
(
a
.
get_shape
().
elements
())(
[
&
](
std
::
size_
t
c
)
{
a2
[
c
]
=
gamma2
[
c
]
/
std
::
sqrt
(
variance2
[
c
]
+
epsilon
);
});
dfor
(
b
.
get_shape
().
elements
())([
&
](
std
::
size_
t
c
)
{
[
&
](
in
t
c
)
{
a2
[
c
]
=
gamma2
[
c
]
/
std
::
sqrt
(
variance2
[
c
]
+
epsilon
);
});
dfor
(
b
.
get_shape
().
elements
())([
&
](
in
t
c
)
{
b2
[
c
]
=
bias2
[
c
]
-
(
gamma2
[
c
]
*
mean2
[
c
]
/
std
::
sqrt
(
variance2
[
c
]
+
epsilon
));
});
});
...
...
src/rewrite_rnn.cpp
View file @
edc23800
...
...
@@ -60,8 +60,8 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
auto
args
=
ins
->
inputs
();
shape
seq_shape
=
args
[
0
]
->
get_shape
();
std
::
size_
t
hidden_size
=
args
[
1
]
->
get_shape
().
lens
()[
1
];
std
::
size_
t
batch_size
=
seq_shape
.
lens
()[
1
];
in
t
hidden_size
=
args
[
1
]
->
get_shape
().
lens
()[
1
];
in
t
batch_size
=
seq_shape
.
lens
()[
1
];
shape
::
type_t
type
=
seq_shape
.
type
();
migraphx
::
shape
ih_shape
{
type
,
{
1
,
batch_size
,
hidden_size
}};
std
::
vector
<
float
>
data
(
ih_shape
.
elements
(),
0
);
...
...
@@ -369,8 +369,8 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
auto
args
=
ins
->
inputs
();
shape
seq_shape
=
args
[
0
]
->
get_shape
();
std
::
size_
t
hidden_size
=
args
[
2
]
->
get_shape
().
lens
()[
2
];
std
::
size_
t
batch_size
=
seq_shape
.
lens
()[
1
];
in
t
hidden_size
=
args
[
2
]
->
get_shape
().
lens
()[
2
];
in
t
batch_size
=
seq_shape
.
lens
()[
1
];
shape
::
type_t
type
=
seq_shape
.
type
();
migraphx
::
shape
ih_shape
{
type
,
{
1
,
batch_size
,
hidden_size
}};
std
::
vector
<
float
>
data
(
ih_shape
.
elements
(),
0.0
);
...
...
@@ -754,8 +754,8 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
auto
args
=
ins
->
inputs
();
shape
seq_shape
=
args
[
0
]
->
get_shape
();
std
::
size_
t
hidden_size
=
args
[
2
]
->
get_shape
().
lens
()[
2
];
std
::
size_
t
batch_size
=
seq_shape
.
lens
()[
1
];
in
t
hidden_size
=
args
[
2
]
->
get_shape
().
lens
()[
2
];
in
t
batch_size
=
seq_shape
.
lens
()[
1
];
shape
::
type_t
type
=
seq_shape
.
type
();
migraphx
::
shape
ihc_shape
{
type
,
{
1
,
batch_size
,
hidden_size
}};
std
::
vector
<
float
>
ihc_data
(
ihc_shape
.
elements
(),
0.0
);
...
...
@@ -1195,7 +1195,7 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
// specifiy any actv func. If less than 46, use the
// algorithm in parse_lstm to make 6 actv functions
const
auto
&
actv_funcs
=
lstm_op
.
actv_funcs
;
std
::
size_
t
num_actv_funcs
=
actv_funcs
.
size
();
in
t
num_actv_funcs
=
actv_funcs
.
size
();
if
(
lstm_op
.
direction
==
op
::
rnn_direction
::
bidirectional
)
{
switch
(
num_actv_funcs
)
...
...
@@ -1295,7 +1295,7 @@ bool rewrite_rnn::is_variable_seq_lens(const module& prog, instruction_ref seq_l
return
is_var_lens
;
}
std
::
size_
t
in
t
rewrite_rnn
::
get_seq_len
(
const
module
&
prog
,
instruction_ref
input
,
instruction_ref
seq_lens
)
const
{
bool
is_var_lens
=
is_variable_seq_lens
(
prog
,
seq_lens
);
...
...
@@ -1304,7 +1304,7 @@ rewrite_rnn::get_seq_len(const module& prog, instruction_ref input, instruction_
if
(
!
is_var_lens
and
seq_lens
!=
prog
.
end
())
{
auto
arg_len
=
seq_lens
->
eval
();
std
::
vector
<
std
::
size_
t
>
vec_lens
;
std
::
vector
<
in
t
>
vec_lens
;
arg_len
.
visit
([
&
](
auto
l
)
{
vec_lens
.
assign
(
l
.
begin
(),
l
.
end
());
});
length
=
vec_lens
.
empty
()
?
length
:
vec_lens
[
0
];
}
...
...
@@ -1414,7 +1414,7 @@ instruction_ref rewrite_rnn::pad_hidden_states(module& prog,
{
auto
s
=
hs
->
get_shape
();
auto
pad_lens
=
s
.
lens
();
pad_lens
[
0
]
=
static_cast
<
std
::
size_
t
>
(
max_seq_len
-
seq_len
);
pad_lens
[
0
]
=
static_cast
<
in
t
>
(
max_seq_len
-
seq_len
);
shape
pad_s
{
s
.
type
(),
pad_lens
};
std
::
vector
<
float
>
pad_data
(
pad_s
.
elements
(),
0.0
f
);
auto
pl
=
prog
.
add_literal
(
pad_s
,
pad_data
.
begin
(),
pad_data
.
end
());
...
...
src/shape.cpp
View file @
edc23800
...
...
@@ -26,14 +26,14 @@ struct shape_impl
{
assert
(
t
!=
shape
::
tuple_type
);
}
shape_impl
(
shape
::
type_t
t
,
std
::
vector
<
std
::
size_
t
>
l
)
shape_impl
(
shape
::
type_t
t
,
std
::
vector
<
in
t
>
l
)
:
m_type
(
t
),
m_lens
(
std
::
move
(
l
)),
m_standard
(
true
)
{
assert
(
t
!=
shape
::
tuple_type
);
this
->
calculate_strides
();
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
}
shape_impl
(
shape
::
type_t
t
,
std
::
vector
<
std
::
size_
t
>
l
,
std
::
vector
<
std
::
size_
t
>
s
)
shape_impl
(
shape
::
type_t
t
,
std
::
vector
<
in
t
>
l
,
std
::
vector
<
in
t
>
s
)
:
m_type
(
t
),
m_lens
(
std
::
move
(
l
)),
m_strides
(
std
::
move
(
s
))
{
assert
(
t
!=
shape
::
tuple_type
);
...
...
@@ -46,8 +46,8 @@ struct shape_impl
shape_impl
(
const
std
::
vector
<
shape
>&
subs
)
:
m_type
(
shape
::
tuple_type
),
m_shapes
(
subs
)
{}
shape
::
type_t
m_type
;
std
::
vector
<
std
::
size_
t
>
m_lens
=
{};
std
::
vector
<
std
::
size_
t
>
m_strides
=
{};
std
::
vector
<
in
t
>
m_lens
=
{};
std
::
vector
<
in
t
>
m_strides
=
{};
std
::
vector
<
shape
>
m_shapes
=
{};
bool
m_standard
=
false
;
...
...
@@ -61,10 +61,10 @@ struct shape_impl
std
::
partial_sum
(
m_lens
.
rbegin
(),
m_lens
.
rend
()
-
1
,
m_strides
.
rbegin
()
+
1
,
std
::
multiplies
<
std
::
size_
t
>
());
std
::
multiplies
<
in
t
>
());
}
std
::
size_
t
element_space
()
const
in
t
element_space
()
const
{
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
if
(
m_lens
.
empty
())
...
...
@@ -72,19 +72,19 @@ struct shape_impl
return
std
::
inner_product
(
m_lens
.
begin
(),
m_lens
.
end
(),
m_strides
.
begin
(),
std
::
size_
t
{
0
},
std
::
plus
<
std
::
size_
t
>
{},
[](
std
::
size_t
l
,
std
::
size_
t
s
)
{
return
(
l
-
1
)
*
s
;
})
+
in
t
{
0
},
std
::
plus
<
in
t
>
{},
[](
int
l
,
in
t
s
)
{
return
(
l
-
1
)
*
s
;
})
+
1
;
}
std
::
size_
t
elements
()
const
in
t
elements
()
const
{
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
if
(
m_lens
.
empty
())
return
0
;
return
std
::
accumulate
(
m_lens
.
begin
(),
m_lens
.
end
(),
std
::
size_
t
{
1
},
std
::
multiplies
<
std
::
size_
t
>
());
m_lens
.
begin
(),
m_lens
.
end
(),
in
t
{
1
},
std
::
multiplies
<
in
t
>
());
}
};
...
...
@@ -124,11 +124,11 @@ std::string shape::cpp_type(shape::type_t t)
shape
::
shape
()
:
impl
(
shape_impl
::
default_shape
())
{}
shape
::
shape
(
type_t
t
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
))
{}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_
t
>
l
)
shape
::
shape
(
type_t
t
,
std
::
vector
<
in
t
>
l
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
,
std
::
move
(
l
)))
{
}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_
t
>
l
,
std
::
vector
<
std
::
size_
t
>
s
)
shape
::
shape
(
type_t
t
,
std
::
vector
<
in
t
>
l
,
std
::
vector
<
in
t
>
s
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
,
std
::
move
(
l
),
std
::
move
(
s
)))
{
}
...
...
@@ -136,7 +136,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
shape
::
shape
(
const
std
::
vector
<
shape
>&
subs
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
subs
))
{}
shape
shape
::
from_permutation
(
type_t
t
,
const
std
::
vector
<
std
::
size_
t
>&
l
,
const
std
::
vector
<
in
t
>&
l
,
const
std
::
vector
<
int64_t
>&
perm
)
{
auto
new_lens
=
reorder_dims
(
l
,
perm
);
...
...
@@ -146,14 +146,14 @@ shape shape::from_permutation(type_t t,
}
shape
::
type_t
shape
::
type
()
const
{
return
impl
->
m_type
;
}
const
std
::
vector
<
std
::
size_
t
>&
shape
::
lens
()
const
{
return
impl
->
m_lens
;
}
const
std
::
vector
<
std
::
size_
t
>&
shape
::
strides
()
const
{
return
impl
->
m_strides
;
}
std
::
size_
t
shape
::
elements
()
const
{
return
impl
->
elements
();
}
std
::
size_
t
shape
::
bytes
()
const
const
std
::
vector
<
in
t
>&
shape
::
lens
()
const
{
return
impl
->
m_lens
;
}
const
std
::
vector
<
in
t
>&
shape
::
strides
()
const
{
return
impl
->
m_strides
;
}
in
t
shape
::
elements
()
const
{
return
impl
->
elements
();
}
in
t
shape
::
bytes
()
const
{
if
(
this
->
sub_shapes
().
empty
())
{
std
::
size_
t
n
=
0
;
in
t
n
=
0
;
this
->
visit_type
([
&
](
auto
as
)
{
n
=
as
.
size
();
});
return
n
*
this
->
element_space
();
}
...
...
@@ -161,44 +161,44 @@ std::size_t shape::bytes() const
{
return
std
::
accumulate
(
this
->
sub_shapes
().
begin
(),
this
->
sub_shapes
().
end
(),
std
::
size_
t
{
0
},
in
t
{
0
},
[
&
](
auto
x
,
auto
y
)
{
return
x
+
y
.
bytes
();
});
}
}
std
::
size_
t
shape
::
type_size
()
const
in
t
shape
::
type_size
()
const
{
std
::
size_
t
n
=
0
;
in
t
n
=
0
;
if
(
this
->
sub_shapes
().
empty
())
this
->
visit_type
([
&
](
auto
as
)
{
n
=
as
.
size
();
});
return
n
;
}
std
::
size_
t
shape
::
index
(
std
::
initializer_list
<
std
::
size_
t
>
l
)
const
in
t
shape
::
index
(
std
::
initializer_list
<
in
t
>
l
)
const
{
assert
(
l
.
size
()
<=
this
->
lens
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
std
::
inner_product
(
l
.
begin
(),
l
.
end
(),
this
->
strides
().
begin
(),
std
::
size_
t
{
0
});
return
std
::
inner_product
(
l
.
begin
(),
l
.
end
(),
this
->
strides
().
begin
(),
in
t
{
0
});
}
std
::
size_
t
shape
::
index
(
const
std
::
vector
<
std
::
size_
t
>&
l
)
const
in
t
shape
::
index
(
const
std
::
vector
<
in
t
>&
l
)
const
{
assert
(
l
.
size
()
<=
this
->
lens
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
std
::
inner_product
(
l
.
begin
(),
l
.
end
(),
this
->
strides
().
begin
(),
std
::
size_
t
{
0
});
return
std
::
inner_product
(
l
.
begin
(),
l
.
end
(),
this
->
strides
().
begin
(),
in
t
{
0
});
}
std
::
size_
t
shape
::
index
(
std
::
size_
t
i
)
const
in
t
shape
::
index
(
in
t
i
)
const
{
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
if
(
this
->
standard
())
return
i
;
else
{
std
::
size_
t
s
=
1
;
std
::
size_
t
result
=
0
;
for
(
std
::
size_
t
j
=
0
;
j
<
this
->
lens
().
size
();
j
++
)
in
t
s
=
1
;
in
t
result
=
0
;
for
(
in
t
j
=
0
;
j
<
this
->
lens
().
size
();
j
++
)
{
const
std
::
size_
t
k
=
this
->
lens
().
size
()
-
j
-
1
;
const
std
::
size_
t
stride
=
this
->
strides
()[
k
];
const
std
::
size_
t
len
=
this
->
lens
()[
k
];
const
std
::
size_
t
idx
=
(
i
%
(
s
*
len
))
/
s
;
const
in
t
k
=
this
->
lens
().
size
()
-
j
-
1
;
const
in
t
stride
=
this
->
strides
()[
k
];
const
in
t
len
=
this
->
lens
()[
k
];
const
in
t
idx
=
(
i
%
(
s
*
len
))
/
s
;
result
+=
stride
*
idx
;
s
*=
len
;
}
...
...
@@ -206,17 +206,17 @@ std::size_t shape::index(std::size_t i) const
}
}
std
::
vector
<
std
::
size_
t
>
shape
::
multi
(
std
::
size_
t
i
)
const
std
::
vector
<
in
t
>
shape
::
multi
(
in
t
i
)
const
{
assert
(
this
->
standard
());
std
::
vector
<
std
::
size_
t
>
indices
(
lens
().
size
());
std
::
vector
<
in
t
>
indices
(
lens
().
size
());
multi_copy
(
i
,
indices
.
data
(),
indices
.
data
()
+
lens
().
size
());
return
indices
;
}
void
shape
::
multi_copy
(
std
::
size_t
i
,
std
::
size_
t
*
start
,
const
std
::
size_
t
*
end
)
const
void
shape
::
multi_copy
(
int
i
,
in
t
*
start
,
const
in
t
*
end
)
const
{
assert
(
this
->
standard
());
(
void
)
end
;
...
...
@@ -225,7 +225,7 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end
strides
().
end
(),
lens
().
begin
(),
start
,
[
&
](
std
::
size_t
stride
,
std
::
size_
t
len
)
{
[
&
](
int
stride
,
in
t
len
)
{
assert
(
len
>
0
and
stride
>
0
);
return
(
i
/
stride
)
%
len
;
});
...
...
@@ -241,12 +241,12 @@ bool shape::transposed() const
if
(
this
->
broadcasted
())
{
// TODO: Use a filter_iterator instead
std
::
vector
<
std
::
size_
t
>
s
;
std
::
vector
<
in
t
>
s
;
s
.
reserve
(
this
->
strides
().
size
());
std
::
copy_if
(
this
->
strides
().
begin
(),
this
->
strides
().
end
(),
std
::
back_inserter
(
s
),
[](
std
::
size_
t
x
)
{
return
x
!=
0
;
});
[](
in
t
x
)
{
return
x
!=
0
;
});
return
not
std
::
is_sorted
(
s
.
rbegin
(),
s
.
rend
());
}
else
...
...
@@ -260,8 +260,8 @@ bool shape::broadcasted() const
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
std
::
accumulate
(
this
->
strides
().
begin
(),
this
->
strides
().
end
(),
std
::
size_
t
{
1
},
std
::
multiplies
<
std
::
size_
t
>
())
==
0
;
in
t
{
1
},
std
::
multiplies
<
in
t
>
())
==
0
;
}
bool
shape
::
scalar
()
const
...
...
@@ -269,7 +269,7 @@ bool shape::scalar() const
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
// if any stride > 0, then accumulate will return false
return
this
->
sub_shapes
().
empty
()
and
std
::
accumulate
(
this
->
strides
().
begin
(),
this
->
strides
().
end
(),
std
::
size_
t
(
0
))
==
0
;
std
::
accumulate
(
this
->
strides
().
begin
(),
this
->
strides
().
end
(),
in
t
(
0
))
==
0
;
}
bool
shape
::
standard
()
const
{
return
impl
->
m_standard
;
}
...
...
@@ -282,19 +282,19 @@ shape shape::normalize_standard() const
return
*
this
;
}
shape
shape
::
with_lens
(
type_t
t
,
const
std
::
vector
<
std
::
size_
t
>&
l
)
const
shape
shape
::
with_lens
(
type_t
t
,
const
std
::
vector
<
in
t
>&
l
)
const
{
assert
(
l
.
size
()
==
this
->
lens
().
size
());
auto
perm
=
find_permutation
(
*
this
);
return
shape
::
from_permutation
(
t
,
l
,
perm
);
}
shape
shape
::
with_lens
(
const
std
::
vector
<
std
::
size_
t
>&
l
)
const
shape
shape
::
with_lens
(
const
std
::
vector
<
in
t
>&
l
)
const
{
return
this
->
with_lens
(
this
->
type
(),
l
);
}
std
::
size_
t
shape
::
element_space
()
const
{
return
impl
->
element_space
();
}
in
t
shape
::
element_space
()
const
{
return
impl
->
element_space
();
}
std
::
string
shape
::
type_string
()
const
{
return
name
(
this
->
type
());
}
...
...
@@ -351,8 +351,8 @@ void migraphx_from_value(const value& v, shape& s)
else
{
s
=
shape
{
shape
::
parse_type
(
t
),
v
.
at
(
"lens"
).
to_vector
<
std
::
size_
t
>
(),
v
.
at
(
"strides"
).
to_vector
<
std
::
size_
t
>
()};
v
.
at
(
"lens"
).
to_vector
<
in
t
>
(),
v
.
at
(
"strides"
).
to_vector
<
in
t
>
()};
}
}
...
...
src/simplify_algebra.cpp
View file @
edc23800
...
...
@@ -278,10 +278,10 @@ struct find_concat_op
}
template
<
class
Iterator
>
static
std
::
vector
<
std
::
size_
t
>
get_output_lens
(
Iterator
start
,
Iterator
last
,
std
::
size_
t
axis
)
static
std
::
vector
<
in
t
>
get_output_lens
(
Iterator
start
,
Iterator
last
,
in
t
axis
)
{
assert
(
start
!=
last
);
std
::
size_
t
dim
=
0
;
in
t
dim
=
0
;
for
(
auto
ins
:
range
(
start
,
last
))
{
dim
+=
ins
->
get_shape
().
lens
().
at
(
axis
);
...
...
@@ -323,7 +323,7 @@ struct find_concat_op
}
std
::
vector
<
instruction_ref
>
concats
;
for
(
std
::
size_
t
i
=
0
;
i
<
x
->
inputs
().
size
();
i
++
)
for
(
in
t
i
=
0
;
i
<
x
->
inputs
().
size
();
i
++
)
{
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
start
,
last
,
std
::
back_inserter
(
inputs
),
[
&
](
auto
j
)
{
...
...
@@ -381,7 +381,7 @@ std::vector<instruction_ref> get_splits(instruction_ref ins)
result
.
begin
(),
result
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
return
get_end
(
x
)
!=
get_start
(
y
);
});
if
(
it
!=
result
.
end
())
return
{};
for
(
std
::
size_
t
i
=
0
;
i
<
axes
.
size
();
i
++
)
for
(
in
t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
auto
axis
=
axes
[
i
];
if
(
ins
->
get_shape
().
lens
()[
axis
]
!=
get_slice
(
result
.
back
()).
ends
[
i
])
...
...
@@ -626,16 +626,16 @@ struct find_split_concat
}
};
bool
axis_equal
(
const
std
::
vector
<
std
::
size_
t
>&
x
,
const
std
::
vector
<
std
::
size_
t
>&
y
,
std
::
size_
t
axis
)
bool
axis_equal
(
const
std
::
vector
<
in
t
>&
x
,
const
std
::
vector
<
in
t
>&
y
,
in
t
axis
)
{
return
x
.
size
()
==
y
.
size
()
and
x
.
size
()
>
axis
and
std
::
equal
(
x
.
begin
(),
x
.
begin
()
+
axis
,
y
.
begin
())
and
std
::
equal
(
x
.
begin
()
+
axis
+
1
,
x
.
end
(),
y
.
begin
()
+
axis
+
1
);
}
bool
axis_shape_equal
(
const
shape
&
x
,
const
shape
&
y
,
std
::
size_
t
axis
)
bool
axis_shape_equal
(
const
shape
&
x
,
const
shape
&
y
,
in
t
axis
)
{
// TODO: Check strides
return
axis_equal
(
x
.
lens
(),
y
.
lens
(),
axis
);
...
...
@@ -654,7 +654,7 @@ struct find_add_convs
return
op
.
stride
[
0
]
==
op
.
stride
[
1
];
}
static
std
::
size_
t
compute_stride_factor
(
const
op
::
convolution
&
x
,
const
op
::
convolution
&
y
)
static
in
t
compute_stride_factor
(
const
op
::
convolution
&
x
,
const
op
::
convolution
&
y
)
{
if
(
not
symmetrical_strides
(
x
))
return
0
;
...
...
@@ -913,7 +913,7 @@ struct find_split_reshape
auto
axis
=
any_cast
<
op
::
slice
>
(
slc
->
get_operator
()).
axes
[
0
];
auto
slc_lens
=
slc
->
get_shape
().
lens
();
auto
slc_dim_size
=
std
::
accumulate
(
slc_lens
.
begin
()
+
axis
,
slc_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_
t
>
());
slc_lens
.
begin
()
+
axis
,
slc_lens
.
end
(),
1
,
std
::
multiplies
<
in
t
>
());
// search the reshape output (standard shape) to decide which axis are
// in its output corresponding to the slc_dim_size
...
...
@@ -942,7 +942,7 @@ struct find_split_reshape
// replace the original reshape with slice
int64_t
start
=
0
;
for
(
std
::
size_
t
i
=
0
;
i
<
vec_rsp
.
size
();
++
i
)
for
(
in
t
i
=
0
;
i
<
vec_rsp
.
size
();
++
i
)
{
p
.
replace_instruction
(
vec_rsp
[
i
],
...
...
src/targets/gpu/device/include/migraphx/gpu/device/reduce.hpp
View file @
edc23800
...
...
@@ -174,14 +174,14 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
}
constexpr
index_int
compute_block_size
(
index_int
n
,
index_int
max_block_size
)
{
size_
t
block_size
=
64
;
in
t
block_size
=
64
;
while
(
block_size
<
max_block_size
and
block_size
<
n
)
block_size
*=
2
;
return
block_size
;
}
inline
std
::
vector
<
index_int
>
get_reduce_lens
(
const
std
::
vector
<
size_
t
>&
input_lens
,
const
std
::
vector
<
size_
t
>&
output_lens
)
inline
std
::
vector
<
index_int
>
get_reduce_lens
(
const
std
::
vector
<
in
t
>&
input_lens
,
const
std
::
vector
<
in
t
>&
output_lens
)
{
std
::
vector
<
index_int
>
reduce_lens
;
std
::
transform
(
output_lens
.
begin
(),
...
...
src/targets/ref/gemm.cpp
View file @
edc23800
...
...
@@ -16,9 +16,9 @@ static auto make_mat(tensor_view<T> x)
{
const
auto
&
s
=
x
.
get_shape
();
// assert(s.lens().size() == 2);
std
::
size_
t
n_dims
=
s
.
lens
().
size
();
std
::
size_
t
dim_0
=
n_dims
-
2
;
std
::
size_
t
dim_1
=
n_dims
-
1
;
in
t
n_dims
=
s
.
lens
().
size
();
in
t
dim_0
=
n_dims
-
2
;
in
t
dim_1
=
n_dims
-
1
;
if
(
s
.
transposed
())
return
matrix
<
T
>
{
x
.
data
(),
s
.
lens
()[
dim_1
],
s
.
lens
()[
dim_0
],
s
.
strides
()[
dim_1
]};
return
matrix
<
T
>
{
x
.
data
(),
s
.
lens
()[
dim_0
],
s
.
lens
()[
dim_1
],
s
.
strides
()[
dim_0
]};
...
...
@@ -66,9 +66,9 @@ template <class T, class F>
void
migemm_impl
(
tensor_view
<
T
>
cmat
,
tensor_view
<
T
>
amat
,
tensor_view
<
T
>
bmat
,
F
alpha
,
F
beta
,
std
::
false_type
)
{
std
::
size_
t
n_dims
=
cmat
.
get_shape
().
lens
().
size
();
std
::
size_
t
dim_0
=
n_dims
-
2
;
std
::
size_
t
dim_1
=
n_dims
-
1
;
in
t
n_dims
=
cmat
.
get_shape
().
lens
().
size
();
in
t
dim_0
=
n_dims
-
2
;
in
t
dim_1
=
n_dims
-
1
;
auto
k
=
amat
.
get_shape
().
lens
()[
dim_1
];
assert
(
amat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_0
]);
...
...
@@ -93,7 +93,7 @@ void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat,
auto
lens
=
amat
.
get_shape
().
lens
();
bool
batch_mul
=
std
::
accumulate
(
lens
.
rbegin
()
+
2
,
lens
.
rend
(),
std
::
size_
t
{
1
},
std
::
multiplies
<
std
::
size_
t
>
())
==
1
;
lens
.
rbegin
()
+
2
,
lens
.
rend
(),
in
t
{
1
},
std
::
multiplies
<
in
t
>
())
==
1
;
if
(
batch_mul
)
{
migemm_impl
(
cmat
,
amat
,
bmat
,
alpha
,
beta
,
is_fast_gemm_type
<
T
>
{});
...
...
src/tf/parse_conv.cpp
View file @
edc23800
...
...
@@ -22,7 +22,7 @@ struct parse_conv : op_parser<parse_conv>
op
::
convolution
op
;
if
(
contains
(
info
.
attributes
,
"strides"
))
{
std
::
vector
<
size_
t
>
stride
;
std
::
vector
<
in
t
>
stride
;
copy
(
info
.
attributes
.
at
(
"strides"
).
list
().
i
(),
std
::
back_inserter
(
stride
));
parser
.
reorder_data
(
stride
);
if
(
stride
.
size
()
!=
4
)
...
...
@@ -34,7 +34,7 @@ struct parse_conv : op_parser<parse_conv>
}
if
(
contains
(
info
.
attributes
,
"dilations"
))
{
std
::
vector
<
size_
t
>
dilation
;
std
::
vector
<
in
t
>
dilation
;
copy
(
info
.
attributes
.
at
(
"dilations"
).
list
().
i
(),
std
::
back_inserter
(
dilation
));
parser
.
reorder_data
(
dilation
);
if
(
dilation
.
size
()
!=
4
)
...
...
@@ -53,16 +53,16 @@ struct parse_conv : op_parser<parse_conv>
if
(
pad_mode
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
{
op
.
padding_mode
=
op
::
padding_mode_t
::
same
;
std
::
vector
<
size_
t
>
weight_dims
=
weights
->
get_shape
().
lens
();
size_
t
weight_h
=
weight_dims
[
2
];
size_
t
weight_w
=
weight_dims
[
3
];
std
::
vector
<
in
t
>
weight_dims
=
weights
->
get_shape
().
lens
();
in
t
weight_h
=
weight_dims
[
2
];
in
t
weight_w
=
weight_dims
[
3
];
auto
input_dims
=
l0
->
get_shape
().
lens
();
std
::
vector
<
int64_t
>
pads
(
input_dims
.
size
());
calculate_padding
(
0
,
pads
,
input_dims
[
2
],
op
.
stride
[
0
],
op
.
dilation
[
0
],
weight_h
);
calculate_padding
(
1
,
pads
,
input_dims
[
3
],
op
.
stride
[
1
],
op
.
dilation
[
1
],
weight_w
);
op
.
padding
=
std
::
vector
<
size_
t
>
(
pads
.
begin
(),
pads
.
end
());
op
.
padding
=
std
::
vector
<
in
t
>
(
pads
.
begin
(),
pads
.
end
());
}
else
if
(
pad_mode
.
find
(
"VALID"
)
!=
std
::
string
::
npos
)
{
...
...
@@ -70,7 +70,7 @@ struct parse_conv : op_parser<parse_conv>
}
else
if
(
pad_mode
.
find
(
"EXPLICIT"
)
!=
std
::
string
::
npos
)
{
std
::
vector
<
size_
t
>
padding
;
std
::
vector
<
in
t
>
padding
;
copy
(
info
.
attributes
.
at
(
"explicit_paddings"
).
list
().
i
(),
std
::
back_inserter
(
padding
));
if
(
padding
.
size
()
!=
4
)
...
...
src/tf/parse_depthwiseconv.cpp
View file @
edc23800
...
...
@@ -20,12 +20,12 @@ struct parse_depthwiseconv : op_parser<parse_depthwiseconv>
std
::
vector
<
instruction_ref
>
args
)
const
{
op
::
convolution
op
;
size_
t
num_channels
=
args
[
0
]
->
get_shape
().
lens
()[
1
];
in
t
num_channels
=
args
[
0
]
->
get_shape
().
lens
()[
1
];
op
.
group
=
num_channels
;
if
(
contains
(
info
.
attributes
,
"strides"
))
{
std
::
vector
<
size_
t
>
stride
;
std
::
vector
<
in
t
>
stride
;
copy
(
info
.
attributes
.
at
(
"strides"
).
list
().
i
(),
std
::
back_inserter
(
stride
));
parser
.
reorder_data
(
stride
);
if
(
stride
.
size
()
!=
4
)
...
...
@@ -39,7 +39,7 @@ struct parse_depthwiseconv : op_parser<parse_depthwiseconv>
auto
weights
=
parser
.
to_kcxy
(
args
[
1
]);
if
(
contains
(
info
.
attributes
,
"dilations"
))
{
std
::
vector
<
size_
t
>
dilation
;
std
::
vector
<
in
t
>
dilation
;
copy
(
info
.
attributes
.
at
(
"dilations"
).
list
().
i
(),
std
::
back_inserter
(
dilation
));
parser
.
reorder_data
(
dilation
);
if
(
dilation
.
size
()
!=
4
)
...
...
@@ -58,9 +58,9 @@ struct parse_depthwiseconv : op_parser<parse_depthwiseconv>
if
(
pad_mode
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
{
op
.
padding_mode
=
op
::
padding_mode_t
::
same
;
std
::
vector
<
size_
t
>
weight_dims
=
weights
->
get_shape
().
lens
();
size_
t
weight_h
=
weight_dims
[
2
];
size_
t
weight_w
=
weight_dims
[
3
];
std
::
vector
<
in
t
>
weight_dims
=
weights
->
get_shape
().
lens
();
in
t
weight_h
=
weight_dims
[
2
];
in
t
weight_w
=
weight_dims
[
3
];
auto
input_dims
=
l0
->
get_shape
().
lens
();
std
::
vector
<
int64_t
>
pads
(
input_dims
.
size
());
...
...
src/tf/parse_expanddims.cpp
View file @
edc23800
...
...
@@ -17,9 +17,9 @@ struct parse_expanddims : op_parser<parse_expanddims>
const
tf_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
std
::
vector
<
size_
t
>
input_dims
=
args
[
0
]
->
get_shape
().
lens
();
std
::
vector
<
in
t
>
input_dims
=
args
[
0
]
->
get_shape
().
lens
();
std
::
vector
<
int64_t
>
new_dims
(
input_dims
.
begin
(),
input_dims
.
end
());
size_
t
num_dims
=
input_dims
.
size
();
in
t
num_dims
=
input_dims
.
size
();
int32_t
dim
=
args
[
1
]
->
eval
().
at
<
int32_t
>
();
if
(
dim
<
0
)
...
...
src/tf/parse_onehot.cpp
View file @
edc23800
...
...
@@ -17,7 +17,7 @@ struct parse_onehot : op_parser<parse_onehot>
tf_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
size_
t
depth
=
static_cast
<
size_
t
>
(
args
[
1
]
->
eval
().
at
<
int32_t
>
());
in
t
depth
=
static_cast
<
in
t
>
(
args
[
1
]
->
eval
().
at
<
int32_t
>
());
int64_t
axis
=
-
1
;
float
on_value
=
args
[
2
]
->
eval
().
at
<
float
>
();
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment