Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
511c8d8f
Commit
511c8d8f
authored
Jun 25, 2019
by
Paul
Browse files
Merge from develop
parents
9b7c44ab
2a2c146c
Changes
62
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
612 additions
and
626 deletions
+612
-626
src/include/migraphx/ranges.hpp
src/include/migraphx/ranges.hpp
+22
-0
src/include/migraphx/raw_data.hpp
src/include/migraphx/raw_data.hpp
+19
-0
src/include/migraphx/requires.hpp
src/include/migraphx/requires.hpp
+7
-24
src/include/migraphx/streamutils.hpp
src/include/migraphx/streamutils.hpp
+2
-0
src/include/migraphx/stringutils.hpp
src/include/migraphx/stringutils.hpp
+2
-0
src/onnx/CMakeLists.txt
src/onnx/CMakeLists.txt
+1
-1
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+12
-4
src/py/CMakeLists.txt
src/py/CMakeLists.txt
+1
-6
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+7
-7
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+94
-173
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+54
-48
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+1
-0
src/targets/gpu/device/concat.cpp
src/targets/gpu/device/concat.cpp
+10
-12
src/targets/gpu/device/gather.cpp
src/targets/gpu/device/gather.cpp
+18
-23
src/targets/gpu/device/include/migraphx/gpu/device/array.hpp
src/targets/gpu/device/include/migraphx/gpu/device/array.hpp
+60
-0
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
+95
-271
src/targets/gpu/device/include/migraphx/gpu/device/shape.hpp
src/targets/gpu/device/include/migraphx/gpu/device/shape.hpp
+89
-0
src/targets/gpu/device/include/migraphx/gpu/device/tensor.hpp
...targets/gpu/device/include/migraphx/gpu/device/tensor.hpp
+13
-54
src/targets/gpu/device/include/migraphx/gpu/device/tensor_view.hpp
...ts/gpu/device/include/migraphx/gpu/device/tensor_view.hpp
+59
-0
src/targets/gpu/device/include/migraphx/gpu/device/types.hpp
src/targets/gpu/device/include/migraphx/gpu/device/types.hpp
+46
-3
No files found.
src/include/migraphx/ranges.hpp
View file @
511c8d8f
...
@@ -33,6 +33,10 @@ auto generic_find_impl(rank<0>, C&& c, const T& x)
...
@@ -33,6 +33,10 @@ auto generic_find_impl(rank<0>, C&& c, const T& x)
return
std
::
find
(
c
.
begin
(),
c
.
end
(),
x
);
return
std
::
find
(
c
.
begin
(),
c
.
end
(),
x
);
}
}
struct
empty
{
};
}
// namespace detail
}
// namespace detail
template
<
class
C
,
class
T
>
template
<
class
C
,
class
T
>
...
@@ -71,6 +75,12 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p)
...
@@ -71,6 +75,12 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p)
return
std
::
all_of
(
c
.
begin
(),
c
.
end
(),
p
);
return
std
::
all_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
}
template
<
class
Predicate
>
bool
all_of
(
detail
::
empty
,
const
Predicate
&
)
{
return
true
;
}
template
<
class
C
,
class
Predicate
>
template
<
class
C
,
class
Predicate
>
bool
any_of
(
const
C
&
c
,
const
Predicate
&
p
)
bool
any_of
(
const
C
&
c
,
const
Predicate
&
p
)
{
{
...
@@ -83,6 +93,12 @@ bool any_of(const std::initializer_list<T>& c, const Predicate& p)
...
@@ -83,6 +93,12 @@ bool any_of(const std::initializer_list<T>& c, const Predicate& p)
return
std
::
any_of
(
c
.
begin
(),
c
.
end
(),
p
);
return
std
::
any_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
}
template
<
class
Predicate
>
bool
any_of
(
detail
::
empty
,
const
Predicate
&
)
{
return
false
;
}
template
<
class
C
,
class
Predicate
>
template
<
class
C
,
class
Predicate
>
bool
none_of
(
const
C
&
c
,
const
Predicate
&
p
)
bool
none_of
(
const
C
&
c
,
const
Predicate
&
p
)
{
{
...
@@ -95,6 +111,12 @@ bool none_of(const std::initializer_list<T>& c, const Predicate& p)
...
@@ -95,6 +111,12 @@ bool none_of(const std::initializer_list<T>& c, const Predicate& p)
return
std
::
none_of
(
c
.
begin
(),
c
.
end
(),
p
);
return
std
::
none_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
}
template
<
class
Predicate
>
bool
none_of
(
detail
::
empty
,
const
Predicate
&
)
{
return
true
;
}
template
<
class
Range
,
class
Iterator
>
template
<
class
Range
,
class
Iterator
>
void
copy
(
Range
&&
r
,
Iterator
it
)
void
copy
(
Range
&&
r
,
Iterator
it
)
{
{
...
...
src/include/migraphx/raw_data.hpp
View file @
511c8d8f
...
@@ -212,6 +212,25 @@ auto visit_all(T&& x, Ts&&... xs)
...
@@ -212,6 +212,25 @@ auto visit_all(T&& x, Ts&&... xs)
};
};
}
}
template
<
class
T
>
auto
visit_all
(
const
std
::
vector
<
T
>&
x
)
{
auto
&&
s
=
x
.
front
().
get_shape
();
if
(
!
std
::
all_of
(
x
.
begin
(),
x
.
end
(),
[
&
](
const
T
&
y
)
{
return
y
.
get_shape
().
type
()
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
return
[
&
](
auto
v
)
{
s
.
visit_type
([
&
](
auto
as
)
{
using
type
=
typename
decltype
(
as
)
::
type
;
std
::
vector
<
tensor_view
<
type
>>
result
;
std
::
transform
(
x
.
begin
(),
x
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
y
)
{
return
make_view
(
y
.
get_shape
(),
as
.
from
(
y
.
data
()));
});
v
(
result
);
});
};
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/requires.hpp
View file @
511c8d8f
...
@@ -15,35 +15,18 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT
...
@@ -15,35 +15,18 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT
template
<
bool
B
>
template
<
bool
B
>
using
bool_c
=
std
::
integral_constant
<
bool
,
B
>
;
using
bool_c
=
std
::
integral_constant
<
bool
,
B
>
;
template
<
int
N
>
#define MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y) x##y
struct
requires_enum
#define MIGRAPHX_REQUIRES_CAT(x, y) MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y)
{
enum
e
{
a
=
0
};
};
#define MIGRAPHX_REQUIRES_
CAT(x, y) x##y
#define MIGRAPHX_REQUIRES_
VAR() MIGRAPHX_REQUIRES_CAT(PrivateRequires, __LINE__)
#ifdef CPPCHECK
#ifdef CPPCHECK
#define MIGRAPHX_REQUIRES(...) class = void
#define MIGRAPHX_REQUIRES(...) class = void
#else
#else
#if 0
#define MIGRAPHX_REQUIRES(...) \
// TODO: This currently crashed on clang
bool MIGRAPHX_REQUIRES_VAR() = true, \
#define MIGRAPHX_REQUIRES(...) \
typename std::enable_if<(MIGRAPHX_REQUIRES_VAR() && (migraphx::and_<__VA_ARGS__>{})), \
typename migraphx::requires_enum<__LINE__>::e MIGRAPHX_REQUIRES_CAT( \
int>::type = 0
PrivateRequires, \
__LINE__) = migraphx::requires_enum<__LINE__>::a, \
class = typename std::enable_if<and_<__VA_ARGS__, \
MIGRAPHX_REQUIRES_CAT(PrivateRequires, __LINE__) == \
migraphx::requires_enum<__LINE__>::a>{}>::type
#else
#define MIGRAPHX_REQUIRES(...) \
typename migraphx::requires_enum<__LINE__>::e MIGRAPHX_REQUIRES_CAT( \
PrivateRequires, __LINE__) = migraphx::requires_enum<__LINE__>::a, \
class = typename std::enable_if<and_<__VA_ARGS__>{}>::type
#endif
#endif
#endif
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/streamutils.hpp
View file @
511c8d8f
...
@@ -42,7 +42,9 @@ template <class Range>
...
@@ -42,7 +42,9 @@ template <class Range>
auto
stream_write_value_impl
(
rank
<
1
>
,
std
::
ostream
&
os
,
const
Range
&
r
)
auto
stream_write_value_impl
(
rank
<
1
>
,
std
::
ostream
&
os
,
const
Range
&
r
)
->
decltype
(
r
.
begin
(),
r
.
end
(),
void
())
->
decltype
(
r
.
begin
(),
r
.
end
(),
void
())
{
{
os
<<
"{"
;
os
<<
stream_range
(
r
);
os
<<
stream_range
(
r
);
os
<<
"}"
;
}
}
template
<
class
T
>
template
<
class
T
>
...
...
src/include/migraphx/stringutils.hpp
View file @
511c8d8f
...
@@ -52,6 +52,8 @@ inline std::string transform_string(std::string s, F f)
...
@@ -52,6 +52,8 @@ inline std::string transform_string(std::string s, F f)
inline
std
::
string
to_upper
(
std
::
string
s
)
{
return
transform_string
(
std
::
move
(
s
),
::
toupper
);
}
inline
std
::
string
to_upper
(
std
::
string
s
)
{
return
transform_string
(
std
::
move
(
s
),
::
toupper
);
}
inline
std
::
string
to_lower
(
std
::
string
s
)
{
return
transform_string
(
std
::
move
(
s
),
::
tolower
);
}
inline
bool
starts_with
(
const
std
::
string
&
value
,
const
std
::
string
&
prefix
)
inline
bool
starts_with
(
const
std
::
string
&
value
,
const
std
::
string
&
prefix
)
{
{
if
(
prefix
.
size
()
>
value
.
size
())
if
(
prefix
.
size
()
>
value
.
size
())
...
...
src/onnx/CMakeLists.txt
View file @
511c8d8f
...
@@ -19,7 +19,7 @@ rocm_install_targets(
...
@@ -19,7 +19,7 @@ rocm_install_targets(
add_executable
(
read_onnx read_onnx.cpp
)
add_executable
(
read_onnx read_onnx.cpp
)
rocm_clang_tidy_check
(
read_onnx
)
rocm_clang_tidy_check
(
read_onnx
)
target_link_libraries
(
read_onnx migraphx_onnx
)
target_link_libraries
(
read_onnx
migraphx_cpu
migraphx_onnx
)
if
(
MIGRAPHX_ENABLE_GPU
)
if
(
MIGRAPHX_ENABLE_GPU
)
...
...
src/onnx/onnx.cpp
View file @
511c8d8f
...
@@ -100,6 +100,7 @@ struct onnx_parser
...
@@ -100,6 +100,7 @@ struct onnx_parser
void
init_actv_func
()
void
init_actv_func
()
{
{
// Support name format of all lower case or the first letter capital
map_actv_funcs
.
insert
(
std
::
make_pair
(
"tanh"
,
op
::
tanh
{}));
map_actv_funcs
.
insert
(
std
::
make_pair
(
"tanh"
,
op
::
tanh
{}));
map_actv_funcs
.
insert
(
std
::
make_pair
(
"relu"
,
op
::
relu
{}));
map_actv_funcs
.
insert
(
std
::
make_pair
(
"relu"
,
op
::
relu
{}));
map_actv_funcs
.
insert
(
std
::
make_pair
(
"sigmoid"
,
op
::
sigmoid
{}));
map_actv_funcs
.
insert
(
std
::
make_pair
(
"sigmoid"
,
op
::
sigmoid
{}));
...
@@ -352,7 +353,8 @@ struct onnx_parser
...
@@ -352,7 +353,8 @@ struct onnx_parser
{
{
// insert zeros for pad op (args[0] has 4 dims)
// insert zeros for pad op (args[0] has 4 dims)
padding
=
{
0
,
0
,
padding
[
0
],
padding
[
1
],
0
,
0
,
padding
[
2
],
padding
[
3
]};
padding
=
{
0
,
0
,
padding
[
0
],
padding
[
1
],
0
,
0
,
padding
[
2
],
padding
[
3
]};
l0
=
prog
.
add_instruction
(
op
::
pad
{
padding
},
l0
);
l0
=
prog
.
add_instruction
(
op
::
pad
{
padding
,
std
::
numeric_limits
<
float
>::
lowest
()},
l0
);
}
}
else
else
{
{
...
@@ -870,7 +872,9 @@ struct onnx_parser
...
@@ -870,7 +872,9 @@ struct onnx_parser
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
vec_names
.
clear
();
vec_names
.
resize
(
names
.
size
());
vec_names
.
resize
(
names
.
size
());
std
::
copy
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
(),
[](
auto
name
)
{
return
to_lower
(
name
);
});
}
}
auto
name_it
=
std
::
find_if
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
auto
name_it
=
std
::
find_if
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
...
@@ -961,7 +965,9 @@ struct onnx_parser
...
@@ -961,7 +965,9 @@ struct onnx_parser
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
vec_names
.
clear
();
vec_names
.
resize
(
names
.
size
());
vec_names
.
resize
(
names
.
size
());
std
::
copy
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
(),
[](
auto
name
)
{
return
to_lower
(
name
);
});
}
}
// need 4 activation functions
// need 4 activation functions
...
@@ -1088,7 +1094,9 @@ struct onnx_parser
...
@@ -1088,7 +1094,9 @@ struct onnx_parser
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
vec_names
.
clear
();
vec_names
.
resize
(
names
.
size
());
vec_names
.
resize
(
names
.
size
());
std
::
copy
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
(),
[](
auto
name
)
{
return
to_lower
(
name
);
});
}
}
// need 6 activation functions for bidirectional directions
// need 6 activation functions for bidirectional directions
...
...
src/py/CMakeLists.txt
View file @
511c8d8f
...
@@ -12,12 +12,7 @@ if(MIGRAPHX_ENABLE_PYTHON)
...
@@ -12,12 +12,7 @@ if(MIGRAPHX_ENABLE_PYTHON)
C_VISIBILITY_PRESET hidden
C_VISIBILITY_PRESET hidden
CXX_VISIBILITY_PRESET hidden
CXX_VISIBILITY_PRESET hidden
)
)
if
(
MIGRAPHX_ENABLE_TF
)
target_link_libraries
(
migraphx_py PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_cpu
)
target_link_libraries
(
migraphx_py PRIVATE migraphx migraphx_tf migraphx_cpu
)
target_compile_definitions
(
migraphx_py PRIVATE -DENABLE_TF
)
else
()
target_link_libraries
(
migraphx_py PRIVATE migraphx migraphx_onnx migraphx_cpu
)
endif
()
if
(
MIGRAPHX_ENABLE_GPU
)
if
(
MIGRAPHX_ENABLE_GPU
)
target_link_libraries
(
migraphx_py PRIVATE migraphx_gpu
)
target_link_libraries
(
migraphx_py PRIVATE migraphx_gpu
)
target_compile_definitions
(
migraphx_py PRIVATE -DHAVE_GPU
)
target_compile_definitions
(
migraphx_py PRIVATE -DHAVE_GPU
)
...
...
src/py/migraphx_py.cpp
View file @
511c8d8f
...
@@ -6,11 +6,9 @@
...
@@ -6,11 +6,9 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#ifdef ENABLE_TF
#include <migraphx/tf.hpp>
#include <migraphx/tf.hpp>
#else
#include <migraphx/onnx.hpp>
#include <migraphx/onnx.hpp>
#
endif
#
include <migraphx/type_name.hpp>
#ifdef HAVE_GPU
#ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/target.hpp>
...
@@ -104,8 +102,13 @@ migraphx::shape to_shape(const py::buffer_info& info)
...
@@ -104,8 +102,13 @@ migraphx::shape to_shape(const py::buffer_info& info)
t
=
as
.
type_enum
();
t
=
as
.
type_enum
();
n
=
sizeof
(
as
());
n
=
sizeof
(
as
());
}
}
});
});
if
(
n
==
0
)
{
MIGRAPHX_THROW
(
"MIGRAPHX PYTHON: Unsupported data type"
+
info
.
format
);
}
auto
strides
=
info
.
strides
;
auto
strides
=
info
.
strides
;
std
::
transform
(
strides
.
begin
(),
strides
.
end
(),
strides
.
begin
(),
[
&
](
auto
i
)
->
std
::
size_t
{
std
::
transform
(
strides
.
begin
(),
strides
.
end
(),
strides
.
begin
(),
[
&
](
auto
i
)
->
std
::
size_t
{
return
n
>
0
?
i
/
n
:
0
;
return
n
>
0
?
i
/
n
:
0
;
...
@@ -161,16 +164,13 @@ PYBIND11_MODULE(migraphx, m)
...
@@ -161,16 +164,13 @@ PYBIND11_MODULE(migraphx, m)
.
def
(
"__ne__"
,
std
::
not_equal_to
<
migraphx
::
program
>
{})
.
def
(
"__ne__"
,
std
::
not_equal_to
<
migraphx
::
program
>
{})
.
def
(
"__repr__"
,
[](
const
migraphx
::
program
&
p
)
{
return
migraphx
::
to_string
(
p
);
});
.
def
(
"__repr__"
,
[](
const
migraphx
::
program
&
p
)
{
return
migraphx
::
to_string
(
p
);
});
#ifdef ENABLE_TF
m
.
def
(
"parse_tf"
,
m
.
def
(
"parse_tf"
,
&
migraphx
::
parse_tf
,
&
migraphx
::
parse_tf
,
"Parse tf protobuf (default format is nhwc)"
,
"Parse tf protobuf (default format is nhwc)"
,
py
::
arg
(
"filename"
),
py
::
arg
(
"filename"
),
py
::
arg
(
"is_nhwc"
)
=
true
);
py
::
arg
(
"is_nhwc"
)
=
true
);
#else
m
.
def
(
"parse_onnx"
,
&
migraphx
::
parse_onnx
);
m
.
def
(
"parse_onnx"
,
&
migraphx
::
parse_onnx
);
#endif
m
.
def
(
"get_target"
,
[](
const
std
::
string
&
name
)
->
migraphx
::
target
{
m
.
def
(
"get_target"
,
[](
const
std
::
string
&
name
)
->
migraphx
::
target
{
if
(
name
==
"cpu"
)
if
(
name
==
"cpu"
)
return
migraphx
::
cpu
::
target
{};
return
migraphx
::
cpu
::
target
{};
...
...
src/rewrite_rnn.cpp
View file @
511c8d8f
...
@@ -204,17 +204,19 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
...
@@ -204,17 +204,19 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
auto
tran_sr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
sr
);
auto
tran_sr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
sr
);
// initial hidden state
// initial hidden state
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
auto
sih_lens
=
sih
->
get_shape
().
lens
();
// bias
// bias
instruction_ref
bb
{};
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
long
hs
=
r
->
get_shape
().
lens
()[
2
];
long
hs
=
static_cast
<
long
>
(
r
->
get_shape
().
lens
()[
2
]
)
;
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
b
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wb
,
rb
);
auto
wrb
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wb
,
rb
);
b
ias
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
().
lens
()
},
b
);
b
b
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
_
lens
},
wr
b
);
}
}
instruction_ref
hidden_out
=
prog
.
end
();
instruction_ref
hidden_out
=
prog
.
end
();
...
@@ -228,20 +230,15 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
...
@@ -228,20 +230,15 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
auto
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_sw
);
auto
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_sw
);
auto
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_sr
);
auto
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_sr
);
auto
xt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
instruction_ref
ht
;
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_ht
,
bias
);
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
bb
);
}
else
{
ht
=
xt_ht
;
}
}
auto
xt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
// apply activation function
// apply activation function
ht
=
prog
.
insert_instruction
(
ins
,
actv_func
,
ht
);
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func
,
xt_
ht
);
sih
=
ht
;
sih
=
ht
;
// add the dimensions of sequence length (axis 0 for sequence length,
// add the dimensions of sequence length (axis 0 for sequence length,
// axis 1 for num_directions
// axis 1 for num_directions
...
@@ -485,62 +482,41 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
...
@@ -485,62 +482,41 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
long
hs
=
static_cast
<
long
>
(
r_shape
.
lens
()[
2
]);
long
hs
=
static_cast
<
long
>
(
r_shape
.
lens
()[
2
]);
migraphx
::
shape
s
(
seq_shape
.
type
(),
{
seq_shape
.
lens
()[
1
],
r_shape
.
lens
()[
2
]});
migraphx
::
shape
s
(
seq_shape
.
type
(),
{
seq_shape
.
lens
()[
1
],
r_shape
.
lens
()[
2
]});
std
::
vector
<
in
t
>
data
(
s
.
elements
(),
1
);
std
::
vector
<
floa
t
>
data
(
s
.
elements
(),
1
.0
f
);
auto
l1
=
prog
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
auto
l1
=
prog
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
// w
eight
matrix
// w matrix
squeeze to 2-dim and do a transpose
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
wz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sw
);
auto
tw
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
sw
);
auto
tran_wz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wz
);
auto
wr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sw
);
// r slide to two part, zr and h
auto
tran_wr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wr
);
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
rzr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
2
*
hs
}},
sr
);
auto
trzr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rzr
);
auto
wh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sw
);
auto
rh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sr
);
auto
tran_wh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wh
);
auto
trh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rh
);
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
rz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sr
);
auto
tran_rz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rz
);
auto
rr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sr
);
auto
tran_rr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rr
);
auto
rh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sr
);
auto
tran_rh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rh
);
// initial states
// initial states
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
size_t
bs
=
ih
->
get_shape
().
lens
()[
1
];
// bias
// bias
instruction_ref
brcst_bz
{};
instruction_ref
bwb
{};
instruction_ref
brcst_br
{};
instruction_ref
brb_zr
{};
instruction_ref
brcst_wbh
{};
instruction_ref
brb_h
{};
instruction_ref
brcst_rbh
{};
instruction_ref
brcst_bh
{};
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
auto
broadcast_lens
=
sih
->
get_shape
().
lens
();
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
3
*
hs
}},
sbias
);
auto
wbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
bwb
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
{
bs
,
static_cast
<
size_t
>
(
3
*
hs
)}},
wb
);
auto
wbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
wbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sbias
);
auto
rb_zr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
5
*
hs
}},
sbias
);
brcst_wbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
broadcast_lens
},
wbh
);
auto
rb_h
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
brb_zr
=
prog
.
insert_instruction
(
auto
rbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sbias
);
ins
,
op
::
broadcast
{
1
,
{
bs
,
static_cast
<
size_t
>
(
2
*
hs
)}},
rb_zr
);
auto
rbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
sbias
);
brb_h
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
{
bs
,
static_cast
<
size_t
>
(
hs
)}},
rb_h
);
auto
rbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
brcst_rbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
broadcast_lens
},
rbh
);
auto
bz
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbz
,
rbz
);
brcst_bz
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
broadcast_lens
},
bz
);
auto
br
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbr
,
rbr
);
brcst_br
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
broadcast_lens
},
br
);
auto
bh
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbh
,
rbh
);
brcst_bh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
broadcast_lens
},
bh
);
}
}
for
(
long
i
=
0
;
i
<
seq_len
;
i
++
)
for
(
long
i
=
0
;
i
<
seq_len
;
i
++
)
...
@@ -549,56 +525,58 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
...
@@ -549,56 +525,58 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
seq
);
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
seq
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto
xt_w
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tw
);
auto
xt_wz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wz
);
auto
ih1_rzr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
trzr
);
auto
ht_rz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rz
);
auto
xht_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wz
,
ht_rz
);
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
xht_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_z
,
brcst_bz
);
xt_w
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_w
,
bwb
);
ih1_rzr
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ih1_rzr
,
brb_zr
);
}
}
auto
zt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xht_z
);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto
xw_z
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
0
},
{
hs
}},
xt_w
);
auto
x
t_
wr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran
_w
r
);
auto
xw
_
r
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
hs
},
{
2
*
hs
}},
xt
_w
);
auto
ht_rr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rr
);
auto
xw_h
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
2
*
hs
},
{
3
*
hs
}},
xt_w
);
auto
xht_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wr
,
ht_rr
);
if
(
bias
!=
prog
.
end
())
auto
hr_z
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
0
},
{
hs
}},
ih1_rzr
);
{
auto
hr_r
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
hs
},
{
2
*
hs
}},
ih1_rzr
);
xht_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_r
,
brcst_br
);
}
auto
xw_hr_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xw_z
,
hr_z
);
auto
rt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
x
ht_r
);
auto
zt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
x
w_hr_z
);
instruction_ref
xht_h
;
auto
xw_hr_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xw_r
,
hr_r
);
auto
rt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xw_hr_r
);
instruction_ref
hr_h
{};
if
(
linear_before_reset
==
0
)
if
(
linear_before_reset
==
0
)
{
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto
xt_wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wh
);
auto
rt_ht1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
sih
);
auto
rt_ht1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
sih
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht1
,
tran_rh
);
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wh
,
rt_rh
);
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_h
,
brcst_bh
);
hr_h
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht1
,
trh
,
brb_h
);
}
else
{
hr_h
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht1
,
trh
);
}
}
}
}
else
else
{
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto
xt_wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wh
);
instruction_ref
ht1_rh
{};
auto
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rh
);
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ht1_rh
,
brcst_rb
h
);
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
trh
,
brb_
h
);
}
}
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ht1_rh
);
else
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wh
,
rt_rh
);
if
(
bias
!=
prog
.
end
())
{
{
x
ht
_
h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_h
,
brcst_wb
h
);
ht
1_r
h
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tr
h
);
}
}
hr_h
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ht1_rh
);
}
}
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
xht_h
);
auto
xw_hr_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xw_h
,
hr_h
);
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
xw_hr_h
);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto
one_minus_zt
=
prog
.
insert_instruction
(
ins
,
op
::
sub
{},
l1
,
zt
);
auto
one_minus_zt
=
prog
.
insert_instruction
(
ins
,
op
::
sub
{},
l1
,
zt
);
...
@@ -913,35 +891,16 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
...
@@ -913,35 +891,16 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
migraphx
::
shape
r_shape
=
r
->
get_shape
();
migraphx
::
shape
r_shape
=
r
->
get_shape
();
long
seq_len
=
static_cast
<
long
>
(
seq_shape
.
lens
()[
0
]);
long
seq_len
=
static_cast
<
long
>
(
seq_shape
.
lens
()[
0
]);
long
hs
=
static_cast
<
long
>
(
r_shape
.
lens
()[
2
]);
long
hs
=
static_cast
<
long
>
(
r_shape
.
lens
()[
2
]);
auto
bs
=
ih
->
get_shape
().
lens
()[
1
];
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
// w matrix
// w matrix, squeeze and transpose
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
wi
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sw
);
auto
tsw
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
sw
);
auto
tran_wi
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wi
);
auto
wo
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sw
);
auto
tran_wo
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wo
);
auto
wf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sw
);
auto
tran_wf
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wf
);
auto
wc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sw
);
// r matrix, squeeze and transpose
auto
tran_wc
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wc
);
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
tsr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
sr
);
// r matrix
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
ri
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sr
);
auto
tran_ri
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
ri
);
auto
ro
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sr
);
auto
tran_ro
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
ro
);
auto
rf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sr
);
auto
tran_rf
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rf
);
auto
rc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sr
);
auto
tran_rc
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rc
);
// initial hidden state
// initial hidden state
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
...
@@ -951,40 +910,23 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
...
@@ -951,40 +910,23 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto
ic_lens
=
sic
->
get_shape
().
lens
();
auto
ic_lens
=
sic
->
get_shape
().
lens
();
// bias
// bias
instruction_ref
bi_brcst
{};
instruction_ref
wrb
{};
instruction_ref
bo_brcst
{};
instruction_ref
bf_brcst
{};
instruction_ref
bc_brcst
{};
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
bxi
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
ub_wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
4
*
hs
}},
sbias
);
auto
bhi
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
sbias
);
auto
ub_rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
8
*
hs
}},
sbias
);
auto
bi
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
bxi
,
bhi
);
auto
ub_wrb
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ub_wb
,
ub_rb
);
bi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_lens
},
bi
);
wrb
=
prog
.
insert_instruction
(
auto
bxo
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
ins
,
op
::
broadcast
{
1
,
{
bs
,
4
*
static_cast
<
size_t
>
(
hs
)}},
ub_wrb
);
auto
bho
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
auto
bo
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
bxo
,
bho
);
bo_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_lens
},
bo
);
auto
bxf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sbias
);
auto
bhf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
6
*
hs
},
{
7
*
hs
}},
sbias
);
auto
bf
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
bxf
,
bhf
);
bf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_lens
},
bf
);
auto
bxc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sbias
);
auto
bhc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
7
*
hs
},
{
8
*
hs
}},
sbias
);
auto
bc
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
bxc
,
bhc
);
bc_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_lens
},
bc
);
}
}
// peep hole
// peep hole
instruction_ref
pphi_brcst
{};
instruction_ref
pphi_brcst
{};
instruction_ref
ppho_brcst
{};
instruction_ref
ppho_brcst
{};
instruction_ref
pphf_brcst
{};
instruction_ref
pphf_brcst
{};
if
(
pph
!=
prog
.
end
())
if
(
pph
!=
prog
.
end
())
{
{
auto
spph
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
pph
);
auto
spph
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
pph
);
...
@@ -1004,44 +946,31 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
...
@@ -1004,44 +946,31 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
seq
);
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
seq
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
auto
xt_tsw
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tsw
);
auto
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wi
);
auto
sih_tsr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tsr
);
auto
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ri
);
auto
xt_sih
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_tsw
,
sih_tsr
);
auto
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
if
(
pph
!=
prog
.
end
())
{
auto
pphi_ct
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
pphi_brcst
,
sic
);
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
it_before_actv
,
pphi_ct
);
}
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
i
t_
before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
i
t_
before_actv
,
bi_brcst
);
x
t_
sih
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
t_
sih
,
wrb
);
}
}
auto
it
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
it_before_actv
);
// equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
auto
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
0
},
{
hs
}},
xt_sih
);
auto
xt_wf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wf
);
auto
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
hs
},
{
2
*
hs
}},
xt_sih
);
auto
ht_rf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rf
);
auto
ft_before_actv
=
auto
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wf
,
ht_rf
);
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
2
*
hs
},
{
3
*
hs
}},
xt_sih
);
auto
ct_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
3
*
hs
},
{
4
*
hs
}},
xt_sih
);
if
(
pph
!=
prog
.
end
())
if
(
pph
!=
prog
.
end
())
{
{
auto
pphi_ct
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
pphi_brcst
,
sic
);
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
it_before_actv
,
pphi_ct
);
auto
pphf_ct
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
pphf_brcst
,
sic
);
auto
pphf_ct
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
pphf_brcst
,
sic
);
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ft_before_actv
,
pphf_ct
);
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ft_before_actv
,
pphf_ct
);
}
}
if
(
bias
!=
prog
.
end
())
auto
it
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
it_before_actv
);
{
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ft_before_actv
,
bf_brcst
);
}
auto
ft
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
ft_before_actv
);
auto
ft
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
ft_before_actv
);
// equation ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
auto
xt_wc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wc
);
auto
ht_rc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rc
);
auto
ct_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wc
,
ht_rc
);
if
(
bias
!=
prog
.
end
())
{
ct_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ct_before_actv
,
bc_brcst
);
}
auto
ct
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
ct_before_actv
);
auto
ct
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
ct_before_actv
);
// equation Ct = ft (.) Ct-1 + it (.) ct
// equation Ct = ft (.) Ct-1 + it (.) ct
...
@@ -1050,19 +979,11 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
...
@@ -1050,19 +979,11 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto
cellt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ft_cell
,
it_ct
);
auto
cellt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ft_cell
,
it_ct
);
last_cell_output
=
cellt
;
last_cell_output
=
cellt
;
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
auto
xt_wo
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wo
);
auto
ht_ro
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ro
);
auto
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wo
,
ht_ro
);
if
(
pph
!=
prog
.
end
())
if
(
pph
!=
prog
.
end
())
{
{
auto
ppho_cellt
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
ppho_brcst
,
cellt
);
auto
ppho_cellt
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
ppho_brcst
,
cellt
);
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ot_before_actv
,
ppho_cellt
);
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ot_before_actv
,
ppho_cellt
);
}
}
if
(
bias
!=
prog
.
end
())
{
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ot_before_actv
,
bo_brcst
);
}
auto
ot
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
ot_before_actv
);
auto
ot
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
ot_before_actv
);
// Ht = ot (.) h(Ct)
// Ht = ot (.) h(Ct)
...
...
src/targets/cpu/lowering.cpp
View file @
511c8d8f
...
@@ -517,40 +517,60 @@ struct cpu_unary
...
@@ -517,40 +517,60 @@ struct cpu_unary
}
}
};
};
struct
softmax
2d
struct
cpu_
softmax
{
{
std
::
string
name
()
const
{
return
"cpu::softmax2d"
;
}
op
::
softmax
op
;
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
inputs
.
front
();
}
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
migraphx
::
reflect
(
self
.
op
,
f
);
}
std
::
string
name
()
const
{
return
"cpu::softmax"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
template
<
typename
T
>
std
::
size_t
compute_batch_index
(
T
idx
,
shape
&
batch_shape
,
int
axis
)
const
{
idx
[
axis
]
=
0
;
return
batch_shape
.
index
(
idx
);
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
auto
batch_lens
=
output_shape
.
lens
();
batch_lens
[
op
.
axis
]
=
1
;
shape
batch_shape
{
shape
::
int32_type
,
batch_lens
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
using
value_type
=
typename
decltype
(
input
)
::
value_type
;
using
value_type
=
typename
decltype
(
input
)
::
value_type
;
auto
nb
=
input
.
get_shape
().
lens
()[
0
];
std
::
vector
<
value_type
>
batch_max
(
batch_shape
.
elements
(),
auto
nc
=
input
.
get_shape
().
lens
()[
1
];
std
::
numeric_limits
<
value_type
>::
lowest
());
auto
nh
=
input
.
get_shape
().
lens
()[
2
];
shape_for_each
(
output_shape
,
[
&
](
auto
idx
)
{
auto
nw
=
input
.
get_shape
().
lens
()[
3
];
auto
index
=
this
->
compute_batch_index
(
idx
,
batch_shape
,
op
.
axis
);
dfor
(
nb
,
nh
,
nw
)([
&
](
std
::
size_t
b
,
std
::
size_t
i
,
std
::
size_t
j
)
{
batch_max
[
index
]
=
std
::
max
(
batch_max
[
index
],
input
(
idx
.
begin
(),
idx
.
end
()));
value_type
cmax
=
std
::
numeric_limits
<
value_type
>::
lowest
();
});
for
(
std
::
size_t
c
=
0
;
c
<
nc
;
c
++
)
{
shape_for_each
(
output_shape
,
[
&
](
auto
idx
)
{
cmax
=
std
::
max
(
cmax
,
input
(
b
,
c
,
i
,
j
));
auto
index
=
this
->
compute_batch_index
(
idx
,
batch_shape
,
op
.
axis
);
}
output
(
idx
.
begin
(),
idx
.
end
())
=
for
(
std
::
size_t
c
=
0
;
c
<
nc
;
c
++
)
std
::
exp
(
input
(
idx
.
begin
(),
idx
.
end
())
-
batch_max
[
index
]);
{
});
output
(
b
,
c
,
i
,
j
)
=
std
::
exp
(
input
(
b
,
c
,
i
,
j
)
-
cmax
);
}
std
::
vector
<
value_type
>
batch_sum
(
batch_shape
.
elements
(),
value_type
(
0
));
value_type
sum
=
value_type
(
0
);
shape_for_each
(
output_shape
,
[
&
](
auto
idx
)
{
for
(
std
::
size_t
c
=
0
;
c
<
nc
;
c
++
)
auto
index
=
this
->
compute_batch_index
(
idx
,
batch_shape
,
op
.
axis
);
{
batch_sum
[
index
]
+=
output
(
idx
.
begin
(),
idx
.
end
());
sum
+=
output
(
b
,
c
,
i
,
j
);
});
}
for
(
std
::
size_t
c
=
0
;
c
<
nc
;
c
++
)
shape_for_each
(
output_shape
,
[
&
](
auto
idx
)
{
{
auto
index
=
this
->
compute_batch_index
(
idx
,
batch_shape
,
op
.
axis
);
output
(
b
,
c
,
i
,
j
)
=
output
(
b
,
c
,
i
,
j
)
/
sum
;
output
(
idx
.
begin
(),
idx
.
end
())
/=
batch_sum
[
index
];
}
});
});
});
});
return
result
;
return
result
;
}
}
};
};
...
@@ -569,33 +589,19 @@ struct cpu_logsoftmax
...
@@ -569,33 +589,19 @@ struct cpu_logsoftmax
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
template
<
typename
T
>
template
<
typename
T
>
std
::
size_t
compute_batch_index
(
const
T
&
idx
,
shape
&
batch_shape
,
int
axis
)
const
std
::
size_t
compute_batch_index
(
T
idx
,
const
shape
&
batch_shape
,
int
axis
)
const
{
{
if
(
axis
==
0
)
idx
[
axis
]
=
0
;
{
return
batch_shape
.
index
(
idx
);
return
0
;
}
else
{
std
::
vector
<
std
::
size_t
>
batch_idx
(
idx
.
begin
(),
idx
.
begin
()
+
axis
);
return
batch_shape
.
index
(
batch_idx
.
begin
(),
batch_idx
.
end
());
}
}
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
auto
lens
=
output_shape
.
lens
();
auto
batch_lens
=
output_shape
.
lens
();
std
::
vector
<
std
::
size_t
>
batch_lens
{};
batch_lens
[
op
.
axis
]
=
1
;
if
(
op
.
axis
==
0
)
shape
batch_shape
{
shape
::
int32_type
,
batch_lens
};
{
batch_lens
.
push_back
(
1
);
}
else
{
batch_lens
.
insert
(
batch_lens
.
begin
(),
lens
.
begin
(),
lens
.
begin
()
+
op
.
axis
);
}
shape
batch_shape
{
migraphx
::
shape
::
uint32_type
,
batch_lens
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
using
value_type
=
typename
decltype
(
input
)
::
value_type
;
using
value_type
=
typename
decltype
(
input
)
::
value_type
;
std
::
vector
<
value_type
>
batch_max
(
batch_shape
.
elements
(),
std
::
vector
<
value_type
>
batch_max
(
batch_shape
.
elements
(),
...
@@ -660,7 +666,7 @@ struct cpu_apply
...
@@ -660,7 +666,7 @@ struct cpu_apply
apply_map
[
"logsoftmax"
]
=
extend_op
<
cpu_logsoftmax
,
op
::
logsoftmax
>
();
apply_map
[
"logsoftmax"
]
=
extend_op
<
cpu_logsoftmax
,
op
::
logsoftmax
>
();
apply_map
[
"lrn"
]
=
extend_op
<
cpu_lrn
,
op
::
lrn
>
();
apply_map
[
"lrn"
]
=
extend_op
<
cpu_lrn
,
op
::
lrn
>
();
apply_map
[
"pad"
]
=
extend_op
<
cpu_pad
,
op
::
pad
>
();
apply_map
[
"pad"
]
=
extend_op
<
cpu_pad
,
op
::
pad
>
();
apply_map
[
"softmax"
]
=
simple_op
<
softmax2d
>
();
apply_map
[
"softmax"
]
=
extend_op
<
cpu_softmax
,
op
::
softmax
>
();
}
}
void
apply
()
void
apply
()
...
...
src/targets/gpu/CMakeLists.txt
View file @
511c8d8f
...
@@ -27,6 +27,7 @@ add_library(migraphx_device
...
@@ -27,6 +27,7 @@ add_library(migraphx_device
device/add_relu.cpp
device/add_relu.cpp
device/contiguous.cpp
device/contiguous.cpp
device/logsoftmax.cpp
device/logsoftmax.cpp
device/softmax.cpp
device/convert.cpp
device/convert.cpp
device/mul.cpp
device/mul.cpp
device/concat.cpp
device/concat.cpp
...
...
src/targets/gpu/device/concat.cpp
View file @
511c8d8f
...
@@ -10,22 +10,20 @@ namespace gpu {
...
@@ -10,22 +10,20 @@ namespace gpu {
namespace
device
{
namespace
device
{
argument
concat
(
hipStream_t
stream
,
argument
concat
(
hipStream_t
stream
,
const
migraphx
::
shape
&
output_shape
,
const
migraphx
::
shape
&
,
std
::
vector
<
migraphx
::
argument
>
args
,
std
::
vector
<
migraphx
::
argument
>
args
,
std
::
vector
<
std
::
size_t
>
offsets
)
std
::
vector
<
std
::
size_t
>
offsets
)
{
{
for
(
std
::
size_t
l
=
0
;
l
<
args
.
size
()
-
1
;
l
++
)
auto
ninputs
=
args
.
size
()
-
1
;
for
(
std
::
size_t
j
=
0
;
j
<
ninputs
;
j
++
)
{
{
auto
argl
=
args
[
l
];
auto
&&
arg
=
args
[
j
];
std
::
size_t
nelements
=
argl
.
get_shape
().
elements
();
std
::
size_t
nelements
=
arg
.
get_shape
().
elements
();
visit_all
(
args
.
back
(),
argl
)([
&
](
auto
output
,
auto
input
)
{
auto
offset
=
offsets
[
j
];
visit_tensor_size
(
output_shape
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
hip_visit_all
(
args
.
back
(),
arg
)([
&
](
auto
output
,
auto
input
)
{
auto
*
outptr
=
output
.
data
()
+
offsets
[
l
];
gs_launch
(
stream
,
nelements
)([
=
](
auto
i
)
{
const
auto
*
inptr
=
input
.
data
();
auto
idx
=
output
.
get_shape
().
index
(
input
.
get_shape
().
multi
(
i
));
hip_tensor_descriptor
<
ndim
>
desc_input
(
input
.
get_shape
());
output
.
data
()[
idx
+
offset
]
=
input
.
data
()[
i
];
hip_tensor_descriptor
<
ndim
>
desc_output
(
output
.
get_shape
());
gs_launch
(
stream
,
nelements
)(
[
=
](
auto
i
)
{
outptr
[
desc_output
.
linear
(
desc_input
.
multi
(
i
))]
=
inptr
[
i
];
});
});
});
});
});
}
}
...
...
src/targets/gpu/device/gather.cpp
View file @
511c8d8f
...
@@ -11,35 +11,30 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -11,35 +11,30 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
argument
gather
(
hipStream_t
stream
,
argument
gather
(
hipStream_t
stream
,
argument
result
,
argument
arg1
,
argument
arg2
,
int
axis
)
const
migraphx
::
shape
&
output_shape
,
std
::
vector
<
migraphx
::
argument
>
args
,
int
axis
)
{
{
auto
axis_index
=
(
axis
<
0
)
?
(
axis
+
args
[
0
].
get_shape
().
lens
().
size
())
:
axis
;
auto
axis_index
=
(
axis
<
0
)
?
(
axis
+
arg1
.
get_shape
().
lens
().
size
())
:
axis
;
visit_all
(
args
.
back
(),
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
auto
&
input_shape
=
arg1
.
get_shape
();
std
::
size_t
nelements
=
output_shape
.
elements
();
auto
lens
=
input_shape
.
lens
();
args
[
1
].
visit
([
&
](
auto
indices
)
{
lens
[
axis_index
]
=
arg2
.
get_shape
().
elements
();
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
shape
out_comp_shape
{
result
.
get_shape
().
type
(),
lens
};
auto
*
out_ptr
=
device_cast
(
output
.
data
());
std
::
size_t
nelements
=
result
.
get_shape
().
elements
();
const
auto
*
in_ptr
=
device_cast
(
input
.
data
());
auto
&
input_shape
=
args
[
0
].
get_shape
();
visit_all
(
result
,
arg1
)([
&
](
auto
output
,
auto
input_v
)
{
auto
lens
=
input_shape
.
lens
();
hip_visit_views
(
input_v
,
out_comp_shape
)([
&
](
auto
input
,
auto
out_comp
)
{
lens
[
axis_index
]
=
args
[
1
].
get_shape
().
elements
();
arg2
.
visit
([
&
](
auto
indices
)
{
migraphx
::
shape
out_comp_shape
{
output_shape
.
type
(),
lens
};
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
visit_tensor_size
(
out_comp_shape
.
lens
().
size
(),
[
&
](
auto
n_out_dim
)
{
auto
*
output_ptr
=
device_cast
(
output
.
data
());
hip_tensor_descriptor
<
n_out_dim
>
desc_input
(
input_shape
);
gs_launch
(
stream
,
nelements
)([
=
](
auto
i
)
{
hip_tensor_descriptor
<
n_out_dim
>
desc_output
(
out_comp_shape
);
auto
idx
=
out_comp
.
multi
(
i
);
gs_launch
(
stream
,
nelements
)([
=
](
auto
ii
)
{
idx
[
axis_index
]
=
indices_ptr
[
idx
[
axis_index
]];
auto
in_idx
=
desc_output
.
multi
(
ii
);
output_ptr
[
i
]
=
input
[
idx
];
in_idx
[
axis_index
]
=
indices_ptr
[
in_idx
[
axis_index
]];
out_ptr
[
ii
]
=
in_ptr
[
desc_input
.
linear
(
in_idx
)];
});
});
});
});
});
});
});
});
return
args
.
back
()
;
return
result
;
}
}
}
// namespace device
}
// namespace device
...
...
src/targets/gpu/device/include/migraphx/gpu/device/array.hpp
0 → 100644
View file @
511c8d8f
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ARRAY_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ARRAY_HPP
#include <migraphx/gpu/device/types.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
template
<
class
T
,
std
::
size_t
N
>
struct
hip_array
{
T
d
[
N
];
MIGRAPHX_DEVICE_CONSTEXPR
T
&
operator
[](
std
::
size_t
i
)
{
return
d
[
i
];
}
MIGRAPHX_DEVICE_CONSTEXPR
const
T
&
operator
[](
std
::
size_t
i
)
const
{
return
d
[
i
];
}
MIGRAPHX_DEVICE_CONSTEXPR
T
*
data
()
{
return
d
;
}
MIGRAPHX_DEVICE_CONSTEXPR
const
T
*
data
()
const
{
return
d
;
}
MIGRAPHX_DEVICE_CONSTEXPR
std
::
integral_constant
<
std
::
size_t
,
N
>
size
()
const
{
return
{};
}
MIGRAPHX_DEVICE_CONSTEXPR
T
*
begin
()
{
return
d
;
}
MIGRAPHX_DEVICE_CONSTEXPR
const
T
*
begin
()
const
{
return
d
;
}
MIGRAPHX_DEVICE_CONSTEXPR
T
*
end
()
{
return
d
+
size
();
}
MIGRAPHX_DEVICE_CONSTEXPR
const
T
*
end
()
const
{
return
d
+
size
();
}
MIGRAPHX_DEVICE_CONSTEXPR
T
dot
(
const
hip_array
&
x
)
const
{
T
result
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
N
;
i
++
)
result
+=
x
[
i
]
*
d
[
i
];
return
result
;
}
MIGRAPHX_DEVICE_CONSTEXPR
T
product
()
const
{
T
result
=
1
;
for
(
std
::
size_t
i
=
0
;
i
<
N
;
i
++
)
result
*=
d
[
i
];
return
result
;
}
friend
MIGRAPHX_DEVICE_CONSTEXPR
hip_array
operator
*
(
const
hip_array
&
x
,
const
hip_array
&
y
)
{
hip_array
result
;
for
(
std
::
size_t
i
=
0
;
i
<
N
;
i
++
)
result
[
i
]
=
x
[
i
]
*
y
[
i
];
return
result
;
}
};
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
View file @
511c8d8f
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_NARY_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_NARY_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_NARY_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_NARY_HPP
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/
types
.hpp>
#include <migraphx/gpu/device/
visit
.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/array.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -13,57 +13,30 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -13,57 +13,30 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
template
<
class
T
>
using
vec4
=
T
__attribute__
((
ext_vector_type
(
4
)));
template
<
class
T
>
__device__
__host__
vec4
<
T
>*
as_vec4
(
T
*
x
)
{
return
reinterpret_cast
<
vec4
<
T
>*>
(
x
);
}
template
<
class
T
>
__device__
__host__
T
*
as_pointer
(
vec4
<
T
>*
x
)
{
return
reinterpret_cast
<
T
*>
(
x
);
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
auto
pack
_vec4
(
Ts
...
xs
)
auto
pack
(
Ts
...
xs
)
__device__
{
{
return
[
=
](
auto
f
,
std
::
size_t
n
)
{
return
f
(
as_vec4
(
xs
)[
n
]
...);
};
return
[
=
](
auto
f
)
{
return
f
(
xs
...);
};
}
}
template
<
class
F
,
class
...
Arguments
>
template
<
class
F
,
class
...
Arguments
>
auto
nary_nonstandard_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
Arguments
...
args
)
auto
nary_nonstandard_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
Arguments
...
args
)
{
{
const
auto
&
output_shape
=
result
.
get_shape
();
std
::
size_t
nelements
=
result
.
get_shape
().
elements
();
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
hip_visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
visit_tensor_size
(
output_shape
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
gs_launch
(
stream
,
nelements
)([
=
](
auto
i
)
{
auto
data
=
pack
(
std
::
make_pair
(
hip_tensor_descriptor
<
ndim
>
{
inputs
.
get_shape
()},
auto
idx
=
output
.
get_shape
().
multi
(
i
);
device_cast
(
inputs
.
data
()))...);
output
[
i
]
=
f
(
inputs
[
idx
]...);
hip_tensor_descriptor
<
ndim
>
out_desc
(
output_shape
);
auto
*
outp
=
device_cast
(
output
.
data
());
gs_launch
(
stream
,
output_shape
.
elements
())([
=
](
auto
i
)
{
data
([
&
](
auto
&&
...
ps
)
{
auto
outidx
=
out_desc
.
multi
(
i
);
outp
[
i
]
=
f
(
ps
.
second
[
ps
.
first
.
linear
(
outidx
)]...);
});
});
});
});
});
});
}
}
template
<
class
F
>
template
<
class
F
,
class
...
Arguments
>
void
trinary_broadcast_vec_impl
(
hipStream_t
stream
,
void
nary_broadcast_vec_impl
(
F
f
,
hipStream_t
stream
,
F
f
,
argument
result
,
argument
barg
,
Arguments
...
args
)
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
)
{
{
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
arg
3
.
get_shape
();
const
auto
&
b_shape
=
b
arg
.
get_shape
();
auto
bdim
=
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
...
@@ -73,156 +46,45 @@ void trinary_broadcast_vec_impl(hipStream_t stream,
...
@@ -73,156 +46,45 @@ void trinary_broadcast_vec_impl(hipStream_t stream,
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
visit_all
(
result
,
arg1
,
arg2
,
arg3
)([
&
](
auto
output
,
auto
input1
,
auto
input2
,
auto
input3
)
{
const
std
::
size_t
vec_size
=
4
;
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>>
;
const
std
::
size_t
nlocal
=
1024
;
auto
*
xp
=
as_vec4
(
device_cast
(
input1
.
data
()));
const
std
::
size_t
nglobal
=
256
*
nlocal
;
auto
*
yp
=
as_vec4
(
device_cast
(
input2
.
data
()));
const
std
::
size_t
bdim_vec_len
=
bdim_len
/
vec_size
;
auto
*
zp
=
as_vec4
(
device_cast
(
input3
.
data
()));
hip_vec_visit_all
<
vec_size
>
(
result
,
barg
,
args
...)(
auto
*
outp
=
as_vec4
(
device_cast
(
output
.
data
()));
[
&
](
auto
output
,
auto
binput
,
auto
...
inputs
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
const
std
::
size_t
vec_size
=
4
;
const
std
::
size_t
nelements
=
output
.
size
()
/
vec_size
;
const
std
::
size_t
nlocal
=
1024
;
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
n
=
output
.
size
()
/
vec_size
;
MIGRAPHX_DEVICE_SHARED
type
buffer
[
2048
/
vec_size
];
const
std
::
size_t
bdim_vec_len
=
bdim_len
/
vec_size
;
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_vec_len
;
i
+=
nlocal
)
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPHX_DEVICE_SHARED
vec4
<
type
>
buffer
[
2048
/
vec_size
];
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_vec_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
zp
[
i
];
}
__syncthreads
();
auto
*
bp
=
as_pointer
(
buffer
);
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
n
;
i
+=
nglobal
)
{
auto
bidx
=
((
i
*
vec_size
)
%
bdim_next_stride
)
/
bdim_stride
;
auto
b
=
bp
[
bidx
];
vec4
<
type
>
x
=
xp
[
i
];
vec4
<
type
>
y
=
yp
[
i
];
vec4
<
type
>
out
=
outp
[
i
];
for
(
std
::
size_t
j
=
0
;
j
<
vec_size
;
j
++
)
{
{
out
[
j
]
=
f
(
x
[
j
],
y
[
j
],
b
)
;
buffer
[
i
]
=
binput
.
data
()[
i
]
;
}
}
outp
[
i
]
=
out
;
__syncthreads
();
}
auto
*
bp
=
as_pointer
(
buffer
);
});
// Process the data
});
for
(
size_t
i
=
idx
.
global
;
i
<
nelements
;
i
+=
nglobal
)
}
template
<
class
F
>
void
trinary_broadcast_impl
(
hipStream_t
stream
,
F
f
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
)
{
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
arg3
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim_len
=
output_shape
.
lens
()[
bdim
];
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
visit_all
(
result
,
arg1
,
arg2
,
arg3
)([
&
](
auto
output
,
auto
input1
,
auto
input2
,
auto
input3
)
{
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>>
;
auto
*
xp
=
device_cast
(
input1
.
data
());
auto
*
yp
=
device_cast
(
input2
.
data
());
auto
*
zp
=
device_cast
(
input3
.
data
());
auto
*
outp
=
device_cast
(
output
.
data
());
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
n
=
output
.
size
();
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPHX_DEVICE_SHARED
type
buffer
[
2048
];
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
zp
[
i
];
}
__syncthreads
();
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
n
;
i
+=
nglobal
)
{
auto
bidx
=
(
i
%
bdim_next_stride
)
/
bdim_stride
;
auto
b
=
buffer
[
bidx
];
type
x
=
xp
[
i
];
type
y
=
yp
[
i
];
outp
[
i
]
=
f
(
x
,
y
,
b
);
}
});
});
}
template
<
class
F
>
void
binary_broadcast_vec_impl
(
hipStream_t
stream
,
F
f
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
{
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
arg2
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim_len
=
output_shape
.
lens
()[
bdim
];
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
visit_all
(
result
,
arg1
,
arg2
)([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>>
;
auto
*
xp
=
as_vec4
(
device_cast
(
input1
.
data
()));
auto
*
yp
=
as_vec4
(
device_cast
(
input2
.
data
()));
auto
*
outp
=
as_vec4
(
device_cast
(
output
.
data
()));
const
std
::
size_t
vec_size
=
4
;
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
n
=
output
.
size
()
/
vec_size
;
const
std
::
size_t
bdim_vec_len
=
bdim_len
/
vec_size
;
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPHX_DEVICE_SHARED
vec4
<
type
>
buffer
[
2048
/
vec_size
];
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_vec_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
yp
[
i
];
}
__syncthreads
();
auto
*
bp
=
as_pointer
(
buffer
);
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
n
;
i
+=
nglobal
)
{
auto
bidx
=
((
i
*
vec_size
)
%
bdim_next_stride
)
/
bdim_stride
;
auto
b
=
bp
[
bidx
];
vec4
<
type
>
x
=
xp
[
i
];
vec4
<
type
>
out
=
outp
[
i
];
for
(
std
::
size_t
j
=
0
;
j
<
vec_size
;
j
++
)
{
{
out
[
j
]
=
f
(
x
[
j
],
b
);
auto
bidx
=
((
i
*
vec_size
)
%
bdim_next_stride
)
/
bdim_stride
;
auto
b
=
bp
[
bidx
];
auto
out
=
output
.
data
()[
i
];
for
(
std
::
size_t
j
=
0
;
j
<
vec_size
;
j
++
)
{
out
[
j
]
=
f
(
inputs
.
data
()[
i
][
j
]...,
b
);
}
output
.
data
()[
i
]
=
out
;
}
}
outp
[
i
]
=
out
;
});
}
});
});
});
}
}
template
<
class
F
>
template
<
class
F
,
class
...
Arguments
>
void
binary_broadcast_impl
(
void
nary_broadcast_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
argument
barg
,
Arguments
...
args
)
hipStream_t
stream
,
F
f
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
{
{
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
arg
2
.
get_shape
();
const
auto
&
b_shape
=
b
arg
.
get_shape
();
auto
bdim
=
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
...
@@ -232,31 +94,25 @@ void binary_broadcast_impl(
...
@@ -232,31 +94,25 @@ void binary_broadcast_impl(
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
visit_all
(
result
,
arg1
,
arg2
)([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
const
std
::
size_t
nlocal
=
1024
;
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>>
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
auto
*
xp
=
device_cast
(
input1
.
data
());
std
::
size_t
nelements
=
result
.
get_shape
().
elements
();
auto
*
yp
=
device_cast
(
input2
.
data
());
hip_visit_all
(
result
,
barg
,
args
...)([
&
](
auto
output
,
auto
binput
,
auto
...
inputs
)
{
auto
*
outp
=
device_cast
(
output
.
data
());
using
type
=
typename
decltype
(
output
)
::
value_type
;
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
n
=
output
.
size
();
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPHX_DEVICE_SHARED
type
buffer
[
2048
];
MIGRAPHX_DEVICE_SHARED
type
buffer
[
2048
];
// Load bias into LDS
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
{
{
buffer
[
i
]
=
yp
[
i
];
buffer
[
i
]
=
binput
.
data
()
[
i
];
}
}
__syncthreads
();
__syncthreads
();
// Process the data
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
n
;
i
+=
nglobal
)
for
(
size_t
i
=
idx
.
global
;
i
<
n
elements
;
i
+=
nglobal
)
{
{
auto
bidx
=
(
i
%
bdim_next_stride
)
/
bdim_stride
;
auto
bidx
=
(
i
%
bdim_next_stride
)
/
bdim_stride
;
auto
b
=
buffer
[
bidx
];
auto
b
=
buffer
[
bidx
];
type
x
=
xp
[
i
];
output
.
data
()[
i
]
=
f
(
inputs
.
data
()[
i
]...,
b
);
outp
[
i
]
=
f
(
x
,
b
);
}
}
});
});
});
});
...
@@ -265,15 +121,14 @@ void binary_broadcast_impl(
...
@@ -265,15 +121,14 @@ void binary_broadcast_impl(
template
<
class
F
,
class
...
Arguments
>
template
<
class
F
,
class
...
Arguments
>
void
nary_standard_vec_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
Arguments
...
args
)
void
nary_standard_vec_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
Arguments
...
args
)
{
{
// assert(x.get_shape().elements() == y.get_shape().elements());
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
output_shape
=
result
.
get_shape
();
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>>
;
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>>
;
const
std
::
size_t
vec_size
=
4
;
const
std
::
size_t
vec_size
=
4
;
auto
data
=
pack_vec
4
(
device_cast
(
inputs
.
data
())...);
auto
data
=
pack_vec
<
4
>
(
device_cast
(
inputs
.
data
())...);
auto
*
outp
=
as_vec
4
(
device_cast
(
output
.
data
()));
auto
*
outp
=
as_vec
<
4
>
(
device_cast
(
output
.
data
()));
gs_launch
(
stream
,
output_shape
.
elements
()
/
vec_size
)([
=
](
auto
i
)
{
gs_launch
(
stream
,
output_shape
.
elements
()
/
vec_size
)([
=
](
auto
i
)
{
vec
4
<
type
>
out
=
outp
[
i
];
vec
<
type
,
4
>
out
=
outp
[
i
];
data
(
data
(
[
&
](
auto
...
xs
)
{
[
&
](
auto
...
xs
)
{
for
(
std
::
size_t
j
=
0
;
j
<
vec_size
;
j
++
)
for
(
std
::
size_t
j
=
0
;
j
<
vec_size
;
j
++
)
...
@@ -290,13 +145,9 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments.
...
@@ -290,13 +145,9 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments.
template
<
class
F
,
class
...
Arguments
>
template
<
class
F
,
class
...
Arguments
>
void
nary_standard_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
Arguments
...
args
)
void
nary_standard_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
Arguments
...
args
)
{
{
// assert(x.get_shape().elements() == y.get_shape().elements());
std
::
size_t
nelements
=
result
.
get_shape
().
elements
();
const
auto
&
output_shape
=
result
.
get_shape
();
hip_pointer_visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
gs_launch
(
stream
,
nelements
)([
=
](
auto
i
)
{
output
[
i
]
=
f
(
inputs
[
i
]...);
});
auto
data
=
pack
(
device_cast
(
inputs
.
data
())...);
auto
*
outp
=
device_cast
(
output
.
data
());
gs_launch
(
stream
,
output_shape
.
elements
())(
[
=
](
auto
i
)
{
data
([
&
](
auto
...
xps
)
{
outp
[
i
]
=
f
(
xps
[
i
]...);
});
});
});
});
}
}
...
@@ -313,12 +164,6 @@ void nary_impl(hipStream_t stream, F f, argument result, Arguments... args)
...
@@ -313,12 +164,6 @@ void nary_impl(hipStream_t stream, F f, argument result, Arguments... args)
nary_nonstandard_impl
(
stream
,
f
,
result
,
args
...);
nary_nonstandard_impl
(
stream
,
f
,
result
,
args
...);
}
}
template
<
class
F
>
void
nary_impl
(
hipStream_t
stream
,
F
f
,
argument
result
)
{
nary_standard_impl
(
stream
,
f
,
result
);
}
template
<
class
...
Arguments
>
template
<
class
...
Arguments
>
auto
nary_nonstandard
(
hipStream_t
stream
,
argument
result
,
Arguments
...
args
)
auto
nary_nonstandard
(
hipStream_t
stream
,
argument
result
,
Arguments
...
args
)
{
{
...
@@ -332,71 +177,50 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
...
@@ -332,71 +177,50 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
}
}
template
<
class
...
Arguments
>
template
<
class
...
Arguments
>
auto
nary
(
hipStream_t
stream
,
argument
result
,
Arguments
...
args
)
auto
nary
(
hipStream_t
stream
,
argument
result
)
{
{
return
[
=
](
auto
f
)
{
nary_impl
(
stream
,
f
,
result
,
args
...
);
};
return
[
=
](
auto
f
)
{
nary_
standard_
impl
(
stream
,
f
,
result
);
};
}
}
inline
auto
template
<
class
...
Arguments
>
nary
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
auto
nary
(
hipStream_t
stream
,
argument
result
,
Arguments
...
args
)
{
{
return
[
=
](
auto
f
)
{
// TODO: Check result and arg1 shape is the same
if
(
arg1
.
get_shape
().
standard
()
and
arg2
.
get_shape
().
broadcasted
()
and
not
arg2
.
get_shape
().
scalar
())
{
auto
not_zero
=
[](
auto
x
)
{
return
x
!=
0
;
};
const
auto
&
strides
=
arg2
.
get_shape
().
strides
();
auto
b_it
=
std
::
find_if
(
strides
.
begin
(),
strides
.
end
(),
not_zero
);
auto
b_idx
=
std
::
distance
(
strides
.
begin
(),
b_it
);
auto
b_len
=
result
.
get_shape
().
lens
()[
b_idx
];
auto
b_stride
=
result
.
get_shape
().
strides
()[
b_idx
];
assert
(
arg2
.
get_shape
().
lens
()[
b_idx
]
==
b_len
);
if
(
b_len
<=
2048
and
std
::
none_of
(
std
::
next
(
b_it
),
strides
.
end
(),
not_zero
))
{
const
bool
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
(
arg1
.
get_shape
().
elements
()
%
4
==
0
);
if
(
divisible_by_4
)
binary_broadcast_vec_impl
(
stream
,
f
,
result
,
arg1
,
arg2
);
else
binary_broadcast_impl
(
stream
,
f
,
result
,
arg1
,
arg2
);
return
;
}
}
nary_impl
(
stream
,
f
,
result
,
arg1
,
arg2
);
};
}
inline
auto
nary
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
)
{
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
// TODO: Check result and arg1 shape is the same
auto
barg
=
back_args
(
args
...);
if
(
arg1
.
get_shape
().
standard
()
and
arg2
.
get_shape
().
standard
()
and
bool
fallback
=
pop_back_args
(
args
...)([
&
](
auto
&&
...
args2
)
{
arg3
.
get_shape
().
broadcasted
())
auto
bshape
=
barg
.
get_shape
();
{
const
bool
standard
=
auto
not_zero
=
[](
auto
x
)
{
return
x
!=
0
;
};
all_of
({
args2
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
const
auto
&
strides
=
arg3
.
get_shape
().
strides
();
const
bool
same_shapes
=
all_of
(
auto
b_it
=
std
::
find_if
(
strides
.
begin
(),
strides
.
end
(),
not_zero
);
{
args2
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
auto
b_idx
=
std
::
distance
(
strides
.
begin
(),
b_it
);
// TODO: Check result and args shape is the same
auto
b_len
=
result
.
get_shape
().
lens
()[
b_idx
];
if
(
standard
and
same_shapes
and
bshape
.
broadcasted
()
and
not
bshape
.
scalar
())
auto
b_stride
=
result
.
get_shape
().
strides
()[
b_idx
];
assert
(
arg3
.
get_shape
().
lens
()[
b_idx
]
==
b_len
);
if
(
b_len
<=
2048
and
std
::
none_of
(
std
::
next
(
b_it
),
strides
.
end
(),
not_zero
))
{
{
const
bool
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
auto
not_zero
=
[](
auto
x
)
{
return
x
!=
0
;
};
(
arg1
.
get_shape
().
elements
()
%
4
==
0
);
const
auto
&
strides
=
bshape
.
strides
();
if
(
divisible_by_4
)
auto
b_it
=
std
::
find_if
(
strides
.
begin
(),
strides
.
end
(),
not_zero
);
trinary_broadcast_vec_impl
(
stream
,
f
,
result
,
arg1
,
arg2
,
arg3
);
auto
b_idx
=
std
::
distance
(
strides
.
begin
(),
b_it
);
else
auto
b_len
=
result
.
get_shape
().
lens
()[
b_idx
];
trinary_broadcast_impl
(
stream
,
f
,
result
,
arg1
,
arg2
,
arg3
);
auto
b_stride
=
result
.
get_shape
().
strides
()[
b_idx
];
return
;
assert
(
bshape
.
lens
()[
b_idx
]
==
b_len
);
if
(
b_len
<=
2048
and
std
::
none_of
(
std
::
next
(
b_it
),
strides
.
end
(),
not_zero
))
{
const
bool
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
(
front_args
(
args
...).
get_shape
().
elements
()
%
4
==
0
);
if
(
divisible_by_4
)
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg
,
args2
...);
else
nary_broadcast_impl
(
stream
,
f
,
result
,
barg
,
args2
...);
return
false
;
}
}
}
}
return
true
;
nary_impl
(
stream
,
f
,
result
,
arg1
,
arg2
,
arg3
);
});
if
(
fallback
)
nary_impl
(
stream
,
f
,
result
,
args
...);
};
};
}
}
...
...
src/targets/gpu/device/include/migraphx/gpu/device/shape.hpp
0 → 100644
View file @
511c8d8f
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_SHAPE_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_SHAPE_HPP
#include <migraphx/gpu/device/array.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
template
<
std
::
size_t
N
>
struct
hip_shape
{
using
hip_index
=
hip_array
<
std
::
size_t
,
N
>
;
hip_array
<
std
::
size_t
,
N
>
lens
=
{};
hip_array
<
std
::
size_t
,
N
>
strides
=
{};
bool
standard
=
false
;
__device__
__host__
hip_shape
()
=
default
;
hip_shape
(
const
shape
&
s
)
:
standard
(
s
.
standard
())
{
assert
(
s
.
lens
().
size
()
==
N
);
assert
(
s
.
strides
().
size
()
==
N
);
std
::
copy
(
s
.
lens
().
begin
(),
s
.
lens
().
end
(),
lens
.
begin
());
std
::
copy
(
s
.
strides
().
begin
(),
s
.
strides
().
end
(),
strides
.
begin
());
}
MIGRAPHX_DEVICE_CONSTEXPR
std
::
size_t
elements
()
const
{
return
lens
.
product
();
}
MIGRAPHX_DEVICE_CONSTEXPR
std
::
size_t
index
(
hip_index
x
)
const
{
return
x
.
dot
(
strides
);
}
MIGRAPHX_DEVICE_CONSTEXPR
std
::
size_t
index
(
std
::
initializer_list
<
std
::
size_t
>
x
)
const
{
std
::
size_t
idx
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
x
.
size
();
i
++
)
idx
+=
*
(
x
.
begin
()
+
i
)
*
strides
[
i
];
return
idx
;
}
MIGRAPHX_DEVICE_CONSTEXPR
std
::
size_t
index
(
std
::
size_t
i
)
const
{
if
(
this
->
standard
)
return
i
;
else
{
const
std
::
size_t
rank
=
this
->
lens
.
size
();
std
::
size_t
s
=
1
;
std
::
size_t
result
=
0
;
for
(
std
::
size_t
j
=
0
;
j
<
this
->
lens
.
size
();
j
++
)
{
const
std
::
size_t
k
=
rank
-
j
-
1
;
const
std
::
size_t
stride
=
this
->
strides
[
k
];
const
std
::
size_t
len
=
this
->
lens
[
k
];
const
std
::
size_t
slen
=
s
*
len
;
const
std
::
size_t
idx
=
(
i
%
slen
)
/
s
;
result
+=
stride
*
idx
;
s
=
slen
;
}
return
result
;
}
}
MIGRAPHX_DEVICE_CONSTEXPR
hip_index
multi
(
std
::
size_t
idx
)
const
{
hip_index
result
;
std
::
size_t
tidx
=
idx
;
for
(
std
::
size_t
is
=
0
;
is
<
result
.
size
();
is
++
)
{
result
[
is
]
=
tidx
/
strides
[
is
];
tidx
=
tidx
%
strides
[
is
];
}
return
result
;
}
};
template
<
std
::
size_t
N
>
hip_shape
<
N
>
make_hip_shape
(
const
shape
&
x
)
{
return
x
;
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/device/include/migraphx/gpu/device/tensor.hpp
View file @
511c8d8f
#ifndef MIGRAPHX_GUARD_RTGLIB_DEAVICE_TENSOR_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_DEAVICE_TENSOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEAVICE_TENSOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEAVICE_TENSOR_HPP
#include <hip/hip_runtime.h>
#include <migraphx/gpu/device/visit.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
template
<
class
F
>
template
<
std
::
size_t
NDim
>
void
visit_tensor_size
(
std
::
size_t
n
,
F
f
)
using
hip_tensor_index
=
hip_array
<
std
::
size_t
,
NDim
>
;
{
switch
(
n
)
{
case
1
:
{
f
(
std
::
integral_constant
<
std
::
size_t
,
1
>
{});
break
;
}
case
2
:
{
f
(
std
::
integral_constant
<
std
::
size_t
,
2
>
{});
break
;
}
case
3
:
{
f
(
std
::
integral_constant
<
std
::
size_t
,
3
>
{});
break
;
}
case
4
:
{
f
(
std
::
integral_constant
<
std
::
size_t
,
4
>
{});
break
;
}
case
5
:
{
f
(
std
::
integral_constant
<
std
::
size_t
,
5
>
{});
break
;
}
default:
throw
std
::
runtime_error
(
"Unknown tensor size"
);
}
}
template
<
size_t
NDim
>
struct
hip_index
{
size_t
d
[
NDim
];
__device__
__host__
size_t
&
operator
[](
size_t
i
)
{
return
d
[
i
];
}
__device__
__host__
size_t
operator
[](
size_t
i
)
const
{
return
d
[
i
];
}
};
template
<
size_t
NDim
>
template
<
std
::
size_t
NDim
>
struct
hip_tensor_descriptor
struct
hip_tensor_descriptor
{
{
__device__
__host__
hip_tensor_descriptor
()
=
default
;
__device__
__host__
hip_tensor_descriptor
()
=
default
;
...
@@ -63,26 +22,26 @@ struct hip_tensor_descriptor
...
@@ -63,26 +22,26 @@ struct hip_tensor_descriptor
std
::
copy
(
s
.
strides
().
begin
(),
s
.
strides
().
end
(),
strides
);
std
::
copy
(
s
.
strides
().
begin
(),
s
.
strides
().
end
(),
strides
);
}
}
__device__
__host__
hip_index
<
NDim
>
multi
(
size_t
idx
)
const
__device__
__host__
hip_
tensor_
index
<
NDim
>
multi
(
std
::
size_t
idx
)
const
{
{
hip_index
<
NDim
>
result
{};
hip_
tensor_
index
<
NDim
>
result
{};
size_t
tidx
=
idx
;
std
::
size_t
tidx
=
idx
;
for
(
size_t
is
=
0
;
is
<
NDim
;
is
++
)
for
(
std
::
size_t
is
=
0
;
is
<
NDim
;
is
++
)
{
{
result
[
is
]
=
tidx
/
strides
[
is
];
result
[
is
]
=
tidx
/
strides
[
is
];
tidx
=
tidx
%
strides
[
is
];
tidx
=
tidx
%
strides
[
is
];
}
}
return
result
;
return
result
;
}
}
__device__
__host__
size_t
linear
(
hip_index
<
NDim
>
s
)
const
__device__
__host__
std
::
size_t
linear
(
hip_
tensor_
index
<
NDim
>
s
)
const
{
{
size_t
idx
=
0
;
std
::
size_t
idx
=
0
;
for
(
size_t
i
=
0
;
i
<
NDim
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
NDim
;
i
++
)
idx
+=
s
[
i
]
*
strides
[
i
];
idx
+=
s
[
i
]
*
strides
[
i
];
return
idx
;
return
idx
;
}
}
size_t
lens
[
NDim
]
=
{};
std
::
size_t
lens
[
NDim
]
=
{};
size_t
strides
[
NDim
]
=
{};
std
::
size_t
strides
[
NDim
]
=
{};
};
};
}
// namespace device
}
// namespace device
...
...
src/targets/gpu/device/include/migraphx/gpu/device/tensor_view.hpp
0 → 100644
View file @
511c8d8f
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_TENSOR_VIEW_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_TENSOR_VIEW_HPP
#include <migraphx/gpu/device/shape.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
template
<
class
T
,
std
::
size_t
N
>
struct
hip_tensor_view
{
using
value_type
=
T
;
using
hip_index
=
typename
hip_shape
<
N
>::
hip_index
;
__device__
__host__
hip_tensor_view
()
=
default
;
__host__
hip_tensor_view
(
tensor_view
<
T
>
x
)
:
d
(
x
.
data
()),
s
(
x
.
get_shape
())
{}
__host__
hip_tensor_view
(
T
*
x
,
const
shape
&
ss
)
:
d
(
x
),
s
(
ss
)
{}
MIGRAPHX_DEVICE_CONSTEXPR
const
hip_shape
<
N
>&
get_shape
()
const
{
return
s
;
}
MIGRAPHX_DEVICE_CONSTEXPR
std
::
size_t
size
()
const
{
return
s
.
elements
();
}
MIGRAPHX_DEVICE_CONSTEXPR
value_type
*
data
()
const
{
return
d
;
}
template
<
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
value_type
&
operator
[](
U
i
)
const
{
return
d
[
s
.
index
(
i
)];
}
MIGRAPHX_DEVICE_CONSTEXPR
value_type
*
begin
()
const
{
return
d
;
}
MIGRAPHX_DEVICE_CONSTEXPR
value_type
*
end
()
const
{
return
d
+
size
();
}
private:
value_type
*
d
=
nullptr
;
hip_shape
<
N
>
s
{};
};
template
<
std
::
size_t
N
,
class
T
>
hip_tensor_view
<
T
,
N
>
make_hip_view
(
const
shape
&
s
,
T
*
x
)
{
return
{
x
,
s
};
}
template
<
std
::
size_t
N
,
class
T
>
hip_tensor_view
<
T
,
N
>
make_hip_view
(
tensor_view
<
T
>
x
)
{
return
{
x
};
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/device/include/migraphx/gpu/device/types.hpp
View file @
511c8d8f
...
@@ -8,14 +8,45 @@
...
@@ -8,14 +8,45 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_GPU_DEVICE_TYPES_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_GPU_DEVICE_TYPES_HPP
#define MIGRAPHX_GUARD_RTGLIB_GPU_DEVICE_TYPES_HPP
#define MIGRAPHX_GUARD_RTGLIB_GPU_DEVICE_TYPES_HPP
#include <hip/hip_runtime.h>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/tensor_view.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
#define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT
template
<
class
T
,
std
::
size_t
N
>
using
vec
=
T
__attribute__
((
ext_vector_type
(
N
)));
template
<
std
::
size_t
N
,
class
T
>
__device__
__host__
T
*
as_pointer
(
vec
<
T
,
N
>*
x
)
{
return
reinterpret_cast
<
T
*>
(
x
);
}
template
<
std
::
size_t
N
,
class
T
>
__device__
__host__
vec
<
T
,
N
>*
as_vec
(
T
*
x
)
{
return
reinterpret_cast
<
vec
<
T
,
N
>*>
(
x
);
}
template
<
std
::
size_t
N
,
class
T
>
tensor_view
<
vec
<
T
,
N
>>
as_vec
(
tensor_view
<
T
>
x
)
{
return
{
x
.
get_shape
(),
as_vec
<
N
>
(
x
.
data
())};
}
template
<
std
::
size_t
N
,
class
...
Ts
>
auto
pack_vec
(
Ts
...
xs
)
{
return
[
=
](
auto
f
,
std
::
size_t
n
)
{
return
f
(
as_vec
<
N
>
(
xs
)[
n
]...);
};
}
using
gpu_half
=
__fp16
;
using
gpu_half
=
__fp16
;
namespace
detail
{
namespace
detail
{
...
@@ -25,6 +56,12 @@ struct device_type
...
@@ -25,6 +56,12 @@ struct device_type
using
type
=
T
;
using
type
=
T
;
};
};
template
<
class
T
,
std
::
size_t
N
>
struct
device_type
<
vec
<
T
,
N
>>
{
using
type
=
vec
<
typename
device_type
<
T
>::
type
,
N
>
;
};
template
<
>
template
<
>
struct
device_type
<
half
>
struct
device_type
<
half
>
{
{
...
@@ -38,7 +75,7 @@ struct host_type
...
@@ -38,7 +75,7 @@ struct host_type
};
};
template
<
>
template
<
>
struct
device
_type
<
gpu_half
>
struct
host
_type
<
gpu_half
>
{
{
using
type
=
half
;
using
type
=
half
;
};
};
...
@@ -64,9 +101,9 @@ host_type<T>* host_cast(T* x)
...
@@ -64,9 +101,9 @@ host_type<T>* host_cast(T* x)
}
}
template
<
class
T
>
template
<
class
T
>
device_type
<
T
>
device_cast
(
T
x
)
device_type
<
T
>
device_cast
(
const
T
&
x
)
{
{
return
reinterpret_cast
<
device_type
<
T
>>
(
x
);
return
reinterpret_cast
<
const
device_type
<
T
>
&
>
(
x
);
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -75,6 +112,12 @@ device_type<T>* device_cast(T* x)
...
@@ -75,6 +112,12 @@ device_type<T>* device_cast(T* x)
return
reinterpret_cast
<
device_type
<
T
>*>
(
x
);
return
reinterpret_cast
<
device_type
<
T
>*>
(
x
);
}
}
template
<
class
T
>
tensor_view
<
device_type
<
T
>>
device_cast
(
tensor_view
<
T
>
x
)
{
return
{
x
.
get_shape
(),
reinterpret_cast
<
device_type
<
T
>*>
(
x
.
data
())};
}
template
<
class
T
>
template
<
class
T
>
T
to_hip_type
(
T
x
)
T
to_hip_type
(
T
x
)
{
{
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment