Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
099e9ce8
Commit
099e9ce8
authored
Jun 24, 2019
by
Shucai Xiao
Browse files
Merge branch 'develop' of
https://github.com/ROCmSoftwarePlatform/AMDMIGraphX
into argmax_min
parents
274c772b
15eb1987
Changes
18
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
263 additions
and
311 deletions
+263
-311
src/eliminate_pad.cpp
src/eliminate_pad.cpp
+1
-3
src/include/migraphx/op/binary.hpp
src/include/migraphx/op/binary.hpp
+3
-1
src/include/migraphx/op/convolution.hpp
src/include/migraphx/op/convolution.hpp
+18
-45
src/include/migraphx/op/pooling.hpp
src/include/migraphx/op/pooling.hpp
+7
-37
src/include/migraphx/pad_calc.hpp
src/include/migraphx/pad_calc.hpp
+13
-2
src/include/migraphx/stringutils.hpp
src/include/migraphx/stringutils.hpp
+2
-0
src/onnx/CMakeLists.txt
src/onnx/CMakeLists.txt
+1
-1
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+12
-4
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+7
-1
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+94
-173
src/targets/gpu/device/include/migraphx/gpu/device/types.hpp
src/targets/gpu/device/include/migraphx/gpu/device/types.hpp
+2
-2
src/targets/gpu/device/pad.cpp
src/targets/gpu/device/pad.cpp
+11
-1
src/targets/gpu/include/migraphx/gpu/softmax.hpp
src/targets/gpu/include/migraphx/gpu/softmax.hpp
+1
-1
src/tf/CMakeLists.txt
src/tf/CMakeLists.txt
+1
-1
src/tf/tf.cpp
src/tf/tf.cpp
+67
-14
test/CMakeLists.txt
test/CMakeLists.txt
+2
-2
test/eliminate_pad_test.cpp
test/eliminate_pad_test.cpp
+0
-19
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+21
-4
No files found.
src/eliminate_pad.cpp
View file @
099e9ce8
...
@@ -44,8 +44,6 @@ void eliminate_pad::update_op(T,
...
@@ -44,8 +44,6 @@ void eliminate_pad::update_op(T,
std
::
array
<
size_t
,
2
>
new_pads
{
static_cast
<
size_t
>
(
pads
[
2
]),
static_cast
<
size_t
>
(
pads
[
3
])};
std
::
array
<
size_t
,
2
>
new_pads
{
static_cast
<
size_t
>
(
pads
[
2
]),
static_cast
<
size_t
>
(
pads
[
3
])};
T
op
=
any_cast
<
T
>
(
ins
->
get_operator
());
T
op
=
any_cast
<
T
>
(
ins
->
get_operator
());
if
(
op
.
padding_mode
!=
op
::
padding_mode_t
::
default_
)
return
;
op
.
padding
=
new_pads
;
op
.
padding
=
new_pads
;
std
::
vector
<
instruction_ref
>
new_inputs
{
ins
->
inputs
()};
std
::
vector
<
instruction_ref
>
new_inputs
{
ins
->
inputs
()};
...
...
src/include/migraphx/op/binary.hpp
View file @
099e9ce8
...
@@ -28,8 +28,10 @@ struct binary : op_name<Derived>
...
@@ -28,8 +28,10 @@ struct binary : op_name<Derived>
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
auto
s1
=
args
[
0
].
get_shape
();
auto
s2
=
args
[
1
].
get_shape
();
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
if
(
input1
.
get_shape
().
packed
()
and
input2
.
get_shape
().
packed
())
if
(
s1
==
s2
and
input1
.
get_shape
().
packed
()
and
input2
.
get_shape
().
packed
())
{
{
std
::
transform
(
input1
.
begin
(),
std
::
transform
(
input1
.
begin
(),
input1
.
end
(),
input1
.
end
(),
...
...
src/include/migraphx/op/convolution.hpp
View file @
099e9ce8
...
@@ -44,8 +44,7 @@ struct convolution
...
@@ -44,8 +44,7 @@ struct convolution
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
weights
=
inputs
.
at
(
1
);
const
shape
&
weights
=
inputs
.
at
(
1
);
auto
t
=
input
.
type
();
auto
t
=
input
.
type
();
if
(
padding_mode
==
default_
)
{
return
{
t
,
return
{
t
,
{
{
input
.
lens
()[
0
],
input
.
lens
()[
0
],
...
@@ -64,32 +63,6 @@ struct convolution
...
@@ -64,32 +63,6 @@ struct convolution
1
)),
1
)),
}};
}};
}
}
else
if
(
padding_mode
==
same
)
{
return
{
t
,
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
],
static_cast
<
std
::
size_t
>
(
std
::
ceil
(
static_cast
<
double
>
(
input
.
lens
()[
2
])
/
stride
[
0
])),
static_cast
<
std
::
size_t
>
(
std
::
ceil
(
static_cast
<
double
>
(
input
.
lens
()[
3
])
/
stride
[
1
]))}};
}
else
if
(
padding_mode
==
valid
)
{
return
{
t
,
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
],
static_cast
<
std
::
size_t
>
(
std
::
ceil
(
static_cast
<
double
>
(
input
.
lens
()[
2
]
-
weights
.
lens
()[
2
]
+
1
)
/
stride
[
0
])),
static_cast
<
std
::
size_t
>
(
std
::
ceil
(
static_cast
<
double
>
(
input
.
lens
()[
3
]
-
weights
.
lens
()[
3
]
+
1
)
/
stride
[
1
]))}};
}
else
{
MIGRAPHX_THROW
(
"Invalid padding mode"
);
}
}
};
};
}
// namespace op
}
// namespace op
...
...
src/include/migraphx/op/pooling.hpp
View file @
099e9ce8
...
@@ -48,52 +48,22 @@ struct pooling
...
@@ -48,52 +48,22 @@ struct pooling
assert
(
lengths
[
0
]
<=
(
input
.
lens
()[
2
]
+
2
*
padding
[
0
]));
assert
(
lengths
[
0
]
<=
(
input
.
lens
()[
2
]
+
2
*
padding
[
0
]));
assert
(
lengths
[
1
]
<=
(
input
.
lens
()[
3
]
+
2
*
padding
[
1
]));
assert
(
lengths
[
1
]
<=
(
input
.
lens
()[
3
]
+
2
*
padding
[
1
]));
if
(
padding_mode
==
default_
)
{
return
{
t
,
return
{
t
,
{
{
input
.
lens
()[
0
],
input
.
lens
()[
0
],
input
.
lens
()[
1
],
input
.
lens
()[
1
],
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
1
,
floor_divide
<
std
::
ptrdiff_t
>
(
floor_divide
<
std
::
ptrdiff_t
>
(
input
.
lens
()[
2
]
+
2
*
padding
[
0
]
-
lengths
[
0
],
input
.
lens
()[
2
]
+
2
*
padding
[
0
]
-
lengths
[
0
],
stride
[
0
])
+
stride
[
0
])
+
1
)),
1
)),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
1
,
floor_divide
<
std
::
ptrdiff_t
>
(
floor_divide
<
std
::
ptrdiff_t
>
(
input
.
lens
()[
3
]
+
2
*
padding
[
1
]
-
lengths
[
1
],
input
.
lens
()[
3
]
+
2
*
padding
[
1
]
-
lengths
[
1
],
stride
[
1
])
+
stride
[
1
])
+
1
)),
1
)),
}};
}};
}
}
else
if
(
padding_mode
==
same
)
{
return
{
t
,
{
input
.
lens
()[
0
],
input
.
lens
()[
1
],
ceil_divide
<
std
::
size_t
>
(
input
.
lens
()[
2
],
stride
[
0
]),
ceil_divide
<
std
::
size_t
>
(
input
.
lens
()[
3
],
stride
[
1
])}};
}
else
if
(
padding_mode
==
valid
)
{
return
{
t
,
{
input
.
lens
()[
0
],
input
.
lens
()[
1
],
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
floor_divide
<
std
::
ptrdiff_t
>
(
input
.
lens
()[
2
]
-
lengths
[
0
],
stride
[
0
])
+
1
)),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
floor_divide
<
std
::
ptrdiff_t
>
(
input
.
lens
()[
3
]
-
lengths
[
1
],
stride
[
1
])
+
1
)),
}};
}
else
{
MIGRAPHX_THROW
(
"Invalid padding mode"
);
}
}
};
};
}
// namespace op
}
// namespace op
...
...
src/include/migraphx/pad_calc.hpp
View file @
099e9ce8
...
@@ -2,13 +2,24 @@
...
@@ -2,13 +2,24 @@
#define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#include <utility>
#include <utility>
#include <cstdint>
#include <vector>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
std
::
size_t
calculate_padding
(
std
::
size_t
weight_dim
,
std
::
size_t
dilation
)
inline
void
calculate_padding
(
int64_t
idx
,
std
::
vector
<
int64_t
>&
pads
,
int64_t
input_dim
,
int64_t
stride
,
int64_t
dilation
,
int64_t
weight_dim
)
{
{
return
(
dilation
*
(
weight_dim
-
1
))
/
2
;
int64_t
output_dim
=
input_dim
/
stride
;
int64_t
pad
=
std
::
max
(
static_cast
<
int64_t
>
(
0
),
(
output_dim
-
1
)
*
stride
+
dilation
*
weight_dim
-
input_dim
);
pads
[
idx
]
=
pad
/
2
;
pads
[
idx
+
2
]
=
pad
-
pad
/
2
;
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/stringutils.hpp
View file @
099e9ce8
...
@@ -52,6 +52,8 @@ inline std::string transform_string(std::string s, F f)
...
@@ -52,6 +52,8 @@ inline std::string transform_string(std::string s, F f)
inline
std
::
string
to_upper
(
std
::
string
s
)
{
return
transform_string
(
std
::
move
(
s
),
::
toupper
);
}
inline
std
::
string
to_upper
(
std
::
string
s
)
{
return
transform_string
(
std
::
move
(
s
),
::
toupper
);
}
inline
std
::
string
to_lower
(
std
::
string
s
)
{
return
transform_string
(
std
::
move
(
s
),
::
tolower
);
}
inline
bool
starts_with
(
const
std
::
string
&
value
,
const
std
::
string
&
prefix
)
inline
bool
starts_with
(
const
std
::
string
&
value
,
const
std
::
string
&
prefix
)
{
{
if
(
prefix
.
size
()
>
value
.
size
())
if
(
prefix
.
size
()
>
value
.
size
())
...
...
src/onnx/CMakeLists.txt
View file @
099e9ce8
...
@@ -19,7 +19,7 @@ rocm_install_targets(
...
@@ -19,7 +19,7 @@ rocm_install_targets(
add_executable
(
read_onnx read_onnx.cpp
)
add_executable
(
read_onnx read_onnx.cpp
)
rocm_clang_tidy_check
(
read_onnx
)
rocm_clang_tidy_check
(
read_onnx
)
target_link_libraries
(
read_onnx migraphx_onnx
)
target_link_libraries
(
read_onnx
migraphx_cpu
migraphx_onnx
)
if
(
MIGRAPHX_ENABLE_GPU
)
if
(
MIGRAPHX_ENABLE_GPU
)
...
...
src/onnx/onnx.cpp
View file @
099e9ce8
...
@@ -100,6 +100,7 @@ struct onnx_parser
...
@@ -100,6 +100,7 @@ struct onnx_parser
void
init_actv_func
()
void
init_actv_func
()
{
{
// Support name format of all lower case or the first letter capital
map_actv_funcs
.
insert
(
std
::
make_pair
(
"tanh"
,
op
::
tanh
{}));
map_actv_funcs
.
insert
(
std
::
make_pair
(
"tanh"
,
op
::
tanh
{}));
map_actv_funcs
.
insert
(
std
::
make_pair
(
"relu"
,
op
::
relu
{}));
map_actv_funcs
.
insert
(
std
::
make_pair
(
"relu"
,
op
::
relu
{}));
map_actv_funcs
.
insert
(
std
::
make_pair
(
"sigmoid"
,
op
::
sigmoid
{}));
map_actv_funcs
.
insert
(
std
::
make_pair
(
"sigmoid"
,
op
::
sigmoid
{}));
...
@@ -352,7 +353,8 @@ struct onnx_parser
...
@@ -352,7 +353,8 @@ struct onnx_parser
{
{
// insert zeros for pad op (args[0] has 4 dims)
// insert zeros for pad op (args[0] has 4 dims)
padding
=
{
0
,
0
,
padding
[
0
],
padding
[
1
],
0
,
0
,
padding
[
2
],
padding
[
3
]};
padding
=
{
0
,
0
,
padding
[
0
],
padding
[
1
],
0
,
0
,
padding
[
2
],
padding
[
3
]};
l0
=
prog
.
add_instruction
(
op
::
pad
{
padding
},
l0
);
l0
=
prog
.
add_instruction
(
op
::
pad
{
padding
,
std
::
numeric_limits
<
float
>::
lowest
()},
l0
);
}
}
else
else
{
{
...
@@ -870,7 +872,9 @@ struct onnx_parser
...
@@ -870,7 +872,9 @@ struct onnx_parser
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
vec_names
.
clear
();
vec_names
.
resize
(
names
.
size
());
vec_names
.
resize
(
names
.
size
());
std
::
copy
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
(),
[](
auto
name
)
{
return
to_lower
(
name
);
});
}
}
auto
name_it
=
std
::
find_if
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
auto
name_it
=
std
::
find_if
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
...
@@ -961,7 +965,9 @@ struct onnx_parser
...
@@ -961,7 +965,9 @@ struct onnx_parser
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
vec_names
.
clear
();
vec_names
.
resize
(
names
.
size
());
vec_names
.
resize
(
names
.
size
());
std
::
copy
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
(),
[](
auto
name
)
{
return
to_lower
(
name
);
});
}
}
// need 4 activation functions
// need 4 activation functions
...
@@ -1088,7 +1094,9 @@ struct onnx_parser
...
@@ -1088,7 +1094,9 @@ struct onnx_parser
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
vec_names
.
clear
();
vec_names
.
resize
(
names
.
size
());
vec_names
.
resize
(
names
.
size
());
std
::
copy
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
(),
[](
auto
name
)
{
return
to_lower
(
name
);
});
}
}
// need 6 activation functions for bidirectional directions
// need 6 activation functions for bidirectional directions
...
...
src/py/migraphx_py.cpp
View file @
099e9ce8
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/type_name.hpp>
#ifdef HAVE_GPU
#ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/target.hpp>
...
@@ -101,8 +102,13 @@ migraphx::shape to_shape(const py::buffer_info& info)
...
@@ -101,8 +102,13 @@ migraphx::shape to_shape(const py::buffer_info& info)
t
=
as
.
type_enum
();
t
=
as
.
type_enum
();
n
=
sizeof
(
as
());
n
=
sizeof
(
as
());
}
}
});
});
if
(
n
==
0
)
{
MIGRAPHX_THROW
(
"MIGRAPHX PYTHON: Unsupported data type"
+
info
.
format
);
}
auto
strides
=
info
.
strides
;
auto
strides
=
info
.
strides
;
std
::
transform
(
strides
.
begin
(),
strides
.
end
(),
strides
.
begin
(),
[
&
](
auto
i
)
->
std
::
size_t
{
std
::
transform
(
strides
.
begin
(),
strides
.
end
(),
strides
.
begin
(),
[
&
](
auto
i
)
->
std
::
size_t
{
return
n
>
0
?
i
/
n
:
0
;
return
n
>
0
?
i
/
n
:
0
;
...
...
src/rewrite_rnn.cpp
View file @
099e9ce8
...
@@ -205,16 +205,18 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
...
@@ -205,16 +205,18 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
// initial hidden state
// initial hidden state
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
auto
sih_lens
=
sih
->
get_shape
().
lens
();
// bias
// bias
instruction_ref
bb
{};
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
long
hs
=
r
->
get_shape
().
lens
()[
2
];
long
hs
=
static_cast
<
long
>
(
r
->
get_shape
().
lens
()[
2
]
)
;
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
b
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wb
,
rb
);
auto
wrb
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wb
,
rb
);
b
ias
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
().
lens
()
},
b
);
b
b
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
_
lens
},
wr
b
);
}
}
instruction_ref
hidden_out
=
prog
.
end
();
instruction_ref
hidden_out
=
prog
.
end
();
...
@@ -228,19 +230,14 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
...
@@ -228,19 +230,14 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
auto
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_sw
);
auto
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_sw
);
auto
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_sr
);
auto
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_sr
);
auto
xt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
instruction_ref
ht
;
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_ht
,
bias
);
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
bb
);
}
else
{
ht
=
xt_ht
;
}
}
auto
xt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
// apply activation function
// apply activation function
ht
=
prog
.
insert_instruction
(
ins
,
actv_func
,
ht
);
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func
,
xt_
ht
);
sih
=
ht
;
sih
=
ht
;
// add the dimensions of sequence length (axis 0 for sequence length,
// add the dimensions of sequence length (axis 0 for sequence length,
...
@@ -485,62 +482,41 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
...
@@ -485,62 +482,41 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
long
hs
=
static_cast
<
long
>
(
r_shape
.
lens
()[
2
]);
long
hs
=
static_cast
<
long
>
(
r_shape
.
lens
()[
2
]);
migraphx
::
shape
s
(
seq_shape
.
type
(),
{
seq_shape
.
lens
()[
1
],
r_shape
.
lens
()[
2
]});
migraphx
::
shape
s
(
seq_shape
.
type
(),
{
seq_shape
.
lens
()[
1
],
r_shape
.
lens
()[
2
]});
std
::
vector
<
in
t
>
data
(
s
.
elements
(),
1
);
std
::
vector
<
floa
t
>
data
(
s
.
elements
(),
1
.0
f
);
auto
l1
=
prog
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
auto
l1
=
prog
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
// w
eight
matrix
// w matrix
squeeze to 2-dim and do a transpose
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
wz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sw
);
auto
tw
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
sw
);
auto
tran_wz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wz
);
auto
wr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sw
);
auto
tran_wr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wr
);
auto
wh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sw
);
auto
tran_wh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wh
);
// r slide to two part, zr and h
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
rz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sr
);
auto
rzr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
2
*
hs
}},
sr
);
auto
tran_rz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rz
);
auto
trzr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rzr
);
auto
rr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sr
);
auto
tran_rr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rr
);
auto
rh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sr
);
auto
rh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sr
);
auto
t
ran_
rh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rh
);
auto
trh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rh
);
// initial states
// initial states
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
size_t
bs
=
ih
->
get_shape
().
lens
()[
1
];
// bias
// bias
instruction_ref
brcst_bz
{};
instruction_ref
bwb
{};
instruction_ref
brcst_br
{};
instruction_ref
brb_zr
{};
instruction_ref
brcst_wbh
{};
instruction_ref
brb_h
{};
instruction_ref
brcst_rbh
{};
instruction_ref
brcst_bh
{};
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
auto
broadcast_lens
=
sih
->
get_shape
().
lens
();
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
3
*
hs
}},
sbias
);
auto
wbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
bwb
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
{
bs
,
static_cast
<
size_t
>
(
3
*
hs
)}},
wb
);
auto
wbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sbias
);
brcst_wbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
broadcast_lens
},
wbh
);
auto
rbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sbias
);
auto
rbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
sbias
);
auto
rbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
brcst_rbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
broadcast_lens
},
rbh
);
auto
bz
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbz
,
rbz
);
brcst_bz
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
broadcast_lens
},
bz
);
auto
br
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbr
,
rbr
);
auto
rb_zr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
5
*
hs
}},
sbias
);
brcst_br
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
broadcast_lens
},
br
);
auto
rb_h
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
brb_zr
=
prog
.
insert_instruction
(
auto
bh
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbh
,
rb
h
);
ins
,
op
::
broadcast
{
1
,
{
bs
,
static_cast
<
size_t
>
(
2
*
hs
)}}
,
rb
_zr
);
br
cst_
bh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
broadcast_lens
},
b
h
);
brb
_
h
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
{
bs
,
static_cast
<
size_t
>
(
hs
)}},
rb_
h
);
}
}
for
(
long
i
=
0
;
i
<
seq_len
;
i
++
)
for
(
long
i
=
0
;
i
<
seq_len
;
i
++
)
...
@@ -549,56 +525,58 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
...
@@ -549,56 +525,58 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
seq
);
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
seq
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto
xt_w
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tw
);
auto
xt_wz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wz
);
auto
ih1_rzr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
trzr
);
auto
ht_rz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rz
);
auto
xht_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wz
,
ht_rz
);
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
xht_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_z
,
brcst_bz
);
xt_w
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_w
,
bwb
);
ih1_rzr
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ih1_rzr
,
brb_zr
);
}
}
auto
zt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xht_z
);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto
xw_z
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
0
},
{
hs
}},
xt_w
);
auto
xt_wr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wr
);
auto
xw_r
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
hs
},
{
2
*
hs
}},
xt_w
);
auto
ht_rr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rr
);
auto
xw_h
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
2
*
hs
},
{
3
*
hs
}},
xt_w
);
auto
xht_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wr
,
ht_rr
);
if
(
bias
!=
prog
.
end
())
auto
hr_z
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
0
},
{
hs
}},
ih1_rzr
);
{
auto
hr_r
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
hs
},
{
2
*
hs
}},
ih1_rzr
);
xht_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_r
,
brcst_br
);
}
auto
xw_hr_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xw_z
,
hr_z
);
auto
rt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xht_r
);
auto
zt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xw_hr_z
);
auto
xw_hr_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xw_r
,
hr_r
);
auto
rt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xw_hr_r
);
instruction_ref
xht_h
;
instruction_ref
hr_h
{}
;
if
(
linear_before_reset
==
0
)
if
(
linear_before_reset
==
0
)
{
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto
xt_wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wh
);
auto
rt_ht1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
sih
);
auto
rt_ht1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
sih
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht1
,
tran_rh
);
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wh
,
rt_rh
);
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_h
,
brcst_bh
);
hr_h
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht1
,
trh
,
brb_h
);
}
else
{
hr_h
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht1
,
trh
);
}
}
}
}
else
else
{
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto
xt_wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wh
);
instruction_ref
ht1_rh
{};
auto
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rh
);
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ht1_rh
,
brcst_rb
h
);
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
trh
,
brb_
h
);
}
}
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ht1_rh
);
else
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wh
,
rt_rh
);
if
(
bias
!=
prog
.
end
())
{
{
x
ht
_
h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_h
,
brcst_wb
h
);
ht
1_r
h
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tr
h
);
}
}
hr_h
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ht1_rh
);
}
}
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
xht_h
);
auto
xw_hr_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xw_h
,
hr_h
);
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
xw_hr_h
);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto
one_minus_zt
=
prog
.
insert_instruction
(
ins
,
op
::
sub
{},
l1
,
zt
);
auto
one_minus_zt
=
prog
.
insert_instruction
(
ins
,
op
::
sub
{},
l1
,
zt
);
...
@@ -913,35 +891,16 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
...
@@ -913,35 +891,16 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
migraphx
::
shape
r_shape
=
r
->
get_shape
();
migraphx
::
shape
r_shape
=
r
->
get_shape
();
long
seq_len
=
static_cast
<
long
>
(
seq_shape
.
lens
()[
0
]);
long
seq_len
=
static_cast
<
long
>
(
seq_shape
.
lens
()[
0
]);
long
hs
=
static_cast
<
long
>
(
r_shape
.
lens
()[
2
]);
long
hs
=
static_cast
<
long
>
(
r_shape
.
lens
()[
2
]);
auto
bs
=
ih
->
get_shape
().
lens
()[
1
];
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
// w matrix
// w matrix
, squeeze and transpose
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
wi
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sw
);
auto
tsw
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
sw
);
auto
tran_wi
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wi
);
auto
wo
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sw
);
auto
tran_wo
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wo
);
auto
wf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sw
);
auto
tran_wf
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wf
);
auto
wc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sw
);
auto
tran_wc
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wc
);
// r matrix
// r matrix
, squeeze and transpose
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
ri
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sr
);
auto
tsr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
sr
);
auto
tran_ri
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
ri
);
auto
ro
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sr
);
auto
tran_ro
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
ro
);
auto
rf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sr
);
auto
tran_rf
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rf
);
auto
rc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sr
);
auto
tran_rc
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rc
);
// initial hidden state
// initial hidden state
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
...
@@ -951,40 +910,23 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
...
@@ -951,40 +910,23 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto
ic_lens
=
sic
->
get_shape
().
lens
();
auto
ic_lens
=
sic
->
get_shape
().
lens
();
// bias
// bias
instruction_ref
bi_brcst
{};
instruction_ref
wrb
{};
instruction_ref
bo_brcst
{};
instruction_ref
bf_brcst
{};
instruction_ref
bc_brcst
{};
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
bxi
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
ub_wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
4
*
hs
}},
sbias
);
auto
bhi
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
sbias
);
auto
ub_rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
8
*
hs
}},
sbias
);
auto
bi
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
bxi
,
bhi
);
auto
ub_wrb
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ub_wb
,
ub_rb
);
bi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_lens
},
bi
);
auto
bxo
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
bho
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
auto
bo
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
bxo
,
bho
);
bo_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_lens
},
bo
);
auto
bxf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sbias
);
wrb
=
prog
.
insert_instruction
(
auto
bhf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
6
*
hs
},
{
7
*
hs
}},
sbias
);
ins
,
op
::
broadcast
{
1
,
{
bs
,
4
*
static_cast
<
size_t
>
(
hs
)}},
ub_wrb
);
auto
bf
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
bxf
,
bhf
);
bf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_lens
},
bf
);
auto
bxc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sbias
);
auto
bhc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
7
*
hs
},
{
8
*
hs
}},
sbias
);
auto
bc
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
bxc
,
bhc
);
bc_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_lens
},
bc
);
}
}
// peep hole
// peep hole
instruction_ref
pphi_brcst
{};
instruction_ref
pphi_brcst
{};
instruction_ref
ppho_brcst
{};
instruction_ref
ppho_brcst
{};
instruction_ref
pphf_brcst
{};
instruction_ref
pphf_brcst
{};
if
(
pph
!=
prog
.
end
())
if
(
pph
!=
prog
.
end
())
{
{
auto
spph
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
pph
);
auto
spph
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
pph
);
...
@@ -1004,44 +946,31 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
...
@@ -1004,44 +946,31 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
seq
);
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
seq
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
auto
xt_tsw
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tsw
);
auto
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wi
);
auto
sih_tsr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tsr
);
auto
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ri
);
auto
xt_sih
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_tsw
,
sih_tsr
);
auto
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
if
(
pph
!=
prog
.
end
())
{
auto
pphi_ct
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
pphi_brcst
,
sic
);
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
it_before_actv
,
pphi_ct
);
}
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
i
t_
before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
i
t_
before_actv
,
bi_brcst
);
x
t_
sih
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
t_
sih
,
wrb
);
}
}
auto
it
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
it_before_actv
);
// equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
auto
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
0
},
{
hs
}},
xt_sih
);
auto
xt_wf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wf
);
auto
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
hs
},
{
2
*
hs
}},
xt_sih
);
auto
ht_rf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rf
);
auto
ft_before_actv
=
auto
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wf
,
ht_rf
);
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
2
*
hs
},
{
3
*
hs
}},
xt_sih
);
auto
ct_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
3
*
hs
},
{
4
*
hs
}},
xt_sih
);
if
(
pph
!=
prog
.
end
())
if
(
pph
!=
prog
.
end
())
{
{
auto
pphi_ct
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
pphi_brcst
,
sic
);
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
it_before_actv
,
pphi_ct
);
auto
pphf_ct
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
pphf_brcst
,
sic
);
auto
pphf_ct
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
pphf_brcst
,
sic
);
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ft_before_actv
,
pphf_ct
);
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ft_before_actv
,
pphf_ct
);
}
}
if
(
bias
!=
prog
.
end
())
auto
it
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
it_before_actv
);
{
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ft_before_actv
,
bf_brcst
);
}
auto
ft
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
ft_before_actv
);
auto
ft
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
ft_before_actv
);
// equation ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
auto
xt_wc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wc
);
auto
ht_rc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rc
);
auto
ct_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wc
,
ht_rc
);
if
(
bias
!=
prog
.
end
())
{
ct_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ct_before_actv
,
bc_brcst
);
}
auto
ct
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
ct_before_actv
);
auto
ct
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
ct_before_actv
);
// equation Ct = ft (.) Ct-1 + it (.) ct
// equation Ct = ft (.) Ct-1 + it (.) ct
...
@@ -1050,19 +979,11 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
...
@@ -1050,19 +979,11 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto
cellt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ft_cell
,
it_ct
);
auto
cellt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ft_cell
,
it_ct
);
last_cell_output
=
cellt
;
last_cell_output
=
cellt
;
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
auto
xt_wo
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wo
);
auto
ht_ro
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ro
);
auto
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wo
,
ht_ro
);
if
(
pph
!=
prog
.
end
())
if
(
pph
!=
prog
.
end
())
{
{
auto
ppho_cellt
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
ppho_brcst
,
cellt
);
auto
ppho_cellt
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
ppho_brcst
,
cellt
);
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ot_before_actv
,
ppho_cellt
);
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ot_before_actv
,
ppho_cellt
);
}
}
if
(
bias
!=
prog
.
end
())
{
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ot_before_actv
,
bo_brcst
);
}
auto
ot
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
ot_before_actv
);
auto
ot
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
ot_before_actv
);
// Ht = ot (.) h(Ct)
// Ht = ot (.) h(Ct)
...
...
src/targets/gpu/device/include/migraphx/gpu/device/types.hpp
View file @
099e9ce8
...
@@ -64,9 +64,9 @@ host_type<T>* host_cast(T* x)
...
@@ -64,9 +64,9 @@ host_type<T>* host_cast(T* x)
}
}
template
<
class
T
>
template
<
class
T
>
device_type
<
T
>
device_cast
(
T
x
)
device_type
<
T
>
device_cast
(
const
T
&
x
)
{
{
return
reinterpret_cast
<
device_type
<
T
>>
(
x
);
return
reinterpret_cast
<
const
device_type
<
T
>
&
>
(
x
);
}
}
template
<
class
T
>
template
<
class
T
>
...
...
src/targets/gpu/device/pad.cpp
View file @
099e9ce8
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include <migraphx/gpu/device/pad.hpp>
#include <migraphx/gpu/device/pad.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/float_equal.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -14,8 +15,17 @@ argument
...
@@ -14,8 +15,17 @@ argument
pad
(
hipStream_t
stream
,
argument
result
,
argument
arg1
,
float
value
,
std
::
vector
<
std
::
int64_t
>
pads
)
pad
(
hipStream_t
stream
,
argument
result
,
argument
arg1
,
float
value
,
std
::
vector
<
std
::
int64_t
>
pads
)
{
{
std
::
size_t
nelements
=
arg1
.
get_shape
().
elements
();
std
::
size_t
nelements
=
arg1
.
get_shape
().
elements
();
visit_all
(
result
)([
&
](
auto
output
)
{
auto
*
outptr
=
device_cast
(
output
.
data
());
using
type
=
typename
decltype
(
output
)
::
value_type
;
device_type
<
type
>
device_val
=
value
;
if
(
float_equal
(
value
,
std
::
numeric_limits
<
float
>::
lowest
()))
{
device_val
=
device_cast
(
std
::
numeric_limits
<
type
>::
lowest
());
}
gs_launch
(
stream
,
result
.
get_shape
().
elements
())([
=
](
auto
i
)
{
outptr
[
i
]
=
device_val
;
});
});
nary
(
stream
,
result
)([
=
]
{
return
value
;
});
visit_all
(
result
,
arg1
)([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
arg1
)([
&
](
auto
output
,
auto
input
)
{
visit_tensor_size
(
result
.
get_shape
().
lens
().
size
(),
[
&
](
auto
ndim
)
{
visit_tensor_size
(
result
.
get_shape
().
lens
().
size
(),
[
&
](
auto
ndim
)
{
std
::
size_t
offsets
[
ndim
];
std
::
size_t
offsets
[
ndim
];
...
...
src/targets/gpu/include/migraphx/gpu/softmax.hpp
View file @
099e9ce8
...
@@ -34,7 +34,7 @@ struct miopen_softmax
...
@@ -34,7 +34,7 @@ struct miopen_softmax
return
migraphx
::
reflect
(
self
.
op
,
f
);
return
migraphx
::
reflect
(
self
.
op
,
f
);
}
}
std
::
string
name
()
const
{
return
"
gpu
::softmax"
;
}
std
::
string
name
()
const
{
return
"
miopen
::softmax"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
;
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
;
argument
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
...
...
src/tf/CMakeLists.txt
View file @
099e9ce8
...
@@ -31,7 +31,7 @@ rocm_install_targets(
...
@@ -31,7 +31,7 @@ rocm_install_targets(
add_executable
(
read_tf read_tf.cpp
)
add_executable
(
read_tf read_tf.cpp
)
rocm_clang_tidy_check
(
read_tf
)
rocm_clang_tidy_check
(
read_tf
)
target_link_libraries
(
read_tf migraphx_tf
)
target_link_libraries
(
read_tf migraphx_tf
migraphx_cpu
)
if
(
MIGRAPHX_ENABLE_GPU
)
if
(
MIGRAPHX_ENABLE_GPU
)
add_executable
(
verify_tf verify_tf.cpp
)
add_executable
(
verify_tf verify_tf.cpp
)
...
...
src/tf/tf.cpp
View file @
099e9ce8
...
@@ -317,6 +317,7 @@ struct tf_parser
...
@@ -317,6 +317,7 @@ struct tf_parser
}
}
}
}
auto
l0
=
args
[
0
];
if
(
contains
(
attributes
,
"padding"
))
if
(
contains
(
attributes
,
"padding"
))
{
{
const
std
::
string
&
pad_mode
=
attributes
.
at
(
"padding"
).
s
();
const
std
::
string
&
pad_mode
=
attributes
.
at
(
"padding"
).
s
();
...
@@ -326,8 +327,24 @@ struct tf_parser
...
@@ -326,8 +327,24 @@ struct tf_parser
std
::
vector
<
size_t
>
weight_dims
=
weights
->
get_shape
().
lens
();
std
::
vector
<
size_t
>
weight_dims
=
weights
->
get_shape
().
lens
();
size_t
weight_h
=
weight_dims
[
2
];
size_t
weight_h
=
weight_dims
[
2
];
size_t
weight_w
=
weight_dims
[
3
];
size_t
weight_w
=
weight_dims
[
3
];
op
.
padding
[
0
]
=
calculate_padding
(
weight_h
,
op
.
dilation
[
0
]);
op
.
padding
[
1
]
=
calculate_padding
(
weight_w
,
op
.
dilation
[
1
]);
auto
input_dims
=
l0
->
get_shape
().
lens
();
size_t
input_h
=
input_dims
[
2
];
size_t
input_w
=
input_dims
[
3
];
std
::
vector
<
int64_t
>
pads
(
input_dims
.
size
());
calculate_padding
(
0
,
pads
,
input_h
,
op
.
stride
[
0
],
op
.
dilation
[
0
],
weight_h
);
calculate_padding
(
1
,
pads
,
input_w
,
op
.
stride
[
1
],
op
.
dilation
[
1
],
weight_w
);
if
(
pads
[
0
]
!=
pads
[
2
]
||
pads
[
1
]
!=
pads
[
3
])
{
std
::
vector
<
int64_t
>
padding
=
{
0
,
0
,
pads
[
0
],
pads
[
1
],
0
,
0
,
pads
[
2
],
pads
[
3
]};
l0
=
prog
.
add_instruction
(
migraphx
::
op
::
pad
{
padding
},
l0
);
}
else
{
op
.
padding
[
0
]
=
pads
[
0
];
op
.
padding
[
1
]
=
pads
[
1
];
}
}
}
else
if
(
pad_mode
.
find
(
"VALID"
)
!=
std
::
string
::
npos
)
else
if
(
pad_mode
.
find
(
"VALID"
)
!=
std
::
string
::
npos
)
{
{
...
@@ -350,7 +367,7 @@ struct tf_parser
...
@@ -350,7 +367,7 @@ struct tf_parser
}
}
}
}
return
prog
.
add_instruction
(
op
,
{
args
[
0
]
,
weights
});
return
prog
.
add_instruction
(
op
,
{
l0
,
weights
});
}
}
instruction_ref
parse_depthwiseconv
(
const
std
::
string
&
,
instruction_ref
parse_depthwiseconv
(
const
std
::
string
&
,
...
@@ -400,17 +417,35 @@ struct tf_parser
...
@@ -400,17 +417,35 @@ struct tf_parser
}
}
}
}
auto
l0
=
args
[
0
];
if
(
contains
(
attributes
,
"padding"
))
if
(
contains
(
attributes
,
"padding"
))
{
{
const
std
::
string
&
pad_mode
=
attributes
.
at
(
"padding"
).
s
();
const
std
::
string
&
pad_mode
=
attributes
.
at
(
"padding"
).
s
();
if
(
pad_mode
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
{
op
.
padding_mode
=
op
::
padding_mode_t
::
same
;
std
::
vector
<
size_t
>
weight_dims
=
weights
->
get_shape
().
lens
();
std
::
vector
<
size_t
>
weight_dims
=
weights
->
get_shape
().
lens
();
size_t
weight_h
=
weight_dims
[
2
];
size_t
weight_h
=
weight_dims
[
2
];
size_t
weight_w
=
weight_dims
[
3
];
size_t
weight_w
=
weight_dims
[
3
];
if
(
pad_mode
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
auto
input_dims
=
l0
->
get_shape
().
lens
();
size_t
input_h
=
input_dims
[
2
];
size_t
input_w
=
input_dims
[
3
];
std
::
vector
<
int64_t
>
pads
(
input_dims
.
size
());
calculate_padding
(
0
,
pads
,
input_h
,
op
.
stride
[
0
],
op
.
dilation
[
0
],
weight_h
);
calculate_padding
(
1
,
pads
,
input_w
,
op
.
stride
[
1
],
op
.
dilation
[
1
],
weight_w
);
if
(
pads
[
0
]
!=
pads
[
2
]
||
pads
[
1
]
!=
pads
[
3
])
{
{
op
.
padding_mode
=
op
::
padding_mode_t
::
same
;
std
::
vector
<
int64_t
>
padding
=
{
0
,
0
,
pads
[
0
],
pads
[
1
],
0
,
0
,
pads
[
2
],
pads
[
3
]};
op
.
padding
[
0
]
=
calculate_padding
(
weight_h
,
op
.
dilation
[
0
]);
l0
=
prog
.
add_instruction
(
migraphx
::
op
::
pad
{
padding
},
l0
);
op
.
padding
[
1
]
=
calculate_padding
(
weight_w
,
op
.
dilation
[
1
]);
}
else
{
op
.
padding
[
0
]
=
pads
[
0
];
op
.
padding
[
1
]
=
pads
[
1
];
}
}
}
else
if
(
pad_mode
.
find
(
"VALID"
)
!=
std
::
string
::
npos
)
else
if
(
pad_mode
.
find
(
"VALID"
)
!=
std
::
string
::
npos
)
{
{
...
@@ -432,7 +467,7 @@ struct tf_parser
...
@@ -432,7 +467,7 @@ struct tf_parser
auto
cweights
=
prog
.
add_instruction
(
op
::
contiguous
{},
weights
);
auto
cweights
=
prog
.
add_instruction
(
op
::
contiguous
{},
weights
);
auto
new_weights
=
prog
.
add_instruction
(
op
::
reshape
{
new_weights_shape
},
cweights
);
auto
new_weights
=
prog
.
add_instruction
(
op
::
reshape
{
new_weights_shape
},
cweights
);
return
prog
.
add_instruction
(
op
,
{
args
[
0
]
,
new_weights
});
return
prog
.
add_instruction
(
op
,
{
l0
,
new_weights
});
}
}
instruction_ref
instruction_ref
...
@@ -567,21 +602,39 @@ struct tf_parser
...
@@ -567,21 +602,39 @@ struct tf_parser
op
.
lengths
[
0
]
=
ksize
[
2
];
op
.
lengths
[
0
]
=
ksize
[
2
];
op
.
lengths
[
1
]
=
ksize
[
3
];
op
.
lengths
[
1
]
=
ksize
[
3
];
}
}
auto
l0
=
args
[
0
];
if
(
contains
(
attributes
,
"padding"
))
if
(
contains
(
attributes
,
"padding"
))
{
{
const
std
::
string
&
pad_mode
=
attributes
.
at
(
"padding"
).
s
();
const
std
::
string
&
pad_mode
=
attributes
.
at
(
"padding"
).
s
();
if
(
pad_mode
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
if
(
pad_mode
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
{
{
op
.
padding_mode
=
op
::
padding_mode_t
::
same
;
op
.
padding_mode
=
op
::
padding_mode_t
::
same
;
op
.
padding
[
0
]
=
calculate_padding
(
op
.
lengths
[
0
],
1
);
auto
input_dims
=
l0
->
get_shape
().
lens
();
op
.
padding
[
1
]
=
calculate_padding
(
op
.
lengths
[
1
],
1
);
size_t
input_h
=
input_dims
[
2
];
size_t
input_w
=
input_dims
[
3
];
std
::
vector
<
int64_t
>
pads
(
input_dims
.
size
());
calculate_padding
(
0
,
pads
,
input_h
,
op
.
stride
[
0
],
1
,
op
.
lengths
[
0
]);
calculate_padding
(
1
,
pads
,
input_w
,
op
.
stride
[
1
],
1
,
op
.
lengths
[
1
]);
if
(
pads
[
0
]
!=
pads
[
2
]
||
pads
[
1
]
!=
pads
[
3
])
{
std
::
vector
<
int64_t
>
padding
=
{
0
,
0
,
pads
[
0
],
pads
[
1
],
0
,
0
,
pads
[
2
],
pads
[
3
]};
l0
=
prog
.
add_instruction
(
migraphx
::
op
::
pad
{
padding
,
std
::
numeric_limits
<
float
>::
lowest
()},
l0
);
}
else
{
op
.
padding
[
0
]
=
pads
[
0
];
op
.
padding
[
1
]
=
pads
[
1
];
}
}
}
else
if
(
pad_mode
.
find
(
"VALID"
)
!=
std
::
string
::
npos
)
else
if
(
pad_mode
.
find
(
"VALID"
)
!=
std
::
string
::
npos
)
{
{
op
.
padding_mode
=
op
::
padding_mode_t
::
valid
;
op
.
padding_mode
=
op
::
padding_mode_t
::
valid
;
}
}
}
}
return
prog
.
add_instruction
(
op
,
args
[
0
]
);
return
prog
.
add_instruction
(
op
,
l0
);
}
}
instruction_ref
instruction_ref
...
...
test/CMakeLists.txt
View file @
099e9ce8
...
@@ -119,7 +119,7 @@ foreach(ONNX_TEST ${ONNX_TESTS})
...
@@ -119,7 +119,7 @@ foreach(ONNX_TEST ${ONNX_TESTS})
set
(
TEST_NAME test_
${
BASE_NAME
}
)
set
(
TEST_NAME test_
${
BASE_NAME
}
)
add_executable
(
${
TEST_NAME
}
${
TES_ONNX_DIR
}
/
${
ONNX_TEST
}
)
add_executable
(
${
TEST_NAME
}
${
TES_ONNX_DIR
}
/
${
ONNX_TEST
}
)
rocm_clang_tidy_check
(
${
TEST_NAME
}
)
rocm_clang_tidy_check
(
${
TEST_NAME
}
)
target_link_libraries
(
${
TEST_NAME
}
migraphx_onnx
)
target_link_libraries
(
${
TEST_NAME
}
migraphx_onnx
migraphx_cpu
)
target_include_directories
(
${
TEST_NAME
}
PUBLIC include
)
target_include_directories
(
${
TEST_NAME
}
PUBLIC include
)
add_test
(
NAME
${
TEST_NAME
}
COMMAND $<TARGET_FILE:
${
TEST_NAME
}
> WORKING_DIRECTORY
${
CMAKE_CURRENT_SOURCE_DIR
}
/onnx
)
add_test
(
NAME
${
TEST_NAME
}
COMMAND $<TARGET_FILE:
${
TEST_NAME
}
> WORKING_DIRECTORY
${
CMAKE_CURRENT_SOURCE_DIR
}
/onnx
)
add_dependencies
(
tests
${
TEST_NAME
}
)
add_dependencies
(
tests
${
TEST_NAME
}
)
...
@@ -129,7 +129,7 @@ endforeach()
...
@@ -129,7 +129,7 @@ endforeach()
# tf test
# tf test
add_executable
(
test_tf tf/tf_test.cpp
)
add_executable
(
test_tf tf/tf_test.cpp
)
rocm_clang_tidy_check
(
test_tf
)
rocm_clang_tidy_check
(
test_tf
)
target_link_libraries
(
test_tf migraphx_tf
)
target_link_libraries
(
test_tf migraphx_tf
migraphx_cpu
)
target_include_directories
(
test_tf PUBLIC include
)
target_include_directories
(
test_tf PUBLIC include
)
add_test
(
NAME test_tf COMMAND $<TARGET_FILE:test_tf> WORKING_DIRECTORY
${
CMAKE_CURRENT_SOURCE_DIR
}
/tf
)
add_test
(
NAME test_tf COMMAND $<TARGET_FILE:test_tf> WORKING_DIRECTORY
${
CMAKE_CURRENT_SOURCE_DIR
}
/tf
)
add_dependencies
(
tests test_tf
)
add_dependencies
(
tests test_tf
)
...
...
test/eliminate_pad_test.cpp
View file @
099e9ce8
...
@@ -83,23 +83,4 @@ TEST_CASE(rewrite_test_asymmetric)
...
@@ -83,23 +83,4 @@ TEST_CASE(rewrite_test_asymmetric)
p
.
begin
(),
p
.
end
(),
[](
const
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"pad"
;
}));
p
.
begin
(),
p
.
end
(),
[](
const
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"pad"
;
}));
}
}
TEST_CASE
(
rewrite_test_same_padding
)
{
migraphx
::
program
p
;
size_t
img_dim
[
2
]
=
{
2
,
2
};
size_t
channels
=
1
;
std
::
vector
<
int32_t
>
input
(
channels
*
img_dim
[
0
]
*
img_dim
[
1
]);
std
::
iota
(
input
.
begin
(),
input
.
end
(),
0
);
migraphx
::
shape
s_img
{
migraphx
::
shape
::
int32_type
,
{
1
,
channels
,
img_dim
[
0
],
img_dim
[
1
]}};
auto
l_img
=
p
.
add_literal
(
migraphx
::
literal
{
s_img
,
input
});
auto
padded_img
=
p
.
add_instruction
(
migraphx
::
op
::
pad
{{
0
,
0
,
1
,
1
,
0
,
0
,
1
,
1
}},
l_img
);
create_conv
(
padded_img
,
channels
,
p
,
migraphx
::
op
::
padding_mode_t
::
same
);
p
.
compile
(
eliminate_pad_target
{});
EXPECT
(
std
::
any_of
(
p
.
begin
(),
p
.
end
(),
[](
const
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"pad"
;
}));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/gpu/miopen.cpp
View file @
099e9ce8
...
@@ -1460,6 +1460,22 @@ struct test_pad : verify_program<test_pad>
...
@@ -1460,6 +1460,22 @@ struct test_pad : verify_program<test_pad>
}
}
};
};
struct
test_pad_int8
:
verify_program
<
test_pad_int8
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
std
::
vector
<
int8_t
>
data0
=
{
0
,
1
,
2
,
3
};
migraphx
::
shape
s0
{
migraphx
::
shape
::
float_type
,
{
2
,
2
}};
auto
l0
=
p
.
add_literal
(
migraphx
::
literal
{
s0
,
data0
});
migraphx
::
op
::
pad
op
{};
op
.
value
=
std
::
numeric_limits
<
int8_t
>::
lowest
();
op
.
pads
=
{
0
,
0
,
1
,
1
};
p
.
add_instruction
(
op
,
l0
);
return
p
;
}
};
struct
test_pooling_autopad
:
verify_program
<
test_pooling_autopad
>
struct
test_pooling_autopad
:
verify_program
<
test_pooling_autopad
>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
...
@@ -2650,7 +2666,8 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
...
@@ -2650,7 +2666,8 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
output
=
p
.
add_instruction
(
auto
output
=
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hidden_size
,
migraphx
::
op
::
lstm
{
hidden_size
,
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn_direction
::
forward
,
migraphx
::
op
::
rnn_direction
::
forward
,
clip
},
clip
},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment