Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8a3d1d09
"vscode:/vscode.git/clone" did not exist on "6df7e51aefc2b69e3b356e39e0618960b36892e8"
Commit
8a3d1d09
authored
Feb 09, 2019
by
Paul
Browse files
Merge branch 'develop' into py
parents
fc8c2664
6972ad26
Changes
55
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1352 additions
and
71 deletions
+1352
-71
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/auto_contiguous.cpp
src/auto_contiguous.cpp
+1
-1
src/dead_code_elimination.cpp
src/dead_code_elimination.cpp
+3
-2
src/include/migraphx/auto_any_cast.hpp
src/include/migraphx/auto_any_cast.hpp
+4
-0
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+2
-2
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+149
-9
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+2
-0
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+53
-0
src/instruction.cpp
src/instruction.cpp
+1
-1
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+252
-3
src/program.cpp
src/program.cpp
+29
-6
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+668
-0
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+62
-38
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+39
-8
src/targets/cpu/target.cpp
src/targets/cpu/target.cpp
+7
-1
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+2
-0
src/targets/gpu/device/sub.cpp
src/targets/gpu/device/sub.cpp
+17
-0
src/targets/gpu/gemm.cpp
src/targets/gpu/gemm.cpp
+1
-0
src/targets/gpu/include/migraphx/gpu/device/sub.hpp
src/targets/gpu/include/migraphx/gpu/device/sub.hpp
+20
-0
src/targets/gpu/include/migraphx/gpu/lrn.hpp
src/targets/gpu/include/migraphx/gpu/lrn.hpp
+39
-0
No files found.
src/CMakeLists.txt
View file @
8a3d1d09
...
@@ -11,6 +11,7 @@ add_library(migraphx
...
@@ -11,6 +11,7 @@ add_library(migraphx
eliminate_contiguous.cpp
eliminate_contiguous.cpp
eliminate_concat.cpp
eliminate_concat.cpp
fwd_conv_batchnorm_rewrite.cpp
fwd_conv_batchnorm_rewrite.cpp
rewrite_rnn.cpp
env.cpp
env.cpp
generate.cpp
generate.cpp
instruction.cpp
instruction.cpp
...
...
src/auto_contiguous.cpp
View file @
8a3d1d09
...
@@ -12,7 +12,7 @@ void auto_contiguous::apply(program& p) const
...
@@ -12,7 +12,7 @@ void auto_contiguous::apply(program& p) const
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
shape
s
=
ins
->
get_shape
();
shape
s
=
ins
->
get_shape
();
if
(
not
s
.
standard
())
if
(
not
s
.
standard
()
and
s
.
elements
()
!=
0
)
{
{
auto
c
=
p
.
insert_instruction
(
std
::
next
(
ins
),
op
::
contiguous
{},
ins
);
auto
c
=
p
.
insert_instruction
(
std
::
next
(
ins
),
op
::
contiguous
{},
ins
);
p
.
replace_instruction
(
ins
,
c
);
p
.
replace_instruction
(
ins
,
c
);
...
...
src/dead_code_elimination.cpp
View file @
8a3d1d09
...
@@ -41,8 +41,9 @@ void dead_code_elimination::apply(program& p) const
...
@@ -41,8 +41,9 @@ void dead_code_elimination::apply(program& p) const
// Skip the last instruction
// Skip the last instruction
if
(
i
==
last
)
if
(
i
==
last
)
break
;
break
;
// Skip instruction with empty shape as output unless its a builtin
// Skip instruction with empty shape as output unless its a builtin or undefined
if
(
i
->
get_shape
().
elements
()
==
0
and
not
(
i
->
name
().
front
()
==
'@'
))
if
(
i
->
get_shape
().
elements
()
==
0
and
not
(
i
->
name
().
front
()
==
'@'
)
and
not
(
i
->
name
()
==
"undefined"
))
continue
;
continue
;
assert
(
bidistance
(
p
,
i
,
last
)
>
0
);
assert
(
bidistance
(
p
,
i
,
last
)
>
0
);
fix
([
&
](
auto
self
,
auto
leaf
)
{
fix
([
&
](
auto
self
,
auto
leaf
)
{
...
...
src/include/migraphx/auto_any_cast.hpp
View file @
8a3d1d09
...
@@ -5,6 +5,10 @@
...
@@ -5,6 +5,10 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
// Forward declare any_cast
template
<
class
T
>
const
T
&
any_cast
(
const
T
&
);
namespace
detail
{
namespace
detail
{
template
<
class
U
>
template
<
class
U
>
...
...
src/include/migraphx/operation.hpp
View file @
8a3d1d09
...
@@ -7,17 +7,17 @@
...
@@ -7,17 +7,17 @@
#include <memory>
#include <memory>
#include <type_traits>
#include <type_traits>
#include <utility>
#include <utility>
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/context.hpp>
#include <migraphx/auto_any_cast.hpp>
#include <migraphx/auto_any_cast.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
context
;
#ifdef DOXYGEN
#ifdef DOXYGEN
/// The operation interface represents an action an instruction will perform. All
/// The operation interface represents an action an instruction will perform. All
...
...
src/include/migraphx/operators.hpp
View file @
8a3d1d09
...
@@ -60,6 +60,30 @@ struct batch_norm_inference
...
@@ -60,6 +60,30 @@ struct batch_norm_inference
}
}
};
};
struct
lrn
{
float
alpha
=
0.0001
;
float
beta
=
0.75
;
float
bias
=
1.0
;
int
size
=
1
;
std
::
string
name
()
const
{
return
"lrn"
;
}
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
alpha
,
"alpha"
),
f
(
self
.
beta
,
"beta"
),
f
(
self
.
bias
,
"bias"
),
f
(
self
.
size
,
"size"
));
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
inputs
.
front
();
}
};
struct
convolution
struct
convolution
{
{
std
::
array
<
std
::
size_t
,
2
>
padding
=
{{
0
,
0
}};
std
::
array
<
std
::
size_t
,
2
>
padding
=
{{
0
,
0
}};
...
@@ -358,6 +382,17 @@ struct contiguous
...
@@ -358,6 +382,17 @@ struct contiguous
auto
t
=
inputs
.
at
(
0
).
type
();
auto
t
=
inputs
.
at
(
0
).
type
();
return
{
t
,
lens
};
return
{
t
,
lens
};
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
assert
(
output_shape
.
standard
());
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
input
(
idx
.
begin
(),
idx
.
end
());
});
});
return
result
;
}
};
};
struct
concat
struct
concat
...
@@ -430,7 +465,6 @@ struct concat
...
@@ -430,7 +465,6 @@ struct concat
}
}
return
result
;
return
result
;
}
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
struct
slice
struct
slice
...
@@ -616,11 +650,16 @@ struct reshape
...
@@ -616,11 +650,16 @@ struct reshape
{
{
if
(
dims
[
i
]
==
0
)
if
(
dims
[
i
]
==
0
)
rdims
[
i
]
=
idims
[
i
];
rdims
[
i
]
=
idims
[
i
];
// since rdims using size_t type, -1 is the max value
// is size_t that cause later compuation incorrect
if
(
dims
[
i
]
==
-
1
)
rdims
[
i
]
=
1
;
}
}
if
(
n_neg_dims
>
0
)
if
(
n_neg_dims
>
0
)
{
{
size_t
missing_dim
=
size_t
missing_dim
=
-
inputs
.
front
().
elements
()
/
inputs
.
front
().
elements
()
/
std
::
accumulate
(
rdims
.
begin
(),
rdims
.
end
(),
1
,
std
::
multiplies
<
int64_t
>
());
std
::
accumulate
(
rdims
.
begin
(),
rdims
.
end
(),
1
,
std
::
multiplies
<
int64_t
>
());
for
(
std
::
size_t
i
=
0
;
i
<
rdims
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
rdims
.
size
();
i
++
)
{
{
...
@@ -628,11 +667,7 @@ struct reshape
...
@@ -628,11 +667,7 @@ struct reshape
rdims
[
i
]
=
missing_dim
;
rdims
[
i
]
=
missing_dim
;
}
}
}
}
if
(
dims
.
back
()
==
-
1
)
{
rdims
.
pop_back
();
std
::
copy
(
idims
.
begin
()
+
rdims
.
size
(),
idims
.
end
(),
std
::
back_inserter
(
rdims
));
}
shape
s
{
inputs
.
front
().
type
(),
rdims
};
shape
s
{
inputs
.
front
().
type
(),
rdims
};
if
(
s
.
elements
()
!=
inputs
.
front
().
elements
())
if
(
s
.
elements
()
!=
inputs
.
front
().
elements
())
MIGRAPHX_THROW
(
"Wrong number of elements for reshape"
);
MIGRAPHX_THROW
(
"Wrong number of elements for reshape"
);
...
@@ -764,8 +799,6 @@ struct gather
...
@@ -764,8 +799,6 @@ struct gather
return
result
;
return
result
;
}
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
struct
dot
struct
dot
...
@@ -1131,6 +1164,113 @@ struct outline
...
@@ -1131,6 +1164,113 @@ struct outline
argument
compute
(
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
return
{
s
,
nullptr
};
}
argument
compute
(
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
return
{
s
,
nullptr
};
}
};
};
// indicate rnn computation direction
enum
class
rnn_direction
{
forward
,
reverse
,
bidirectional
,
};
struct
rnn
{
std
::
size_t
hidden_size
=
1
;
std
::
vector
<
operation
>
actv_funcs
{
tanh
{},
tanh
{}};
rnn_direction
direction
=
rnn_direction
::
forward
;
float
clip
=
0.0
f
;
std
::
string
name
()
const
{
return
"rnn"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
auto
in_dims
=
inputs
[
0
].
lens
();
auto
hidden_dims
=
inputs
[
2
].
lens
();
if
(
hidden_size
!=
hidden_dims
[
2
])
{
MIGRAPHX_THROW
(
"RNN: hidden size mismatch in attribute and input"
);
}
std
::
size_t
num_directions
=
1
;
if
(
direction
==
rnn_direction
::
bidirectional
)
{
num_directions
=
2
;
}
if
(
num_directions
!=
hidden_dims
[
0
])
{
MIGRAPHX_THROW
(
"RNN: num_direction mismatch in attribute and input"
);
}
std
::
vector
<
std
::
size_t
>
out_dims
(
in_dims
);
out_dims
.
insert
(
out_dims
.
begin
()
+
1
,
num_directions
);
out_dims
.
back
()
=
hidden_size
;
return
{
inputs
[
0
].
type
(),
out_dims
};
}
};
struct
rnn_last_output
{
std
::
string
name
()
const
{
return
"rnn_last_output"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
dims
=
inputs
[
0
].
lens
();
// remove the first dimension, remaing are output shape
dims
.
erase
(
dims
.
begin
());
return
{
inputs
[
0
].
type
(),
dims
};
}
};
struct
gru
{
std
::
size_t
hidden_size
=
1
;
std
::
vector
<
operation
>
actv_funcs
{
sigmoid
{},
tanh
{}};
rnn_direction
direction
=
rnn_direction
::
forward
;
float
clip
=
0.0
f
;
int
linear_before_reset
=
0
;
std
::
string
name
()
const
{
return
"gru"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
auto
in_dims
=
inputs
[
0
].
lens
();
auto
hidden_dims
=
inputs
[
2
].
lens
();
if
(
hidden_size
!=
hidden_dims
[
2
])
{
MIGRAPHX_THROW
(
"GRU: hidden size mismatch in attribute and input"
);
}
std
::
size_t
num_directions
=
1
;
if
(
direction
==
rnn_direction
::
bidirectional
)
{
num_directions
=
2
;
}
if
(
num_directions
!=
hidden_dims
[
0
])
{
MIGRAPHX_THROW
(
"GRU: num_direction does not match the direction attribute"
);
}
std
::
vector
<
std
::
size_t
>
out_dims
(
in_dims
);
out_dims
.
insert
(
out_dims
.
begin
()
+
1
,
num_directions
);
out_dims
.
back
()
=
hidden_size
;
return
{
inputs
[
0
].
type
(),
out_dims
};
}
};
struct
undefined
{
std
::
string
name
()
const
{
return
"undefined"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
0
);
return
{};
}
argument
compute
(
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
return
{{},
nullptr
};
}
};
}
// namespace op
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/program.hpp
View file @
8a3d1d09
...
@@ -105,6 +105,8 @@ struct program
...
@@ -105,6 +105,8 @@ struct program
void
debug_print
(
instruction_ref
ins
)
const
;
void
debug_print
(
instruction_ref
ins
)
const
;
void
debug_print
(
const
std
::
vector
<
instruction_ref
>&
inss
)
const
;
void
debug_print
(
const
std
::
vector
<
instruction_ref
>&
inss
)
const
;
void
dry_run
(
parameter_map
params
)
const
;
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
friend
bool
operator
==
(
const
program
&
x
,
const
program
&
y
);
friend
bool
operator
==
(
const
program
&
x
,
const
program
&
y
);
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
!
(
x
==
y
);
}
...
...
src/include/migraphx/rewrite_rnn.hpp
0 → 100644
View file @
8a3d1d09
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_RNN_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_RNN_HPP
#include <string>
#include <vector>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
program
;
/**
* Rewrite rnn to gemm and add.
*/
struct
rewrite_rnn
{
std
::
string
name
()
const
{
return
"rewrite_rnn"
;
}
void
apply
(
program
&
prog
)
const
;
private:
// for vanilla rnn operators
void
apply_vanilla_rnn
(
program
&
prog
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
vanilla_rnn_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
instruction_ref
input
,
instruction_ref
w
,
instruction_ref
r
,
instruction_ref
bias
,
instruction_ref
ih
,
operation
&
actv_func
)
const
;
std
::
vector
<
operation
>
vanilla_rnn_actv_funcs
(
instruction_ref
ins
)
const
;
// for gru operators
void
apply_gru
(
program
&
prog
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
gru_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
int
linear_before_reset
,
const
operation
&
actv_func1
,
const
operation
&
actv_func2
)
const
;
std
::
vector
<
operation
>
gru_actv_funcs
(
instruction_ref
ins
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/instruction.cpp
View file @
8a3d1d09
...
@@ -97,7 +97,7 @@ const std::vector<instruction_ref>& instruction::outputs() const { return output
...
@@ -97,7 +97,7 @@ const std::vector<instruction_ref>& instruction::outputs() const { return output
bool
operator
==
(
const
instruction
&
x
,
const
instruction
&
y
)
bool
operator
==
(
const
instruction
&
x
,
const
instruction
&
y
)
{
{
if
(
not
(
x
.
result
==
y
.
result
and
x
.
op
==
y
.
op
and
x
.
arguments
==
y
.
arguments
))
if
(
std
::
tie
(
x
.
result
,
x
.
op
,
x
.
arguments
)
!=
std
::
tie
(
y
.
result
,
y
.
op
,
y
.
arguments
))
return
false
;
return
false
;
if
(
x
.
name
()
==
"@literal"
)
if
(
x
.
name
()
==
"@literal"
)
return
x
.
lit
==
y
.
lit
;
return
x
.
lit
==
y
.
lit
;
...
...
src/onnx/onnx.cpp
View file @
8a3d1d09
...
@@ -32,6 +32,7 @@ struct onnx_parser
...
@@ -32,6 +32,7 @@ struct onnx_parser
bool
is_pytorch
=
false
;
bool
is_pytorch
=
false
;
std
::
unordered_map
<
std
::
string
,
op_func
>
ops
;
std
::
unordered_map
<
std
::
string
,
op_func
>
ops
;
std
::
unordered_map
<
std
::
string
,
operation
>
map_actv_funcs
;
onnx_parser
()
onnx_parser
()
{
{
...
@@ -63,6 +64,7 @@ struct onnx_parser
...
@@ -63,6 +64,7 @@ struct onnx_parser
add_variadic_op
(
"Max"
,
op
::
max
{});
add_variadic_op
(
"Max"
,
op
::
max
{});
add_variadic_op
(
"Min"
,
op
::
min
{});
add_variadic_op
(
"Min"
,
op
::
min
{});
add_mem_op
(
"LRN"
,
&
onnx_parser
::
parse_lrn
);
add_mem_op
(
"ImageScaler"
,
&
onnx_parser
::
parse_imagescaler
);
add_mem_op
(
"ImageScaler"
,
&
onnx_parser
::
parse_imagescaler
);
add_mem_op
(
"LeakyRelu"
,
&
onnx_parser
::
parse_leaky_relu
);
add_mem_op
(
"LeakyRelu"
,
&
onnx_parser
::
parse_leaky_relu
);
add_mem_op
(
"Elu"
,
&
onnx_parser
::
parse_elu
);
add_mem_op
(
"Elu"
,
&
onnx_parser
::
parse_elu
);
...
@@ -85,7 +87,21 @@ struct onnx_parser
...
@@ -85,7 +87,21 @@ struct onnx_parser
add_mem_op
(
"Shape"
,
&
onnx_parser
::
parse_shape
);
add_mem_op
(
"Shape"
,
&
onnx_parser
::
parse_shape
);
add_mem_op
(
"ConstantFill"
,
&
onnx_parser
::
parse_constant_fill
);
add_mem_op
(
"ConstantFill"
,
&
onnx_parser
::
parse_constant_fill
);
add_mem_op
(
"Transpose"
,
&
onnx_parser
::
parse_transpose
);
add_mem_op
(
"Transpose"
,
&
onnx_parser
::
parse_transpose
);
add_mem_op
(
"RNN"
,
&
onnx_parser
::
parse_rnn
);
add_mem_op
(
"GRU"
,
&
onnx_parser
::
parse_gru
);
add_mem_op
(
"Pad"
,
&
onnx_parser
::
parse_pad
);
add_mem_op
(
"Pad"
,
&
onnx_parser
::
parse_pad
);
// init the activation function map
init_actv_func
();
}
void
init_actv_func
()
{
map_actv_funcs
.
insert
(
std
::
make_pair
(
"tanh"
,
op
::
tanh
{}));
map_actv_funcs
.
insert
(
std
::
make_pair
(
"relu"
,
op
::
relu
{}));
map_actv_funcs
.
insert
(
std
::
make_pair
(
"sigmoid"
,
op
::
sigmoid
{}));
map_actv_funcs
.
insert
(
std
::
make_pair
(
"leakyrelu"
,
op
::
leaky_relu
{}));
map_actv_funcs
.
insert
(
std
::
make_pair
(
"elu"
,
op
::
elu
{}));
}
}
template
<
class
F
>
template
<
class
F
>
...
@@ -522,6 +538,25 @@ struct onnx_parser
...
@@ -522,6 +538,25 @@ struct onnx_parser
return
prog
.
add_instruction
(
op
,
args
.
front
());
return
prog
.
add_instruction
(
op
,
args
.
front
());
}
}
instruction_ref
parse_lrn
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
float
alpha
=
0.0001
;
float
beta
=
0.75
;
float
bias
=
1.0
;
int
size
=
1
;
if
(
contains
(
attributes
,
"alpha"
))
alpha
=
parse_value
(
attributes
.
at
(
"alpha"
)).
at
<
float
>
();
if
(
contains
(
attributes
,
"beta"
))
beta
=
parse_value
(
attributes
.
at
(
"beta"
)).
at
<
float
>
();
if
(
contains
(
attributes
,
"bias"
))
bias
=
parse_value
(
attributes
.
at
(
"bias"
)).
at
<
float
>
();
if
(
contains
(
attributes
,
"size"
))
size
=
parse_value
(
attributes
.
at
(
"size"
)).
at
<
int
>
();
op
::
lrn
op
{
alpha
,
beta
,
bias
,
size
};
return
prog
.
add_instruction
(
op
,
args
.
front
());
}
instruction_ref
parse_imagescaler
(
const
std
::
string
&
,
instruction_ref
parse_imagescaler
(
const
std
::
string
&
,
attribute_map
attributes
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
std
::
vector
<
instruction_ref
>
args
)
...
@@ -677,6 +712,214 @@ struct onnx_parser
...
@@ -677,6 +712,214 @@ struct onnx_parser
}
}
}
}
std
::
vector
<
instruction_ref
>
parse_rnn
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
migraphx
::
shape
input_shape
=
args
[
0
]
->
get_shape
();
std
::
size_t
hidden_size
=
args
[
1
]
->
get_shape
().
lens
()[
1
];
if
(
contains
(
attributes
,
"hidden_size"
))
{
std
::
size_t
hidden_size_att
=
parse_value
(
attributes
.
at
(
"hidden_size"
)).
at
<
int
>
();
if
(
hidden_size
!=
hidden_size_att
)
{
MIGRAPHX_THROW
(
"RNN: hidden size mismatch in input and attribute"
);
}
}
// Handling of direction to be added later
std
::
string
direction
{
"forward"
};
if
(
contains
(
attributes
,
"direction"
))
{
direction
=
attributes
.
at
(
"direction"
).
s
();
}
op
::
rnn_direction
dirct
=
op
::
rnn_direction
::
forward
;
if
(
direction
==
"bidirectional"
)
{
dirct
=
op
::
rnn_direction
::
bidirectional
;
}
else
if
(
direction
==
"reverse"
)
{
dirct
=
op
::
rnn_direction
::
reverse
;
}
std
::
vector
<
std
::
string
>
vec_names
{
"tanh"
};
if
(
contains
(
attributes
,
"activations"
))
{
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
for_each
(
names
.
begin
(),
names
.
end
(),
[
&
](
auto
&
fn
)
{
vec_names
.
push_back
(
fn
);
});
}
for_each
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
fn
)
{
if
(
map_actv_funcs
.
count
(
fn
)
==
0
)
{
MIGRAPHX_THROW
(
"RNN: activation function "
+
std
::
string
(
fn
)
+
" not supported"
);
}
});
// bidirectional case should have two activation functions.
// one is for forward, and the other is for reverse.
// if only one actv function is provided, we use it in both
// forward and reverse direction
if
(
dirct
==
op
::
rnn_direction
::
bidirectional
)
{
if
(
vec_names
.
size
()
==
1
)
{
vec_names
.
push_back
(
vec_names
.
at
(
0
));
}
}
std
::
vector
<
operation
>
vec_actv_funcs
(
vec_names
.
size
());
std
::
transform
(
vec_names
.
begin
(),
vec_names
.
end
(),
vec_actv_funcs
.
begin
(),
[
&
](
auto
&
fn
)
{
return
map_actv_funcs
[
fn
];
});
// To be added later
float
clip
=
0.0
;
if
(
contains
(
attributes
,
"clip"
))
{
clip
=
parse_value
(
attributes
.
at
(
"clip"
)).
at
<
float
>
();
}
// if the number of arguments is less than 6, append
// undefined operator to have 6 arguments
if
(
args
.
size
()
<
6
)
{
auto
ins
=
prog
.
add_instruction
(
op
::
undefined
{});
args
.
insert
(
args
.
end
(),
(
6
-
args
.
size
()),
ins
);
}
// first output for the concatenation of hidden states
auto
hidden_states
=
prog
.
add_instruction
(
op
::
rnn
{
hidden_size
,
vec_actv_funcs
,
dirct
,
clip
},
std
::
move
(
args
));
// second output for the last hidden state
auto
last_output
=
prog
.
add_instruction
(
op
::
rnn_last_output
{},
hidden_states
);
return
{
hidden_states
,
last_output
};
}
std
::
vector
<
instruction_ref
>
parse_gru
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
migraphx
::
shape
input_shape
=
args
[
0
]
->
get_shape
();
std
::
size_t
hidden_size
=
args
[
2
]
->
get_shape
().
lens
()[
2
];
if
(
contains
(
attributes
,
"hidden_size"
))
{
std
::
size_t
hidden_size_att
=
parse_value
(
attributes
.
at
(
"hidden_size"
)).
at
<
int
>
();
if
(
hidden_size
!=
hidden_size_att
)
{
MIGRAPHX_THROW
(
"GRU: hidden size mismatch in input and attribute"
);
}
}
// Handling of direction to be added later
std
::
string
direction
{
"forward"
};
if
(
contains
(
attributes
,
"direction"
))
{
direction
=
attributes
.
at
(
"direction"
).
s
();
}
op
::
rnn_direction
dirct
=
op
::
rnn_direction
::
forward
;
if
(
direction
==
"bidirectional"
)
{
dirct
=
op
::
rnn_direction
::
bidirectional
;
}
else
if
(
direction
==
"reverse"
)
{
dirct
=
op
::
rnn_direction
::
reverse
;
}
std
::
vector
<
std
::
string
>
vec_names
=
{
"sigmoid"
,
"tanh"
};
if
(
contains
(
attributes
,
"activations"
))
{
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
vec_names
.
resize
(
names
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
(),
[](
auto
&
str
)
{
return
str
;
});
}
// need 4 activation functions
if
(
dirct
==
op
::
rnn_direction
::
bidirectional
)
{
// 4 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provided,
// repeat 1 four times. If 2 actv functins are provided,
// assume forward and reverse use the same pair of actv
// functions. For the case of 3 actv functions provided,
// assume the 3rd one is repeated once and used by the
// reverse direction.
// This may need change later
if
(
vec_names
.
size
()
==
1
)
{
vec_names
.
insert
(
vec_names
.
end
(),
3
,
vec_names
.
at
(
0
));
}
else
if
(
vec_names
.
size
()
==
2
)
{
// repeat the activation functions
vec_names
.
push_back
(
vec_names
.
at
(
0
));
vec_names
.
push_back
(
vec_names
.
at
(
1
));
}
else
if
(
vec_names
.
size
()
==
3
)
{
vec_names
.
push_back
(
vec_names
.
at
(
2
));
}
}
else
{
if
(
vec_names
.
size
()
==
1
)
{
vec_names
.
push_back
(
vec_names
.
at
(
0
));
}
}
for_each
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
if
(
map_actv_funcs
.
count
(
name
)
==
0
)
{
MIGRAPHX_THROW
(
"GRU: activation function "
+
std
::
string
(
name
)
+
" not supported"
);
}
});
std
::
vector
<
operation
>
vec_actv_funcs
(
vec_names
.
size
());
std
::
transform
(
vec_names
.
begin
(),
vec_names
.
end
(),
vec_actv_funcs
.
begin
(),
[
&
](
auto
&
name
)
{
return
map_actv_funcs
[
name
];
});
float
clip
=
0.0
;
if
(
contains
(
attributes
,
"clip"
))
{
clip
=
parse_value
(
attributes
.
at
(
"clip"
)).
at
<
float
>
();
}
int
linear_before_reset
=
0
;
if
(
contains
(
attributes
,
"linear_before_reset"
))
{
linear_before_reset
=
parse_value
(
attributes
.
at
(
"linear_before_reset"
)).
at
<
int
>
();
}
// append undefined opeator to make 6 arguments
if
(
args
.
size
()
<
6
)
{
auto
ins
=
prog
.
add_instruction
(
op
::
undefined
{});
args
.
insert
(
args
.
end
(),
6
-
args
.
size
(),
ins
);
}
// first output for concatenation of hidden states
auto
hidden_states
=
prog
.
add_instruction
(
op
::
gru
{
hidden_size
,
vec_actv_funcs
,
dirct
,
clip
,
linear_before_reset
},
std
::
move
(
args
));
// second output for last gru output
auto
last_output
=
prog
.
add_instruction
(
op
::
rnn_last_output
{},
hidden_states
);
return
{
hidden_states
,
last_output
};
}
void
parse_from
(
std
::
istream
&
is
)
void
parse_from
(
std
::
istream
&
is
)
{
{
onnx
::
ModelProto
model
;
onnx
::
ModelProto
model
;
...
@@ -723,6 +966,12 @@ struct onnx_parser
...
@@ -723,6 +966,12 @@ struct onnx_parser
}
}
}
}
void
parse_undefined
(
const
std
::
string
&
name
)
{
auto
ins
=
prog
.
add_instruction
(
op
::
undefined
{});
instructions
[
name
]
=
ins
;
}
void
parse_node
(
const
std
::
string
&
name
)
void
parse_node
(
const
std
::
string
&
name
)
{
{
if
(
name
.
empty
())
if
(
name
.
empty
())
...
@@ -737,12 +986,12 @@ struct onnx_parser
...
@@ -737,12 +986,12 @@ struct onnx_parser
{
{
assert
(
name
!=
input
);
assert
(
name
!=
input
);
this
->
parse_node
(
input
);
this
->
parse_node
(
input
);
args
.
push_back
(
instructions
.
at
(
input
));
}
}
else
else
if
(
input
.
empty
())
{
{
args
.
push_back
(
instructions
.
at
(
input
)
)
;
this
->
parse_undefined
(
input
);
}
}
args
.
push_back
(
instructions
.
at
(
input
));
}
}
std
::
vector
<
instruction_ref
>
result
;
std
::
vector
<
instruction_ref
>
result
;
if
(
ops
.
count
(
node
.
op_type
())
==
0
)
if
(
ops
.
count
(
node
.
op_type
())
==
0
)
...
...
src/program.cpp
View file @
8a3d1d09
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/env.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
#include <migraphx/time.hpp>
...
@@ -134,6 +135,12 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
...
@@ -134,6 +135,12 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
assert
(
has_instruction
(
ins
));
assert
(
has_instruction
(
ins
));
assert
(
has_instruction
(
rep
));
assert
(
has_instruction
(
rep
));
assert
(
ins
!=
rep
);
assert
(
ins
!=
rep
);
if
(
ins
==
std
::
prev
(
this
->
end
()))
{
return
replace_instruction
(
ins
,
op
::
identity
{},
rep
);
}
// TODO: Should it be an error if the output is empty?
// TODO: Should it be an error if the output is empty?
if
(
ins
->
outputs
().
empty
())
if
(
ins
->
outputs
().
empty
())
{
{
...
@@ -372,20 +379,31 @@ argument generic_eval(const program& p,
...
@@ -372,20 +379,31 @@ argument generic_eval(const program& p,
argument
program
::
eval
(
std
::
unordered_map
<
std
::
string
,
argument
>
params
)
const
argument
program
::
eval
(
std
::
unordered_map
<
std
::
string
,
argument
>
params
)
const
{
{
auto
&
ctx
=
this
->
impl
->
ctx
;
#ifndef NDEBUG
auto
sctx
=
ctx
;
auto
check_context
=
[
&
](
auto
f
)
{
assert
(
is_shared
(
ctx
,
sctx
));
auto
x
=
f
();
sctx
=
ctx
;
return
x
;
};
#else
auto
check_context
=
[](
auto
f
)
{
return
f
();
};
#endif
if
(
enabled
(
MIGRAPHX_TRACE_EVAL
{}))
if
(
enabled
(
MIGRAPHX_TRACE_EVAL
{}))
{
{
auto
&
ctx
=
this
->
impl
->
ctx
;
return
generic_eval
(
*
this
,
ctx
,
std
::
move
(
params
),
[
&
](
auto
&
ins
,
auto
f
)
{
return
generic_eval
(
*
this
,
this
->
impl
->
ctx
,
std
::
move
(
params
),
[
&
](
auto
&
ins
,
auto
f
)
{
ctx
.
finish
();
ctx
.
finish
();
std
::
cout
<<
"Run instruction: "
;
std
::
cout
<<
"Run instruction: "
;
this
->
debug_print
(
ins
);
this
->
debug_print
(
ins
);
return
f
(
);
return
check_context
(
f
);
});
});
}
}
else
else
{
{
return
generic_eval
(
return
generic_eval
(
*
this
,
this
->
impl
->
ctx
,
std
::
move
(
params
),
[](
auto
&
,
auto
f
)
{
return
f
(
);
});
*
this
,
ctx
,
std
::
move
(
params
),
[
&
](
auto
&
,
auto
f
)
{
return
check_context
(
f
);
});
}
}
}
}
...
@@ -439,8 +457,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
...
@@ -439,8 +457,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
overhead_vec
.
reserve
(
n
);
overhead_vec
.
reserve
(
n
);
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
{
{
overhead_vec
.
push_back
(
time
<
milliseconds
>
(
overhead_vec
.
push_back
(
time
<
milliseconds
>
([
&
]
{
dry_run
(
params
);
}));
[
&
]
{
generic_eval
(
*
this
,
ctx
,
params
,
[](
auto
...)
{
return
argument
{};
});
}));
}
}
double
total_time
=
common_average
(
total_vec
);
double
total_time
=
common_average
(
total_vec
);
...
@@ -504,6 +521,12 @@ void program::debug_print(const std::vector<instruction_ref>& inss) const
...
@@ -504,6 +521,12 @@ void program::debug_print(const std::vector<instruction_ref>& inss) const
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
}
}
void
program
::
dry_run
(
std
::
unordered_map
<
std
::
string
,
argument
>
params
)
const
{
auto
&
ctx
=
this
->
impl
->
ctx
;
generic_eval
(
*
this
,
ctx
,
std
::
move
(
params
),
[](
auto
&&
...)
{
return
argument
{};
});
}
bool
operator
==
(
const
program
&
x
,
const
program
&
y
)
{
return
to_string
(
x
)
==
to_string
(
y
);
}
bool
operator
==
(
const
program
&
x
,
const
program
&
y
)
{
return
to_string
(
x
)
==
to_string
(
y
);
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
)
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
)
...
...
src/rewrite_rnn.cpp
0 → 100644
View file @
8a3d1d09
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
rewrite_rnn
::
apply
(
program
&
prog
)
const
{
for
(
auto
ins
:
iterator_for
(
prog
))
{
if
(
ins
->
name
()
==
"rnn"
)
{
apply_vanilla_rnn
(
prog
,
ins
);
}
if
(
ins
->
name
()
==
"gru"
)
{
apply_gru
(
prog
,
ins
);
}
}
}
void
rewrite_rnn
::
apply_vanilla_rnn
(
program
&
prog
,
instruction_ref
ins
)
const
{
assert
(
ins
->
name
()
==
"rnn"
);
// could be 3 to 6 inputs, but the parse_rnn function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have num of arguments
// when writing their program.
auto
args
=
ins
->
inputs
();
shape
seq_shape
=
args
[
0
]
->
get_shape
();
std
::
size_t
hidden_size
=
args
[
1
]
->
get_shape
().
lens
()[
1
];
std
::
size_t
batch_size
=
seq_shape
.
lens
()[
1
];
shape
::
type_t
type
=
seq_shape
.
type
();
migraphx
::
shape
ih_shape
{
type
,
{
1
,
batch_size
,
hidden_size
}};
std
::
vector
<
float
>
data
(
ih_shape
.
elements
(),
0
);
auto
actv_funcs
=
vanilla_rnn_actv_funcs
(
ins
);
auto
rnn_op
=
any_cast
<
op
::
rnn
>
(
ins
->
get_operator
());
op
::
rnn_direction
dicrt
=
rnn_op
.
direction
;
instruction_ref
last_output
{};
if
(
dicrt
==
op
::
rnn_direction
::
bidirectional
)
{
// input weight matrix
auto
w_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
1
]);
auto
w_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
1
]);
// hidden state weight matrix
auto
r_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
2
]);
auto
r_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
2
]);
// process bias
instruction_ref
bias_forward
=
prog
.
end
();
instruction_ref
bias_reverse
=
prog
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"undefined"
)
{
bias_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
3
]);
bias_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
3
]);
}
// process intial hidden state, it could be the 6th argument
// or the 5th one (if the sequence len argument is ignored)
instruction_ref
ih_forward
{};
instruction_ref
ih_reverse
{};
if
(
args
.
size
()
==
6
&&
args
[
5
]
->
name
()
!=
"undefined"
)
{
ih_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
5
]);
ih_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
5
]);
}
else
{
ih_forward
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
ih_reverse
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
auto
ret_forward
=
vanilla_rnn_cell
(
true
,
prog
,
ins
,
args
[
0
],
w_forward
,
r_forward
,
bias_forward
,
ih_forward
,
actv_funcs
.
at
(
0
));
auto
ret_reverse
=
vanilla_rnn_cell
(
false
,
prog
,
ins
,
args
[
0
],
w_reverse
,
r_reverse
,
bias_reverse
,
ih_reverse
,
actv_funcs
.
at
(
1
));
auto
concat_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
concat_output
);
// The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction
// sequence len is 1
if
(
ret_forward
[
0
]
==
prog
.
end
())
{
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
}
else
{
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
0
],
ret_forward
[
1
]);
ret_reverse
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_reverse
[
1
],
ret_reverse
[
0
]);
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
{
ret_forward
[
0
],
ret_reverse
[
0
]});
}
}
else
{
bool
is_forward
=
(
dicrt
==
op
::
rnn_direction
::
forward
);
// input weight matrix
auto
w
=
args
[
1
];
// hidden state weight matrix
auto
r
=
args
[
2
];
// process bias and initial hidden state
instruction_ref
bias
=
prog
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"undefined"
)
{
bias
=
args
[
3
];
}
// process intial hidden state
instruction_ref
ih
;
if
(
args
.
size
()
==
6
&&
args
[
5
]
->
name
()
!=
"undefined"
)
{
ih
=
args
[
5
];
}
else
{
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
auto
ret
=
vanilla_rnn_cell
(
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
actv_funcs
.
at
(
0
));
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ret
[
1
]);
// following logic is to ensure the last instruction is a
// concat instruction
// sequence len is 1
if
(
ret
[
0
]
==
prog
.
end
())
{
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
ret
[
1
]);
}
else
{
auto
concat_arg0
=
is_forward
?
ret
[
0
]
:
ret
[
1
];
auto
concat_arg1
=
is_forward
?
ret
[
1
]
:
ret
[
0
];
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
concat_arg0
,
concat_arg1
);
}
}
// search its output to find if there are rnn_last_output operator
// while loop to handle case of multiple rnn_last_output operators
auto
last_output_it
=
ins
->
outputs
().
begin
();
while
(
last_output_it
!=
ins
->
outputs
().
end
())
{
last_output_it
=
std
::
find_if
(
last_output_it
,
ins
->
outputs
().
end
(),
[](
auto
i
)
{
return
i
->
name
()
==
"rnn_last_output"
;
});
if
(
last_output_it
!=
ins
->
outputs
().
end
())
{
prog
.
replace_instruction
(
*
last_output_it
,
last_output
);
last_output_it
++
;
}
}
}
std
::
vector
<
instruction_ref
>
rewrite_rnn
::
vanilla_rnn_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
instruction_ref
input
,
instruction_ref
w
,
instruction_ref
r
,
instruction_ref
bias
,
instruction_ref
ih
,
operation
&
actv_func
)
const
{
// squeeze and transpose w
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
tran_sw
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
sw
);
// squeeze and transpose r
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
tran_sr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
sr
);
// initial hidden state
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
// bias
if
(
bias
!=
prog
.
end
())
{
long
hs
=
r
->
get_shape
().
lens
()[
2
];
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
b
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wb
,
rb
);
bias
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
b
);
}
instruction_ref
hidden_out
=
prog
.
end
();
instruction_ref
last_out
{};
last_out
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
0
,
1
}},
sih
);
std
::
size_t
seq_len
=
input
->
get_shape
().
lens
()[
0
];
for
(
std
::
size_t
i
=
0
;
i
<
seq_len
;
i
++
)
{
long
seq_index
=
is_forward
?
i
:
(
seq_len
-
1
-
i
);
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
input
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
auto
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_sw
);
auto
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_sr
);
auto
xt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
instruction_ref
ht
;
if
(
bias
!=
prog
.
end
())
{
ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_ht
,
bias
);
}
else
{
ht
=
xt_ht
;
}
// apply activation function
ht
=
prog
.
insert_instruction
(
ins
,
actv_func
,
ht
);
sih
=
ht
;
// add the dimensions of sequence length (axis 0 for sequence length,
// axis 1 for num_directions
last_out
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
0
,
1
}},
ht
);
// concatenation for the last last_out is performed in the apply()
// function to ensure the last instruction is concat, then we have
// output inserted
if
(
i
<
seq_len
-
1
)
{
if
(
is_forward
)
{
hidden_out
=
(
seq_index
==
0
)
?
last_out
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
hidden_out
,
last_out
);
}
else
{
hidden_out
=
(
seq_index
==
seq_len
-
1
)
?
last_out
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
last_out
,
hidden_out
);
}
}
}
return
{
hidden_out
,
last_out
};
}
std
::
vector
<
operation
>
rewrite_rnn
::
vanilla_rnn_actv_funcs
(
instruction_ref
ins
)
const
{
auto
rnn_op
=
any_cast
<
op
::
rnn
>
(
ins
->
get_operator
());
// could be 3 to 6 inputs, but the parse_gru function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have any num of arguments
// when writing their program.
if
(
rnn_op
.
direction
==
op
::
rnn_direction
::
bidirectional
)
{
if
(
rnn_op
.
actv_funcs
.
empty
())
{
// default is tanh
return
{
op
::
tanh
{},
op
::
tanh
{}};
}
else
if
(
rnn_op
.
actv_funcs
.
size
()
==
1
)
{
return
{
rnn_op
.
actv_funcs
.
at
(
0
),
rnn_op
.
actv_funcs
.
at
(
0
)};
}
else
{
return
rnn_op
.
actv_funcs
;
}
}
else
{
if
(
rnn_op
.
actv_funcs
.
empty
())
{
// default is tanh
return
{
op
::
tanh
{}};
}
else
{
return
rnn_op
.
actv_funcs
;
}
}
}
void
rewrite_rnn
::
apply_gru
(
program
&
prog
,
instruction_ref
ins
)
const
{
assert
(
ins
->
name
()
==
"gru"
);
const
auto
actv_funcs
=
gru_actv_funcs
(
ins
);
// could be 3 to 6 inputs, but the parse_gru function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have num of arguments
// when writing their program.
auto
args
=
ins
->
inputs
();
shape
seq_shape
=
args
[
0
]
->
get_shape
();
std
::
size_t
hidden_size
=
args
[
2
]
->
get_shape
().
lens
()[
2
];
std
::
size_t
batch_size
=
seq_shape
.
lens
()[
1
];
shape
::
type_t
type
=
seq_shape
.
type
();
migraphx
::
shape
ih_shape
{
type
,
{
1
,
batch_size
,
hidden_size
}};
std
::
vector
<
float
>
data
(
ih_shape
.
elements
(),
0.0
);
auto
gru_op
=
any_cast
<
op
::
gru
>
(
ins
->
get_operator
());
op
::
rnn_direction
dicrt
=
gru_op
.
direction
;
instruction_ref
last_output
{};
if
(
dicrt
==
op
::
rnn_direction
::
bidirectional
)
{
// w weight matrix
auto
w_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
1
]);
auto
w_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
1
]);
// r weight matrix
auto
r_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
2
]);
auto
r_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
2
]);
// bias
instruction_ref
bias_forward
=
prog
.
end
();
instruction_ref
bias_reverse
=
prog
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"undefined"
)
{
bias_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
3
]);
bias_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
3
]);
}
// intial hidden state
instruction_ref
ih_forward
{};
instruction_ref
ih_reverse
{};
if
(
args
.
size
()
==
6
&&
args
[
5
]
->
name
()
!=
"undefined"
)
{
ih_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
5
]);
ih_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
5
]);
}
else
{
ih_forward
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
ih_reverse
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
auto
ret_forward
=
gru_cell
(
true
,
prog
,
ins
,
{
args
[
0
],
w_forward
,
r_forward
,
bias_forward
,
ih_forward
},
gru_op
.
linear_before_reset
,
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
));
auto
ret_reverse
=
gru_cell
(
false
,
prog
,
ins
,
{
args
[
0
],
w_reverse
,
r_reverse
,
bias_reverse
,
ih_reverse
},
gru_op
.
linear_before_reset
,
actv_funcs
.
at
(
2
),
actv_funcs
.
at
(
3
));
auto
concat_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
concat_output
);
// The following logic is to ensure the last instruction rewritten
// from gru operator is a concat
if
(
ret_forward
[
0
]
==
prog
.
end
())
{
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
}
else
{
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
0
],
ret_forward
[
1
]);
ret_reverse
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_reverse
[
1
],
ret_reverse
[
0
]);
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
{
ret_forward
[
0
],
ret_reverse
[
0
]});
}
}
else
{
bool
is_forward
=
(
dicrt
==
op
::
rnn_direction
::
forward
);
// weight matrix
auto
w
=
args
[
1
];
auto
r
=
args
[
2
];
// bias
instruction_ref
bias
=
prog
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"undefined"
)
{
bias
=
args
[
3
];
}
// intial hidden state
instruction_ref
ih
{};
if
(
args
.
size
()
==
6
&&
args
[
5
]
->
name
()
!=
"undefined"
)
{
ih
=
args
[
5
];
}
else
{
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
auto
ret
=
gru_cell
(
is_forward
,
prog
,
ins
,
{
args
[
0
],
w
,
r
,
bias
,
ih
},
gru_op
.
linear_before_reset
,
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
));
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ret
[
1
]);
if
(
ret
[
0
]
==
prog
.
end
())
{
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
ret
[
1
]);
}
else
{
auto
concat_arg0
=
is_forward
?
ret
[
0
]
:
ret
[
1
];
auto
concat_arg1
=
is_forward
?
ret
[
1
]
:
ret
[
0
];
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
concat_arg0
,
concat_arg1
);
}
}
// replace the corresponding rnn_last_output instruction
// with the last_output, if rnn_last_output exists
// while loop to handle case of multiple rnn_last_output operators
auto
last_output_it
=
ins
->
outputs
().
begin
();
while
(
last_output_it
!=
ins
->
outputs
().
end
())
{
last_output_it
=
std
::
find_if
(
last_output_it
,
ins
->
outputs
().
end
(),
[](
auto
i
)
{
return
i
->
name
()
==
"rnn_last_output"
;
});
if
(
last_output_it
!=
ins
->
outputs
().
end
())
{
prog
.
replace_instruction
(
*
last_output_it
,
last_output
);
last_output_it
++
;
}
}
}
std
::
vector
<
instruction_ref
>
rewrite_rnn
::
gru_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
int
linear_before_reset
,
const
operation
&
actv_func1
,
const
operation
&
actv_func2
)
const
{
assert
(
inputs
.
size
()
==
5
);
auto
seq
=
inputs
.
at
(
0
);
auto
w
=
inputs
.
at
(
1
);
auto
r
=
inputs
.
at
(
2
);
auto
bias
=
inputs
.
at
(
3
);
auto
ih
=
inputs
.
at
(
4
);
instruction_ref
hidden_states
=
prog
.
end
();
instruction_ref
last_output
{};
migraphx
::
shape
seq_shape
=
seq
->
get_shape
();
migraphx
::
shape
r_shape
=
r
->
get_shape
();
long
seq_len
=
static_cast
<
long
>
(
seq_shape
.
lens
()[
0
]);
long
hs
=
static_cast
<
long
>
(
r_shape
.
lens
()[
2
]);
migraphx
::
shape
s
(
seq_shape
.
type
(),
{
seq_shape
.
lens
()[
1
],
r_shape
.
lens
()[
2
]});
std
::
vector
<
int
>
data
(
s
.
elements
(),
1
);
auto
l1
=
prog
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
// weight matrix
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
wz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sw
);
auto
tran_wz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wz
);
auto
wr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sw
);
auto
tran_wr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wr
);
auto
wh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sw
);
auto
tran_wh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wh
);
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
rz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sr
);
auto
tran_rz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rz
);
auto
rr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sr
);
auto
tran_rr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rr
);
auto
rh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sr
);
auto
tran_rh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rh
);
// initial states
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
// bias
instruction_ref
brcst_bz
{};
instruction_ref
brcst_br
{};
instruction_ref
brcst_wbh
{};
instruction_ref
brcst_rbh
{};
instruction_ref
brcst_bh
{};
if
(
bias
!=
prog
.
end
())
{
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
wbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
wbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sbias
);
brcst_wbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
wbh
);
auto
rbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sbias
);
auto
rbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
sbias
);
auto
rbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
brcst_rbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
rbh
);
auto
bz
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbz
,
rbz
);
brcst_bz
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
bz
);
auto
br
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbr
,
rbr
);
brcst_br
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
br
);
auto
bh
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbh
,
rbh
);
brcst_bh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
bh
);
}
for
(
long
i
=
0
;
i
<
seq_len
;
i
++
)
{
long
seq_index
=
is_forward
?
i
:
(
seq_len
-
1
-
i
);
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
seq
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto
xt_wz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wz
);
auto
ht_rz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rz
);
auto
xht_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wz
,
ht_rz
);
if
(
bias
!=
prog
.
end
())
{
xht_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_z
,
brcst_bz
);
}
auto
zt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xht_z
);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto
xt_wr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wr
);
auto
ht_rr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rr
);
auto
xht_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wr
,
ht_rr
);
if
(
bias
!=
prog
.
end
())
{
xht_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_r
,
brcst_br
);
}
auto
rt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xht_r
);
instruction_ref
xht_h
;
if
(
linear_before_reset
==
0
)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto
xt_wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wh
);
auto
rt_ht1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
sih
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht1
,
tran_rh
);
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wh
,
rt_rh
);
if
(
bias
!=
prog
.
end
())
{
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_h
,
brcst_bh
);
}
}
else
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto
xt_wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wh
);
auto
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rh
);
if
(
bias
!=
prog
.
end
())
{
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ht1_rh
,
brcst_rbh
);
}
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ht1_rh
);
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wh
,
rt_rh
);
if
(
bias
!=
prog
.
end
())
{
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_h
,
brcst_wbh
);
}
}
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
xht_h
);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto
one_minus_zt
=
prog
.
insert_instruction
(
ins
,
op
::
sub
{},
l1
,
zt
);
auto
one_minus_zt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
one_minus_zt
,
ht
);
auto
zt_ht1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
zt
,
sih
);
sih
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
one_minus_zt_ht
,
zt_ht1
);
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
0
,
1
}},
sih
);
if
(
i
<
seq_len
-
1
)
{
if
(
is_forward
)
{
hidden_states
=
(
seq_index
==
0
)
?
last_output
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
hidden_states
,
last_output
);
}
else
{
hidden_states
=
(
seq_index
==
seq_len
-
1
)
?
last_output
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
last_output
,
hidden_states
);
}
}
}
return
{
hidden_states
,
last_output
};
}
std
::
vector
<
operation
>
rewrite_rnn
::
gru_actv_funcs
(
instruction_ref
ins
)
const
{
auto
gru_op
=
any_cast
<
op
::
gru
>
(
ins
->
get_operator
());
// before rewrite the gru operator, need to ensure
// we have 4 actv funcs, even though a user does not
// specifiy any actv func. If less than 4, use the
// algorithm in parse_gru to make 4 actv functions
if
(
gru_op
.
direction
==
op
::
rnn_direction
::
bidirectional
)
{
if
(
gru_op
.
actv_funcs
.
empty
())
return
{
op
::
sigmoid
{},
op
::
tanh
{},
op
::
sigmoid
{},
op
::
tanh
{}};
else
if
(
gru_op
.
actv_funcs
.
size
()
==
1
)
return
{
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
0
)};
else
if
(
gru_op
.
actv_funcs
.
size
()
==
2
)
return
{
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
),
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
)};
else
if
(
gru_op
.
actv_funcs
.
size
()
==
3
)
return
{
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
),
gru_op
.
actv_funcs
.
at
(
2
),
gru_op
.
actv_funcs
.
at
(
0
)};
else
return
gru_op
.
actv_funcs
;
}
else
{
if
(
gru_op
.
actv_funcs
.
empty
())
return
{
op
::
sigmoid
{},
op
::
tanh
{}};
else
if
(
gru_op
.
actv_funcs
.
size
()
==
1
)
return
{
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
0
)};
else
return
gru_op
.
actv_funcs
;
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/simplify_reshapes.cpp
View file @
8a3d1d09
...
@@ -9,39 +9,47 @@
...
@@ -9,39 +9,47 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
// Reshapers that can't handle nonstandard input shapes
bool
is_nonstandard_reshaper
(
instruction_ref
ins
)
{
// clang-format off
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
"reshape"
};
// clang-format on
return
contains
(
names
,
ins
->
name
())
and
ins
->
inputs
().
front
()
->
name
()
==
"contiguous"
;
}
bool
is_reshaper
(
instruction_ref
ins
)
bool
is_reshaper
(
instruction_ref
ins
)
{
{
// clang-format off
// clang-format off
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
"reshape"
,
"reshape"
,
"transpose"
,
// "broadcast",
"contiguous"
"contiguous"
};
};
// clang-format on
// clang-format on
return
contains
(
names
,
ins
->
name
())
and
not
is_nonstandard_reshaper
(
ins
);
return
contains
(
names
,
ins
->
name
());
}
bool
is_transpose_output
(
instruction_ref
ins
)
{
if
(
ins
->
outputs
().
size
()
!=
1
)
return
false
;
if
(
ins
->
outputs
().
front
()
->
name
()
==
"contiguous"
)
return
is_transpose_output
(
ins
->
outputs
().
front
());
return
ins
->
outputs
().
front
()
->
name
()
==
"transpose"
;
}
instruction_ref
find_transpose_input
(
instruction_ref
ins
)
{
if
(
ins
->
inputs
().
size
()
!=
1
)
return
ins
;
if
(
ins
->
inputs
().
front
()
->
name
()
==
"contiguous"
)
return
find_transpose_input
(
ins
->
inputs
().
front
());
if
(
ins
->
inputs
().
front
()
->
name
()
==
"transpose"
)
return
ins
->
inputs
().
front
();
return
ins
;
}
}
void
simplify_reshapes
::
apply
(
program
&
p
)
const
void
simplify_reshapes
::
apply
(
program
&
p
)
const
{
{
auto
end
=
std
::
prev
(
p
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
if
(
not
is_reshaper
(
ins
))
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
continue
;
if
(
ins
->
outputs
().
size
()
!=
1
)
continue
;
continue
;
if
(
is_reshaper
(
ins
->
outputs
().
front
()))
if
(
is_reshaper
(
ins
))
{
if
(
std
::
any_of
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
&
is_reshaper
))
continue
;
continue
;
// Gather reshapes
// Gather reshapes
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
...
@@ -70,6 +78,22 @@ void simplify_reshapes::apply(program& p) const
...
@@ -70,6 +78,22 @@ void simplify_reshapes::apply(program& p) const
p
.
replace_instruction
(
r
.
first
,
r
.
second
);
p
.
replace_instruction
(
r
.
first
,
r
.
second
);
}
}
}
}
else
if
(
ins
->
name
()
==
"transpose"
)
{
if
(
is_transpose_output
(
ins
))
continue
;
auto
x
=
ins
;
auto
t
=
ins
;
do
{
x
=
t
;
t
=
find_transpose_input
(
x
);
}
while
(
x
!=
t
and
t
->
name
()
==
"transpose"
);
if
(
t
==
ins
or
t
->
name
()
!=
"transpose"
)
continue
;
p
.
replace_instruction
(
ins
,
t
->
inputs
().
front
());
}
}
// Replace all reshapes with as_shape
// Replace all reshapes with as_shape
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
...
...
src/targets/cpu/lowering.cpp
View file @
8a3d1d09
...
@@ -103,6 +103,43 @@ struct cpu_batch_norm_inference
...
@@ -103,6 +103,43 @@ struct cpu_batch_norm_inference
}
}
};
};
struct
cpu_lrn
{
op
::
lrn
op
;
std
::
string
name
()
const
{
return
"cpu::lrn"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
int
n_batch
=
output_shape
.
lens
()[
0
];
int
channels
=
output_shape
.
lens
()[
1
];
int
height
=
output_shape
.
lens
()[
2
];
int
width
=
output_shape
.
lens
()[
3
];
float
alphaoverarea
=
op
.
alpha
/
op
.
size
;
int
radius
=
(
op
.
size
-
1
)
/
2
;
par_dfor
(
n_batch
,
height
,
width
)([
&
](
int
b
,
int
h
,
int
w
)
{
float
scale
=
0
;
dfor
(
channels
)([
&
](
int
c
)
{
auto
start
=
(
c
-
radius
)
<
0
?
0
:
(
c
-
radius
);
auto
end
=
(
c
+
radius
)
>
channels
?
channels
:
(
c
+
radius
);
for
(
auto
k
=
start
;
k
<
end
;
++
k
)
{
scale
+=
std
::
pow
(
input
(
b
,
k
,
h
,
w
),
2
);
}
scale
*=
alphaoverarea
;
scale
+=
op
.
bias
;
scale
=
std
::
pow
(
scale
,
-
op
.
beta
);
output
(
b
,
c
,
h
,
w
)
=
input
(
b
,
c
,
h
,
w
)
*
scale
;
});
});
});
return
result
;
}
};
struct
cpu_convolution
struct
cpu_convolution
{
{
op
::
convolution
op
;
op
::
convolution
op
;
...
@@ -287,14 +324,7 @@ struct cpu_contiguous
...
@@ -287,14 +324,7 @@ struct cpu_contiguous
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
assert
(
output_shape
.
standard
());
return
op
.
compute
(
output_shape
,
std
::
move
(
args
));
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
input
(
idx
.
begin
(),
idx
.
end
());
});
});
return
result
;
}
}
};
};
...
@@ -688,6 +718,7 @@ struct cpu_apply
...
@@ -688,6 +718,7 @@ struct cpu_apply
apply_map
[
"dot"
]
=
extend_op
<
cpu_gemm
,
op
::
dot
>
();
apply_map
[
"dot"
]
=
extend_op
<
cpu_gemm
,
op
::
dot
>
();
apply_map
[
"batch_norm_inference"
]
=
apply_map
[
"batch_norm_inference"
]
=
extend_op
<
cpu_batch_norm_inference
,
op
::
batch_norm_inference
>
();
extend_op
<
cpu_batch_norm_inference
,
op
::
batch_norm_inference
>
();
apply_map
[
"lrn"
]
=
extend_op
<
cpu_lrn
,
op
::
lrn
>
();
apply_map
[
"contiguous"
]
=
extend_op
<
cpu_contiguous
,
op
::
contiguous
>
();
apply_map
[
"contiguous"
]
=
extend_op
<
cpu_contiguous
,
op
::
contiguous
>
();
apply_map
[
"pad"
]
=
extend_op
<
cpu_pad
,
op
::
pad
>
();
apply_map
[
"pad"
]
=
extend_op
<
cpu_pad
,
op
::
pad
>
();
apply_map
[
"concat"
]
=
extend_op
<
cpu_concat
,
op
::
concat
>
();
apply_map
[
"concat"
]
=
extend_op
<
cpu_concat
,
op
::
concat
>
();
...
...
src/targets/cpu/target.cpp
View file @
8a3d1d09
...
@@ -2,6 +2,8 @@
...
@@ -2,6 +2,8 @@
#include <migraphx/cpu/target.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/cpu/lowering.hpp>
#include <migraphx/cpu/lowering.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/dead_code_elimination.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -11,7 +13,11 @@ std::string target::name() const { return "cpu"; }
...
@@ -11,7 +13,11 @@ std::string target::name() const { return "cpu"; }
std
::
vector
<
pass
>
target
::
get_passes
(
migraphx
::
context
&
)
const
std
::
vector
<
pass
>
target
::
get_passes
(
migraphx
::
context
&
)
const
{
{
return
{
auto_contiguous
{},
lowering
{}};
return
{
auto_contiguous
{},
rewrite_rnn
{},
dead_code_elimination
{},
lowering
{},
dead_code_elimination
{}};
}
}
}
// namespace cpu
}
// namespace cpu
...
...
src/targets/gpu/CMakeLists.txt
View file @
8a3d1d09
...
@@ -30,6 +30,7 @@ add_library(migraphx_device
...
@@ -30,6 +30,7 @@ add_library(migraphx_device
device/concat.cpp
device/concat.cpp
device/pad.cpp
device/pad.cpp
device/gather.cpp
device/gather.cpp
device/sub.cpp
)
)
set_target_properties
(
migraphx_device PROPERTIES EXPORT_NAME device
)
set_target_properties
(
migraphx_device PROPERTIES EXPORT_NAME device
)
rocm_clang_tidy_check
(
migraphx_device
)
rocm_clang_tidy_check
(
migraphx_device
)
...
@@ -60,6 +61,7 @@ add_library(migraphx_gpu
...
@@ -60,6 +61,7 @@ add_library(migraphx_gpu
elu.cpp
elu.cpp
pad.cpp
pad.cpp
gather.cpp
gather.cpp
lrn.cpp
)
)
set_target_properties
(
migraphx_gpu PROPERTIES EXPORT_NAME gpu
)
set_target_properties
(
migraphx_gpu PROPERTIES EXPORT_NAME gpu
)
rocm_clang_tidy_check
(
migraphx_gpu
)
rocm_clang_tidy_check
(
migraphx_gpu
)
...
...
src/targets/gpu/device/sub.cpp
0 → 100644
View file @
8a3d1d09
#include <migraphx/gpu/device/sub.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
void
sub
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
{
nary
(
stream
,
result
,
arg1
,
arg2
)([](
auto
x
,
auto
y
)
{
return
y
-
x
;
});
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/gemm.cpp
View file @
8a3d1d09
...
@@ -107,6 +107,7 @@ argument miopen_gemm::compute(context& ctx,
...
@@ -107,6 +107,7 @@ argument miopen_gemm::compute(context& ctx,
ldc
);
ldc
);
});
});
return
args
[
2
];
return
args
[
2
];
}
}
...
...
src/targets/gpu/include/migraphx/gpu/device/sub.hpp
0 → 100644
View file @
8a3d1d09
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_SUB_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_SUB_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
void
sub
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
);
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/include/migraphx/gpu/lrn.hpp
0 → 100644
View file @
8a3d1d09
#ifndef MIGRAPHX_GUARD_RTGLIB_LRN_HPP
#define MIGRAPHX_GUARD_RTGLIB_LRN_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
struct
miopen_lrn
{
shared
<
lrn_descriptor
>
ldesc
;
std
::
string
name
()
const
{
return
"gpu::lrn"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
;
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
int
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment