Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c984fd24
Unverified
Commit
c984fd24
authored
Mar 04, 2019
by
Paul Fultz II
Committed by
GitHub
Mar 04, 2019
Browse files
Merge branch 'develop' into tf_pb
parents
fa983162
8274cb47
Changes
98
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
898 additions
and
105 deletions
+898
-105
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+7
-0
src/include/migraphx/literal.hpp
src/include/migraphx/literal.hpp
+2
-2
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+97
-28
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+12
-0
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+208
-18
src/program.cpp
src/program.cpp
+1
-0
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+507
-2
src/shape.cpp
src/shape.cpp
+1
-1
src/targets/cpu/gemm.cpp
src/targets/cpu/gemm.cpp
+35
-14
src/targets/cpu/include/migraphx/cpu/target.hpp
src/targets/cpu/include/migraphx/cpu/target.hpp
+1
-0
src/targets/cpu/target.cpp
src/targets/cpu/target.cpp
+1
-0
src/targets/gpu/abs.cpp
src/targets/gpu/abs.cpp
+1
-4
src/targets/gpu/batchnorm.cpp
src/targets/gpu/batchnorm.cpp
+1
-4
src/targets/gpu/concat.cpp
src/targets/gpu/concat.cpp
+1
-4
src/targets/gpu/contiguous.cpp
src/targets/gpu/contiguous.cpp
+2
-4
src/targets/gpu/convolution.cpp
src/targets/gpu/convolution.cpp
+2
-4
src/targets/gpu/device/gather.cpp
src/targets/gpu/device/gather.cpp
+15
-11
src/targets/gpu/elu.cpp
src/targets/gpu/elu.cpp
+1
-4
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+1
-0
src/targets/gpu/gather.cpp
src/targets/gpu/gather.cpp
+2
-5
No files found.
src/eliminate_contiguous.cpp
View file @
c984fd24
...
@@ -27,6 +27,13 @@ void eliminate_contiguous::apply(program& p) const
...
@@ -27,6 +27,13 @@ void eliminate_contiguous::apply(program& p) const
{
{
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
// skip the reshape operator for now, since there is a bug
// for the transpose followed by a reshape
if
(
ins
->
name
()
==
"reshape"
)
{
continue
;
}
// Make a copy so we can modify it while we iterate
// Make a copy so we can modify it while we iterate
auto
args
=
ins
->
inputs
();
auto
args
=
ins
->
inputs
();
for
(
auto
arg
:
ins
->
inputs
())
for
(
auto
arg
:
ins
->
inputs
())
...
...
src/include/migraphx/literal.hpp
View file @
c984fd24
...
@@ -22,8 +22,8 @@ struct literal : raw_data<literal>
...
@@ -22,8 +22,8 @@ struct literal : raw_data<literal>
{
{
literal
()
{}
literal
()
{}
template
<
class
U
,
class
T
=
deduce
<
U
>
>
template
<
class
U
,
class
T
=
deduce
<
U
>
,
shape
::
type_t
ShapeType
=
shape
::
get_type
<
T
>
{}
>
literal
(
U
x
)
:
buffer
(
make_shared_array
<
char
>
(
sizeof
(
T
))),
m_shape
(
s
hape
::
get_type
<
T
>
{}
)
literal
(
U
x
)
:
buffer
(
make_shared_array
<
char
>
(
sizeof
(
T
))),
m_shape
(
S
hape
Type
)
{
{
static_assert
(
std
::
is_trivially_copyable
<
T
>
{},
"Literals can only be trivial types"
);
static_assert
(
std
::
is_trivially_copyable
<
T
>
{},
"Literals can only be trivial types"
);
*
(
reinterpret_cast
<
T
*>
(
buffer
.
get
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
buffer
.
get
()))
=
x
;
...
...
src/include/migraphx/operators.hpp
View file @
c984fd24
...
@@ -758,42 +758,48 @@ struct gather
...
@@ -758,42 +758,48 @@ struct gather
int
axis_index
=
(
axis
<
0
)
?
(
n_dim
+
axis
)
:
axis
;
int
axis_index
=
(
axis
<
0
)
?
(
n_dim
+
axis
)
:
axis
;
auto
type
=
inputs
[
0
].
type
();
auto
type
=
inputs
[
0
].
type
();
lens
[
axis_index
]
=
inputs
[
1
].
elements
();
lens
.
erase
(
lens
.
begin
()
+
axis_index
);
if
(
!
inputs
[
1
].
scalar
())
return
{
type
,
lens
};
{
auto
ind_lens
=
inputs
[
1
].
lens
();
lens
.
insert
(
lens
.
begin
()
+
axis_index
,
ind_lens
.
begin
(),
ind_lens
.
end
());
}
}
template
<
class
T
>
// for scalar output
void
compute_index
(
const
T
&
out_idx
,
if
(
lens
.
empty
())
const
int
axis_index
,
const
std
::
vector
<
std
::
size_t
>&
vec_indices
,
const
std
::
size_t
max_dim
,
T
&
in_idx
)
const
{
in_idx
=
out_idx
;
std
::
size_t
idx
=
vec_indices
.
at
(
out_idx
[
axis_index
]);
if
(
idx
>=
max_dim
)
{
{
MIGRAPHX_THROW
(
"Gather: indices are out of range in input tensor"
)
;
return
{
type
}
;
}
}
in_idx
[
axis_index
]
=
idx
;
return
{
type
,
lens
};
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
// negative axis means counting dimensions from back
// negative axis means counting dimensions from back
int
axis_index
=
(
axis
<
0
)
?
(
output_shape
.
lens
().
size
()
+
axis
)
:
axis
;
int
axis_index
=
(
axis
<
0
)
?
static_cast
<
int
>
(
args
[
0
].
get_shape
().
lens
().
size
()
+
axis
)
:
axis
;
// max dimension in axis
// max dimension in axis
std
::
size_t
max_dim
=
args
[
0
].
get_shape
().
lens
()[
axis_index
];
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
std
::
vector
<
std
::
size_t
>
vec_indices
;
args
[
1
].
visit
([
&
](
auto
indices
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
vec_indices
.
assign
(
indices
.
begin
(),
indices
.
end
());
});
if
(
output_shape
.
scalar
())
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
{
std
::
vector
<
std
::
size_t
>
in_idx
;
output
[
0
]
=
data
[
indices
.
front
()];
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
}
this
->
compute_index
(
idx
,
axis_index
,
vec_indices
,
max_dim
,
in_idx
);
else
output
(
idx
.
begin
(),
idx
.
end
())
=
input
(
in_idx
.
begin
(),
in_idx
.
end
());
{
auto
out_lens
=
data
.
get_shape
().
lens
();
out_lens
[
axis_index
]
=
indices
.
get_shape
().
elements
();
migraphx
::
shape
out_comp_shape
{
data
.
get_shape
().
type
(),
out_lens
};
shape_for_each
(
out_comp_shape
,
[
&
](
const
auto
&
out_idx
)
{
auto
data_idx
=
out_idx
;
data_idx
[
axis_index
]
=
indices
[
data_idx
[
axis_index
]];
output
[
out_comp_shape
.
index
(
out_idx
.
begin
(),
out_idx
.
end
())]
=
data
(
data_idx
.
begin
(),
data_idx
.
end
());
});
}
});
});
});
});
...
@@ -820,10 +826,22 @@ struct dot
...
@@ -820,10 +826,22 @@ struct dot
const
shape
&
b
=
inputs
.
at
(
1
);
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
auto
t
=
a
.
type
();
if
(
a
.
lens
()[
1
]
!=
b
.
lens
()[
0
])
// according to the specification of the numpy.matmul()
// inputs with the shape dims more than 2 are acceptable
// as long as dim values are the same in the two inputs
if
(
!
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
))
{
MIGRAPHX_THROW
(
"DOT: dim values mismatch"
);
}
std
::
size_t
dim_0
=
a
.
lens
().
size
()
-
2
;
std
::
size_t
dim_1
=
a
.
lens
().
size
()
-
1
;
if
(
a
.
lens
()[
dim_1
]
!=
b
.
lens
()[
dim_0
])
MIGRAPHX_THROW
(
"Inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
MIGRAPHX_THROW
(
"Inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
return
{
t
,
{
a
.
lens
()[
0
],
b
.
lens
()[
1
]}};
auto
out_lens
=
a
.
lens
();
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
return
{
t
,
out_lens
};
}
}
};
};
...
@@ -1259,6 +1277,57 @@ struct gru
...
@@ -1259,6 +1277,57 @@ struct gru
}
}
};
};
struct
lstm
{
std
::
size_t
hidden_size
=
1
;
std
::
vector
<
operation
>
actv_funcs
{
sigmoid
{},
tanh
{},
tanh
{}};
rnn_direction
direction
=
rnn_direction
::
forward
;
float
clip
=
0.0
f
;
int
input_forget
=
0
;
std
::
string
name
()
const
{
return
"lstm"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
auto
in_dims
=
inputs
[
0
].
lens
();
auto
hidden_dims
=
inputs
[
2
].
lens
();
if
(
hidden_size
!=
hidden_dims
[
2
])
{
MIGRAPHX_THROW
(
"LSTM: hidden size mismatch in attribute and input"
);
}
std
::
size_t
num_directions
=
1
;
if
(
direction
==
rnn_direction
::
bidirectional
)
{
num_directions
=
2
;
}
if
(
num_directions
!=
hidden_dims
[
0
])
{
MIGRAPHX_THROW
(
"LSTM: num_direction does not match the direction attribute"
);
}
std
::
vector
<
std
::
size_t
>
out_dims
(
in_dims
);
out_dims
.
insert
(
out_dims
.
begin
()
+
1
,
num_directions
);
out_dims
.
back
()
=
hidden_size
;
return
{
inputs
[
0
].
type
(),
out_dims
};
}
};
struct
lstm_last_cell_output
{
std
::
string
name
()
const
{
return
"lstm_last_cell_output"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
dims
=
inputs
[
0
].
lens
();
// remove the first dimension, remaing are output shape
dims
.
erase
(
dims
.
begin
());
return
{
inputs
[
0
].
type
(),
dims
};
}
};
struct
undefined
struct
undefined
{
{
std
::
string
name
()
const
{
return
"undefined"
;
}
std
::
string
name
()
const
{
return
"undefined"
;
}
...
...
src/include/migraphx/rewrite_rnn.hpp
View file @
c984fd24
...
@@ -45,6 +45,18 @@ struct rewrite_rnn
...
@@ -45,6 +45,18 @@ struct rewrite_rnn
const
operation
&
actv_func2
)
const
;
const
operation
&
actv_func2
)
const
;
std
::
vector
<
operation
>
gru_actv_funcs
(
instruction_ref
ins
)
const
;
std
::
vector
<
operation
>
gru_actv_funcs
(
instruction_ref
ins
)
const
;
// for lstm operators
void
apply_lstm
(
program
&
prog
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
lstm_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
const
operation
&
actv_func1
,
const
operation
&
actv_func2
,
const
operation
&
actv_func3
)
const
;
std
::
vector
<
operation
>
lstm_actv_funcs
(
instruction_ref
ins
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/onnx/onnx.cpp
View file @
c984fd24
...
@@ -89,6 +89,7 @@ struct onnx_parser
...
@@ -89,6 +89,7 @@ struct onnx_parser
add_mem_op
(
"Transpose"
,
&
onnx_parser
::
parse_transpose
);
add_mem_op
(
"Transpose"
,
&
onnx_parser
::
parse_transpose
);
add_mem_op
(
"RNN"
,
&
onnx_parser
::
parse_rnn
);
add_mem_op
(
"RNN"
,
&
onnx_parser
::
parse_rnn
);
add_mem_op
(
"GRU"
,
&
onnx_parser
::
parse_gru
);
add_mem_op
(
"GRU"
,
&
onnx_parser
::
parse_gru
);
add_mem_op
(
"LSTM"
,
&
onnx_parser
::
parse_lstm
);
add_mem_op
(
"Pad"
,
&
onnx_parser
::
parse_pad
);
add_mem_op
(
"Pad"
,
&
onnx_parser
::
parse_pad
);
// init the activation function map
// init the activation function map
...
@@ -354,7 +355,9 @@ struct onnx_parser
...
@@ -354,7 +355,9 @@ struct onnx_parser
}
}
if
(
args
.
size
()
==
2
)
if
(
args
.
size
()
==
2
)
{
{
literal
s
=
args
[
1
]
->
get_literal
();
auto
s
=
args
[
1
]
->
eval
();
if
(
s
.
empty
())
MIGRAPHX_THROW
(
"Dynamic shape is not supported."
);
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
}
}
return
prog
.
add_instruction
(
op
,
args
[
0
]);
return
prog
.
add_instruction
(
op
,
args
[
0
]);
...
@@ -434,6 +437,14 @@ struct onnx_parser
...
@@ -434,6 +437,14 @@ struct onnx_parser
const
std
::
vector
<
instruction_ref
>&
)
const
std
::
vector
<
instruction_ref
>&
)
{
{
literal
v
=
parse_value
(
attributes
.
at
(
"value"
));
literal
v
=
parse_value
(
attributes
.
at
(
"value"
));
auto
dim_size
=
attributes
.
at
(
"value"
).
t
().
dims_size
();
// if dim_size is 0, it is a scalar
if
(
dim_size
==
0
)
{
migraphx
::
shape
scalar_shape
{
v
.
get_shape
().
type
()};
return
prog
.
add_literal
(
migraphx
::
literal
{
scalar_shape
,
v
.
data
()});
}
return
prog
.
add_literal
(
v
);
return
prog
.
add_literal
(
v
);
}
}
...
@@ -460,7 +471,12 @@ struct onnx_parser
...
@@ -460,7 +471,12 @@ struct onnx_parser
{
{
transb
=
parse_value
(
attributes
.
at
(
"transB"
)).
at
<
bool
>
();
transb
=
parse_value
(
attributes
.
at
(
"transB"
)).
at
<
bool
>
();
}
}
std
::
vector
<
int64_t
>
perm
=
{
1
,
0
};
std
::
vector
<
int64_t
>
perm
(
args
[
0
]
->
get_shape
().
lens
().
size
());
std
::
iota
(
perm
.
begin
(),
perm
.
end
(),
int64_t
{
0
});
// swap the last two elements
std
::
swap
(
*
perm
.
rbegin
(),
*
(
perm
.
rbegin
()
+
1
));
auto
l1
=
(
transa
)
?
prog
.
add_instruction
(
op
::
transpose
{
perm
},
args
[
0
])
:
args
[
0
];
auto
l1
=
(
transa
)
?
prog
.
add_instruction
(
op
::
transpose
{
perm
},
args
[
0
])
:
args
[
0
];
auto
l2
=
(
transb
)
?
prog
.
add_instruction
(
op
::
transpose
{
perm
},
args
[
1
])
:
args
[
1
];
auto
l2
=
(
transb
)
?
prog
.
add_instruction
(
op
::
transpose
{
perm
},
args
[
1
])
:
args
[
1
];
if
(
args
.
size
()
==
3
)
if
(
args
.
size
()
==
3
)
...
@@ -749,15 +765,17 @@ struct onnx_parser
...
@@ -749,15 +765,17 @@ struct onnx_parser
{
{
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
vec_names
.
clear
();
for_each
(
names
.
begin
(),
names
.
end
(),
[
&
](
auto
&
fn
)
{
vec_names
.
push_back
(
fn
);
});
vec_names
.
resize
(
names
.
size
());
std
::
copy
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
());
}
}
for_each
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
fn
)
{
auto
name_it
=
std
::
find_if
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
if
(
map_actv_funcs
.
count
(
fn
)
==
0
)
return
(
map_actv_funcs
.
count
(
name
)
==
0
);
});
if
(
name_it
!=
vec_names
.
end
())
{
{
MIGRAPHX_THROW
(
"RNN: activation function "
+
std
::
string
(
fn
)
+
" not supported"
);
MIGRAPHX_THROW
(
"RNN: activation function "
+
std
::
string
(
*
name_it
)
+
" not supported"
);
}
}
});
// bidirectional case should have two activation functions.
// bidirectional case should have two activation functions.
// one is for forward, and the other is for reverse.
// one is for forward, and the other is for reverse.
...
@@ -839,8 +857,7 @@ struct onnx_parser
...
@@ -839,8 +857,7 @@ struct onnx_parser
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
vec_names
.
clear
();
vec_names
.
resize
(
names
.
size
());
vec_names
.
resize
(
names
.
size
());
std
::
transform
(
std
::
copy
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
());
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
(),
[](
auto
&
str
)
{
return
str
;
});
}
}
// need 4 activation functions
// need 4 activation functions
...
@@ -878,12 +895,13 @@ struct onnx_parser
...
@@ -878,12 +895,13 @@ struct onnx_parser
}
}
}
}
for_each
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
auto
name_it
=
std
::
find_if
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
if
(
map_actv_funcs
.
count
(
name
)
==
0
)
return
(
map_actv_funcs
.
count
(
name
)
==
0
);
});
if
(
name_it
!=
vec_names
.
end
())
{
{
MIGRAPHX_THROW
(
"GRU: activation function "
+
std
::
string
(
name
)
+
" not supported"
);
MIGRAPHX_THROW
(
"GRU: activation function "
+
std
::
string
(
*
name
_it
)
+
" not supported"
);
}
}
});
std
::
vector
<
operation
>
vec_actv_funcs
(
vec_names
.
size
());
std
::
vector
<
operation
>
vec_actv_funcs
(
vec_names
.
size
());
std
::
transform
(
vec_names
.
begin
(),
vec_names
.
end
(),
vec_actv_funcs
.
begin
(),
[
&
](
auto
&
name
)
{
std
::
transform
(
vec_names
.
begin
(),
vec_names
.
end
(),
vec_actv_funcs
.
begin
(),
[
&
](
auto
&
name
)
{
...
@@ -920,6 +938,178 @@ struct onnx_parser
...
@@ -920,6 +938,178 @@ struct onnx_parser
return
{
hidden_states
,
last_output
};
return
{
hidden_states
,
last_output
};
}
}
std
::
vector
<
instruction_ref
>
parse_lstm
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
migraphx
::
shape
input_shape
=
args
[
0
]
->
get_shape
();
std
::
size_t
hidden_size
=
args
[
2
]
->
get_shape
().
lens
()[
2
];
if
(
contains
(
attributes
,
"hidden_size"
))
{
std
::
size_t
hidden_size_att
=
parse_value
(
attributes
.
at
(
"hidden_size"
)).
at
<
int
>
();
if
(
hidden_size
!=
hidden_size_att
)
{
MIGRAPHX_THROW
(
"LSTM: hidden size mismatch in input and attribute"
);
}
}
// Handling of direction to be added later
std
::
string
direction
{
"forward"
};
if
(
contains
(
attributes
,
"direction"
))
{
direction
=
attributes
.
at
(
"direction"
).
s
();
}
op
::
rnn_direction
dirct
=
op
::
rnn_direction
::
forward
;
if
(
direction
==
"bidirectional"
)
{
dirct
=
op
::
rnn_direction
::
bidirectional
;
}
else
if
(
direction
==
"reverse"
)
{
dirct
=
op
::
rnn_direction
::
reverse
;
}
else
if
(
direction
==
"forward"
)
{
dirct
=
op
::
rnn_direction
::
forward
;
}
else
{
MIGRAPHX_THROW
(
"LSTM: incorrect direction attribute"
);
}
std
::
vector
<
std
::
string
>
vec_names
=
{
"sigmoid"
,
"tanh"
,
"tanh"
};
if
(
contains
(
attributes
,
"activations"
))
{
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
vec_names
.
resize
(
names
.
size
());
std
::
copy
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
());
}
// need 6 activation functions for bidirectional directions
if
(
dirct
==
op
::
rnn_direction
::
bidirectional
)
{
// 6 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provided,
// repeat 1st six times. If 2 actv functins are provided,
// repeat 2nd once, then repeat all three once
// if 3 actv funcs are provide, repeat all three once.
// the same algorithm is used for 4, 5, and 6 actv funcions
// provided. This may need change later
switch
(
vec_names
.
size
())
{
case
1
:
vec_names
=
{
vec_names
.
at
(
0
),
vec_names
.
at
(
0
),
vec_names
.
at
(
0
),
vec_names
.
at
(
0
),
vec_names
.
at
(
0
),
vec_names
.
at
(
0
)};
break
;
case
2
:
// repeat the 2nd actv func once, then repeat all three another time
vec_names
=
{
vec_names
.
at
(
0
),
vec_names
.
at
(
1
),
vec_names
.
at
(
1
),
vec_names
.
at
(
0
),
vec_names
.
at
(
1
),
vec_names
.
at
(
1
)};
break
;
case
3
:
// repeat all three actv funcs once
vec_names
=
{
vec_names
.
at
(
0
),
vec_names
.
at
(
1
),
vec_names
.
at
(
2
),
vec_names
.
at
(
0
),
vec_names
.
at
(
1
),
vec_names
.
at
(
2
)};
break
;
case
4
:
vec_names
=
{
vec_names
.
at
(
0
),
vec_names
.
at
(
1
),
vec_names
.
at
(
2
),
vec_names
.
at
(
3
),
vec_names
.
at
(
3
),
vec_names
.
at
(
3
)};
break
;
case
5
:
vec_names
=
{
vec_names
.
at
(
0
),
vec_names
.
at
(
1
),
vec_names
.
at
(
2
),
vec_names
.
at
(
3
),
vec_names
.
at
(
4
),
vec_names
.
at
(
4
)};
break
;
default:
break
;
}
}
else
{
switch
(
vec_names
.
size
())
{
case
1
:
vec_names
=
{
vec_names
.
at
(
0
),
vec_names
.
at
(
0
),
vec_names
.
at
(
0
)};
break
;
case
2
:
// repeat the 2nd actv func once, so we have 3 actv funcs
vec_names
=
{
vec_names
.
at
(
0
),
vec_names
.
at
(
1
),
vec_names
.
at
(
1
)};
break
;
default:
break
;
}
}
auto
name_it
=
std
::
find_if
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
return
(
map_actv_funcs
.
count
(
name
)
==
0
);
});
if
(
name_it
!=
vec_names
.
end
())
{
MIGRAPHX_THROW
(
"LSTM: activation function "
+
std
::
string
(
*
name_it
)
+
" not supported"
);
}
std
::
vector
<
operation
>
vec_actv_funcs
(
vec_names
.
size
());
std
::
transform
(
vec_names
.
begin
(),
vec_names
.
end
(),
vec_actv_funcs
.
begin
(),
[
&
](
auto
&
name
)
{
return
map_actv_funcs
[
name
];
});
float
clip
=
0.0
;
if
(
contains
(
attributes
,
"clip"
))
{
clip
=
parse_value
(
attributes
.
at
(
"clip"
)).
at
<
float
>
();
}
int
input_forget
=
0
;
if
(
contains
(
attributes
,
"input_forget"
))
{
input_forget
=
parse_value
(
attributes
.
at
(
"input_forget"
)).
at
<
int
>
();
}
// append undefined opeator to make 6 arguments
if
(
args
.
size
()
<
8
)
{
auto
ins
=
prog
.
add_instruction
(
op
::
undefined
{});
args
.
insert
(
args
.
end
(),
8
-
args
.
size
(),
ins
);
}
// first output for concatenation of hidden states
auto
hidden_states
=
prog
.
add_instruction
(
op
::
lstm
{
hidden_size
,
vec_actv_funcs
,
dirct
,
clip
,
input_forget
},
std
::
move
(
args
));
// second output for last lstm output
auto
last_output
=
prog
.
add_instruction
(
op
::
rnn_last_output
{},
hidden_states
);
// third output for last cell output
auto
last_cell_output
=
prog
.
add_instruction
(
op
::
lstm_last_cell_output
{},
hidden_states
);
return
{
hidden_states
,
last_output
,
last_cell_output
};
}
void
parse_from
(
std
::
istream
&
is
)
void
parse_from
(
std
::
istream
&
is
)
{
{
onnx
::
ModelProto
model
;
onnx
::
ModelProto
model
;
...
@@ -960,9 +1150,9 @@ struct onnx_parser
...
@@ -960,9 +1150,9 @@ struct onnx_parser
instructions
[
name
]
=
prog
.
add_parameter
(
name
,
s
);
instructions
[
name
]
=
prog
.
add_parameter
(
name
,
s
);
}
}
}
}
for
(
auto
&&
p
:
nodes
)
for
(
auto
&&
output
:
graph
.
output
()
)
{
{
this
->
parse_node
(
p
.
first
);
this
->
parse_node
(
output
.
name
()
);
}
}
}
}
...
...
src/program.cpp
View file @
c984fd24
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/target.hpp>
#include <migraphx/env.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
#include <migraphx/time.hpp>
...
...
src/rewrite_rnn.cpp
View file @
c984fd24
This diff is collapsed.
Click to expand it.
src/shape.cpp
View file @
c984fd24
...
@@ -19,7 +19,7 @@ struct shape_impl
...
@@ -19,7 +19,7 @@ struct shape_impl
shape_impl
()
:
m_type
(
shape
::
float_type
),
m_standard
(
false
)
{}
shape_impl
()
:
m_type
(
shape
::
float_type
),
m_standard
(
false
)
{}
shape_impl
(
shape
::
type_t
t
)
:
m_type
(
t
),
m_lens
({
1
}),
m_strides
({
1
}),
m_standard
(
true
)
{}
shape_impl
(
shape
::
type_t
t
)
:
m_type
(
t
),
m_lens
({
1
}),
m_strides
({
0
}),
m_standard
(
true
)
{}
shape_impl
(
shape
::
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
shape_impl
(
shape
::
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
m_type
(
t
),
m_lens
(
std
::
move
(
l
)),
m_standard
(
true
)
:
m_type
(
t
),
m_lens
(
std
::
move
(
l
)),
m_standard
(
true
)
{
{
...
...
src/targets/cpu/gemm.cpp
View file @
c984fd24
#include <migraphx/cpu/gemm.hpp>
#include <migraphx/cpu/gemm.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/shape_for_each.hpp>
#include <blaze/math/CustomMatrix.h>
#include <blaze/math/CustomMatrix.h>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -14,10 +15,13 @@ template <class T>
...
@@ -14,10 +15,13 @@ template <class T>
static
auto
make_mat
(
tensor_view
<
T
>
x
)
static
auto
make_mat
(
tensor_view
<
T
>
x
)
{
{
const
auto
&
s
=
x
.
get_shape
();
const
auto
&
s
=
x
.
get_shape
();
assert
(
s
.
lens
().
size
()
==
2
);
// assert(s.lens().size() == 2);
std
::
size_t
n_dims
=
s
.
lens
().
size
();
std
::
size_t
dim_0
=
n_dims
-
2
;
std
::
size_t
dim_1
=
n_dims
-
1
;
if
(
s
.
transposed
())
if
(
s
.
transposed
())
return
matrix
<
T
>
{
x
.
data
(),
s
.
lens
()[
1
],
s
.
lens
()[
0
],
s
.
strides
()[
1
]};
return
matrix
<
T
>
{
x
.
data
(),
s
.
lens
()[
dim_
1
],
s
.
lens
()[
dim_
0
],
s
.
strides
()[
dim_
1
]};
return
matrix
<
T
>
{
x
.
data
(),
s
.
lens
()[
0
],
s
.
lens
()[
1
],
s
.
strides
()[
0
]};
return
matrix
<
T
>
{
x
.
data
(),
s
.
lens
()[
dim_
0
],
s
.
lens
()[
dim_
1
],
s
.
strides
()[
dim_
0
]};
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -64,18 +68,24 @@ void migemm_impl(tensor_view<T> cmat,
...
@@ -64,18 +68,24 @@ void migemm_impl(tensor_view<T> cmat,
float
beta
,
float
beta
,
std
::
false_type
)
std
::
false_type
)
{
{
auto
m
=
cmat
.
get_shape
().
lens
()[
0
];
std
::
size_t
n_dims
=
cmat
.
get_shape
().
lens
().
size
();
auto
n
=
cmat
.
get_shape
().
lens
()[
1
];
std
::
size_t
dim_0
=
n_dims
-
2
;
auto
k
=
amat
.
get_shape
().
lens
()[
1
];
std
::
size_t
dim_1
=
n_dims
-
1
;
auto
k
=
amat
.
get_shape
().
lens
()[
dim_1
];
assert
(
amat
.
get_shape
().
lens
()[
1
]
==
bmat
.
get_shape
().
lens
()[
0
]);
assert
(
amat
.
get_shape
().
lens
()[
dim_
1
]
==
bmat
.
get_shape
().
lens
()[
dim_
0
]);
assert
(
m
==
amat
.
get_shape
().
lens
()[
0
]);
assert
(
cmat
.
get_shape
().
lens
()[
dim_0
]
==
amat
.
get_shape
().
lens
()[
dim_
0
]);
assert
(
n
==
bmat
.
get_shape
().
lens
()[
1
]);
assert
(
cmat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_
1
]);
dfor
(
m
,
n
)([
&
](
auto
ii
,
auto
jj
)
{
shape_for_each
(
cmat
.
get_shape
(),
[
&
](
const
auto
&
c_idx
)
{
double
s
=
cmat
(
ii
,
jj
)
*
beta
;
auto
a_idx
=
c_idx
;
dfor
(
k
)([
&
](
auto
kk
)
{
s
+=
amat
(
ii
,
kk
)
*
bmat
(
kk
,
jj
);
});
auto
b_idx
=
c_idx
;
cmat
(
ii
,
jj
)
=
alpha
*
s
;
double
s
=
0.0
;
dfor
(
k
)([
&
](
auto
kk
)
{
a_idx
[
dim_1
]
=
b_idx
[
dim_0
]
=
kk
;
s
+=
amat
(
a_idx
.
begin
(),
a_idx
.
end
())
*
bmat
(
b_idx
.
begin
(),
b_idx
.
end
());
});
cmat
(
c_idx
.
begin
(),
c_idx
.
end
())
=
alpha
*
s
+
cmat
(
c_idx
.
begin
(),
c_idx
.
end
())
*
beta
;
});
});
}
}
...
@@ -83,7 +93,18 @@ template <class T>
...
@@ -83,7 +93,18 @@ template <class T>
void
migemm_impl
(
void
migemm_impl
(
tensor_view
<
T
>
cmat
,
tensor_view
<
T
>
amat
,
tensor_view
<
T
>
bmat
,
float
alpha
,
float
beta
)
tensor_view
<
T
>
cmat
,
tensor_view
<
T
>
amat
,
tensor_view
<
T
>
bmat
,
float
alpha
,
float
beta
)
{
{
auto
lens
=
amat
.
get_shape
().
lens
();
bool
batch_mul
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
())
==
(
*
lens
.
rbegin
())
*
(
*
(
lens
.
rbegin
()
+
1
));
if
(
batch_mul
)
{
migemm_impl
(
cmat
,
amat
,
bmat
,
alpha
,
beta
,
is_fast_gemm_type
<
T
>
{});
migemm_impl
(
cmat
,
amat
,
bmat
,
alpha
,
beta
,
is_fast_gemm_type
<
T
>
{});
}
else
{
migemm_impl
(
cmat
,
amat
,
bmat
,
alpha
,
beta
,
std
::
false_type
{});
}
}
}
void
migemm
(
void
migemm
(
...
...
src/targets/cpu/include/migraphx/cpu/target.hpp
View file @
c984fd24
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
pass
;
namespace
cpu
{
namespace
cpu
{
struct
target
struct
target
...
...
src/targets/cpu/target.cpp
View file @
c984fd24
#include <migraphx/cpu/target.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/cpu/lowering.hpp>
#include <migraphx/cpu/lowering.hpp>
#include <migraphx/pass.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
...
...
src/targets/gpu/abs.cpp
View file @
c984fd24
#include <migraphx/gpu/abs.hpp>
#include <migraphx/gpu/abs.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/targets/gpu/batchnorm.cpp
View file @
c984fd24
#include <migraphx/gpu/batchnorm.hpp>
#include <migraphx/gpu/batchnorm.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/targets/gpu/concat.cpp
View file @
c984fd24
#include <migraphx/gpu/concat.hpp>
#include <migraphx/gpu/concat.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/device/concat.hpp>
#include <migraphx/gpu/device/concat.hpp>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/targets/gpu/contiguous.cpp
View file @
c984fd24
#include <migraphx/gpu/contiguous.hpp>
#include <migraphx/gpu/contiguous.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/targets/gpu/convolution.cpp
View file @
c984fd24
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/targets/gpu/device/gather.cpp
View file @
c984fd24
...
@@ -16,20 +16,24 @@ argument gather(hipStream_t stream,
...
@@ -16,20 +16,24 @@ argument gather(hipStream_t stream,
std
::
vector
<
migraphx
::
argument
>
args
,
std
::
vector
<
migraphx
::
argument
>
args
,
int
axis
)
int
axis
)
{
{
int
axis_index
=
(
axis
<
0
)
?
(
axis
+
outpu
t_shape
.
lens
().
size
())
:
axis
;
int
axis_index
=
(
axis
<
0
)
?
(
axis
+
args
[
0
].
ge
t_shape
()
.
lens
().
size
())
:
axis
;
visit_all
(
args
.
back
(),
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
visit_all
(
args
.
back
(),
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
std
::
size_t
nelements
=
output_shape
.
elements
();
std
::
size_t
nelements
=
output_shape
.
elements
();
args
[
1
].
visit
([
&
](
auto
indices
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
visit_tensor_size
(
output_shape
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
auto
*
outptr
=
device_cast
(
output
.
data
());
auto
*
out_ptr
=
device_cast
(
output
.
data
());
const
auto
*
inptr
=
device_cast
(
input
.
data
());
const
auto
*
in_ptr
=
device_cast
(
input
.
data
());
hip_tensor_descriptor
<
ndim
>
desc_input
(
input
.
get_shape
());
auto
&
input_shape
=
args
[
0
].
get_shape
();
hip_tensor_descriptor
<
ndim
>
desc_output
(
output
.
get_shape
());
auto
lens
=
input_shape
.
lens
();
gs_launch
(
stream
,
nelements
)([
=
](
auto
i
)
{
lens
[
axis_index
]
=
args
[
1
].
get_shape
().
elements
();
auto
lens
=
desc_output
.
multi
(
i
);
migraphx
::
shape
out_comp_shape
{
output_shape
.
type
(),
lens
};
lens
[
axis_index
]
=
indices_ptr
[
lens
[
axis_index
]];
visit_tensor_size
(
out_comp_shape
.
lens
().
size
(),
[
&
](
auto
n_out_dim
)
{
outptr
[
i
]
=
inptr
[
desc_input
.
linear
(
lens
)];
hip_tensor_descriptor
<
n_out_dim
>
desc_input
(
input_shape
);
hip_tensor_descriptor
<
n_out_dim
>
desc_output
(
out_comp_shape
);
gs_launch
(
stream
,
nelements
)([
=
](
auto
ii
)
{
auto
in_idx
=
desc_output
.
multi
(
ii
);
in_idx
[
axis_index
]
=
indices_ptr
[
in_idx
[
axis_index
]];
out_ptr
[
ii
]
=
in_ptr
[
desc_input
.
linear
(
in_idx
)];
});
});
});
});
});
});
...
...
src/targets/gpu/elu.cpp
View file @
c984fd24
#include <migraphx/gpu/elu.hpp>
#include <migraphx/gpu/elu.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/targets/gpu/fuse_ops.cpp
View file @
c984fd24
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
...
src/targets/gpu/gather.cpp
View file @
c984fd24
#include <migraphx/gpu/gather.hpp>
#include <migraphx/gpu/gather.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/device/gather.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/device/concat.hpp>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment