Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
dfa26315
Commit
dfa26315
authored
May 05, 2022
by
charlie
Browse files
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_shape_update
parents
aa085491
4a5a23a4
Changes
42
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
843 additions
and
155 deletions
+843
-155
.clang-tidy
.clang-tidy
+1
-1
CMakeLists.txt
CMakeLists.txt
+1
-1
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/api/api.cpp
src/api/api.cpp
+16
-3
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+2
-0
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+139
-121
src/api/migraphx.py
src/api/migraphx.py
+1
-0
src/include/migraphx/op/gathernd.hpp
src/include/migraphx/op/gathernd.hpp
+131
-0
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+1
-0
src/onnx/parse_generic_op.cpp
src/onnx/parse_generic_op.cpp
+1
-0
src/onnx/parse_reversesequence.cpp
src/onnx/parse_reversesequence.cpp
+125
-0
src/targets/gpu/driver/run_op.cpp
src/targets/gpu/driver/run_op.cpp
+1
-1
src/targets/gpu/jit/gathernd.cpp
src/targets/gpu/jit/gathernd.cpp
+75
-0
src/targets/gpu/jit/reduce.cpp
src/targets/gpu/jit/reduce.cpp
+52
-5
src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
...argets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
+10
-0
src/targets/gpu/kernels/include/migraphx/kernels/debug.hpp
src/targets/gpu/kernels/include/migraphx/kernels/debug.hpp
+70
-4
src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
...targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
+81
-0
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+116
-13
src/targets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp
...gets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp
+16
-4
test/api/CMakeLists.txt
test/api/CMakeLists.txt
+3
-2
No files found.
.clang-tidy
View file @
dfa26315
...
...
@@ -4,7 +4,7 @@ CheckOptions:
- key: bugprone-unused-return-value.CheckedFunctions
value: '::std::async;::std::launder;::std::remove;::std::remove_if;::std::unique;::std::unique_ptr::release;::std::basic_string::empty;::std::vector::empty;::std::find;::std::find_if;::std::find_if_not;::std::all_of;::std::any_of;::std::none_of;::std::count;::std::count_if;::std::mismatch;::std::find_end;::std::find_first_of;::std::adjacent_find;::std::search;::std::search_n;::std::nth_element;::std::lower_bound;::std::upper_bound;::std::binary_search;::std::equal_range;::std::max;::std::max_element;::std::min;::std::min_element;::std::minmax;::std::minmax_element;::std::equal;::std::lexicographical_compare;::std::accumulate;::std::inner_product'
- key: cppcoreguidelines-macro-usage.AllowedRegexp
value: 'DEBUG|ASSERT|ASSUME|UNREACHABLE|FALLTHROUGH|STRINGIZE|_HAS_|_THROW|_REQUIRES|_DECLARE_|_VISIT_|_REGISTER_|_GENERATE_|_DETAIL_|_TIDY_|_MANAGE_PTR|_MATCHER|DEVICE_SHARED|_WORKAROUND_'
value: 'DEBUG|ASSERT|ASSUME|UNREACHABLE|FALLTHROUGH|
DEPRECATED|
STRINGIZE|_HAS_|_THROW|_REQUIRES|_DECLARE_|_VISIT_|_REGISTER_|_GENERATE_|_DETAIL_|_TIDY_|_MANAGE_PTR|_MATCHER|DEVICE_SHARED|_WORKAROUND_'
- key: modernize-loop-convert.MinConfidence
value: risky
- key: modernize-loop-convert.NamingStyle
...
...
CMakeLists.txt
View file @
dfa26315
...
...
@@ -42,7 +42,7 @@ find_package(nlohmann_json 3.8.0 REQUIRED)
include
(
ROCMSetupVersion
)
rocm_setup_version
(
VERSION 2.
2
)
rocm_setup_version
(
VERSION 2.
3
)
set
(
MIGRAPHX_SO_VERSION
${
PROJECT_VERSION_MAJOR
}
.
${
PROJECT_VERSION_MINOR
}
)
option
(
BUILD_SHARED_LIBS
"Build as a shared library"
ON
)
...
...
src/CMakeLists.txt
View file @
dfa26315
...
...
@@ -109,6 +109,7 @@ register_migraphx_ops(
flatten
floor
gather
gathernd
get_tuple_elem
greater
gru
...
...
src/api/api.cpp
View file @
dfa26315
...
...
@@ -401,7 +401,8 @@ extern "C" struct migraphx_instruction;
struct
migraphx_instruction
{
template
<
class
...
Ts
>
migraphx_instruction
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
migraphx_instruction
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
// NOLINT(readability-redundant-member-init)
{
}
migraphx
::
instruction_ref
object
;
...
...
@@ -411,7 +412,8 @@ extern "C" struct migraphx_instructions;
struct
migraphx_instructions
{
template
<
class
...
Ts
>
migraphx_instructions
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
migraphx_instructions
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
// NOLINT(readability-redundant-member-init)
{
}
std
::
vector
<
migraphx
::
instruction_ref
>
object
;
...
...
@@ -421,7 +423,8 @@ extern "C" struct migraphx_modules;
struct
migraphx_modules
{
template
<
class
...
Ts
>
migraphx_modules
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
migraphx_modules
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
// NOLINT(readability-redundant-member-init)
{
}
std
::
vector
<
migraphx
::
module
*>
object
;
...
...
@@ -1691,6 +1694,16 @@ extern "C" migraphx_status migraphx_context_finish(const_migraphx_context_t cont
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_context_get_queue
(
void
**
out
,
migraphx_context_t
context
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
context
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter context: Null pointer"
);
*
out
=
(
context
->
object
).
get_queue
().
unsafe_get
();
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_experimental_custom_op_destroy
(
migraphx_experimental_custom_op_t
experimental_custom_op
)
{
...
...
src/api/include/migraphx/migraphx.h
View file @
dfa26315
...
...
@@ -433,6 +433,8 @@ migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_status
migraphx_context_finish
(
const_migraphx_context_t
context
);
migraphx_status
migraphx_context_get_queue
(
void
**
out
,
migraphx_context_t
context
);
migraphx_status
migraphx_experimental_custom_op_destroy
(
migraphx_experimental_custom_op_t
experimental_custom_op
);
...
...
src/api/include/migraphx/migraphx.hpp
View file @
dfa26315
...
...
@@ -15,6 +15,16 @@ namespace migraphx {
inline
namespace
api
{
// NOLINT
#endif
#ifdef __has_cpp_attribute
#if __has_cpp_attribute(deprecated)
#define MIGRAPHX_DEPRECATED(...) [[deprecated(__VA_ARGS__)]]
#endif
#endif
#ifndef MIGRAPHX_DEPRECATED
#define MIGRAPHX_DEPRECATED(...)
#endif
template
<
int
N
>
struct
rank
:
rank
<
N
-
1
>
{
...
...
@@ -99,34 +109,22 @@ struct iota_iterator
return
it
;
}
// TODO: operator->
reference
operator
*
()
const
{
return
(
*
f
)(
index
);
}
};
reference
operator
*
()
const
{
return
f
(
index
);
}
template
<
class
F
,
class
Iterator
>
inline
iota_iterator
<
F
,
Iterator
>
operator
+
(
iota_iterator
<
F
,
Iterator
>
x
,
iota_iterator
<
F
,
Iterator
>
y
)
{
return
iota_iterator
<
F
,
Iterator
>
(
x
.
index
+
y
.
index
,
x
.
f
);
}
friend
iota_iterator
operator
+
(
iota_iterator
x
,
iota_iterator
y
)
{
return
iota_iterator
(
x
.
index
+
y
.
index
,
x
.
f
);
}
template
<
class
F
,
class
Iterator
>
inline
iota_iterator
<
F
,
Iterator
>
operator
-
(
iota_iterator
<
F
,
Iterator
>
x
,
iota_iterator
<
F
,
Iterator
>
y
)
{
return
iota_iterator
<
F
,
Iterator
>
(
x
.
index
-
y
.
index
,
x
.
f
);
}
friend
iota_iterator
operator
-
(
iota_iterator
x
,
iota_iterator
y
)
{
return
iota_iterator
(
x
.
index
-
y
.
index
,
x
.
f
);
}
template
<
class
F
,
class
Iterator
>
inline
bool
operator
==
(
iota_iterator
<
F
,
Iterator
>
x
,
iota_iterator
<
F
,
Iterator
>
y
)
{
return
x
.
index
==
y
.
index
;
}
friend
bool
operator
==
(
iota_iterator
x
,
iota_iterator
y
)
{
return
x
.
index
==
y
.
index
;
}
template
<
class
F
,
class
Iterator
>
inline
bool
operator
!=
(
iota_iterator
<
F
,
Iterator
>
x
,
iota_iterator
<
F
,
Iterator
>
y
)
{
return
x
.
index
!=
y
.
index
;
}
friend
bool
operator
!=
(
iota_iterator
x
,
iota_iterator
y
)
{
return
x
.
index
!=
y
.
index
;
}
};
template
<
class
Derived
>
struct
array_base
...
...
@@ -136,8 +134,20 @@ struct array_base
template
<
class
T
>
using
value_type_t
=
decltype
(
std
::
declval
<
T
>
()[
0
]);
struct
iterator_read
{
const
Derived
*
self
;
template
<
class
D
=
Derived
>
value_type_t
<
D
>
operator
()(
size_t
pidx
)
const
{
return
(
*
self
)[
pidx
];
}
};
template
<
class
T
>
using
iterator_t
=
iota_iterator
<
typename
T
::
iterator_read
>
;
using
iterator_t
=
iota_iterator
<
iterator_read
>
;
bool
empty
()
const
{
return
derived
().
size
()
==
0
;
}
template
<
class
D
=
Derived
>
value_type_t
<
D
>
front
()
const
...
...
@@ -154,13 +164,13 @@ struct array_base
template
<
class
D
=
Derived
>
iterator_t
<
D
>
begin
()
const
{
return
{
0
,
{
derived
()
.
get_handle_ptr
()
}};
return
{
0
,
{
&
derived
()}};
}
template
<
class
D
=
Derived
>
iterator_t
<
D
>
end
()
const
{
return
{
derived
().
size
(),
{
derived
()
.
get_handle_ptr
()
}};
return
{
derived
().
size
(),
{
&
derived
()}};
}
};
...
...
@@ -200,9 +210,25 @@ struct borrow
{
};
template
<
class
T
>
struct
share
{
share
(
std
::
shared_ptr
<
T
>
p
)
:
ptr
(
std
::
move
(
p
))
{}
template
<
class
U
>
std
::
shared_ptr
<
U
>
alias
(
U
*
p
)
const
{
return
std
::
shared_ptr
<
U
>
{
ptr
,
p
};
}
private:
std
::
shared_ptr
<
T
>
ptr
;
};
template
<
class
Derived
,
class
T
,
class
D
,
D
Deleter
,
class
A
,
A
Assigner
>
struct
handle_base
:
handle_lookup
<
Derived
,
std
::
remove_cv_t
<
T
>>
{
using
handle_type
=
T
;
handle_base
()
:
m_handle
(
nullptr
)
{}
template
<
class
F
,
class
...
Ts
>
void
make_handle
(
F
f
,
Ts
&&
...
xs
)
...
...
@@ -231,6 +257,14 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
m_handle
=
std
::
shared_ptr
<
U
>
{
ptr
,
[](
U
*
)
{}};
}
template
<
class
U
,
class
V
>
void
set_handle
(
U
*
ptr
,
share
<
V
>
b
)
{
m_handle
=
std
::
shared_ptr
<
T
>
{
ptr
,
[
b
](
U
*
)
{}};
}
share
<
T
>
share_handle
()
const
{
return
{
m_handle
};
}
template
<
class
U
>
void
assign_to_handle
(
U
*
x
)
{
...
...
@@ -241,6 +275,17 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
std
::
shared_ptr
<
T
>
m_handle
;
};
// NOLINTNEXTLINE
#define MIGRAPHX_HANDLE_CONSTRUCTOR(name) \
template <class HandleType, \
class Lifetime, \
class = \
typename std::enable_if<std::is_convertible<HandleType*, handle_type*>{}>::type> \
name(HandleType* p, Lifetime lifetime) \
{ \
this->set_handle(p, std::move(lifetime)); \
}
template
<
class
Base
>
struct
interface_base
:
Base
{
...
...
@@ -398,11 +443,10 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
{
shape
()
{}
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
shape
(
const
migraphx_shape
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
shape
(
migraphx_shape
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
shape
(
migraphx_shape
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
shape
);
/// Construct a scalar shape
shape
(
migraphx_shape_datatype_t
type
)
...
...
@@ -479,10 +523,9 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{
argument
()
{}
argument
(
migraphx_argument
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
argument
(
migraphx_argument
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
argument
);
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
argument
(
const
migraphx_argument
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
argument
(
shape
pshape
,
void
*
pbuffer
)
...
...
@@ -494,7 +537,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{
const_migraphx_shape_t
pout
;
call
(
&
migraphx_argument_shape
,
&
pout
,
this
->
get_handle_ptr
());
return
{
pout
};
return
{
pout
,
this
->
share_handle
()
};
}
char
*
data
()
const
...
...
@@ -526,9 +569,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
{
target
()
{}
target
(
migraphx_target
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
target
(
migraphx_target
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
target
);
/// Construct a target from its name
target
(
const
char
*
name
)
{
this
->
make_handle
(
&
migraphx_target_create
,
name
);
}
...
...
@@ -538,15 +579,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{
program_parameter_shapes
()
{}
program_parameter_shapes
(
migraphx_program_parameter_shapes
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
program_parameter_shapes
(
migraphx_program_parameter_shapes
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program_parameter_shapes
);
size_t
size
()
const
{
...
...
@@ -559,7 +592,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{
const_migraphx_shape_t
pout
;
call
(
&
migraphx_program_parameter_shapes_get
,
&
pout
,
this
->
get_handle_ptr
(),
pname
);
return
{
pout
};
return
{
pout
,
this
->
share_handle
()
};
}
std
::
vector
<
const
char
*>
names
()
const
...
...
@@ -576,10 +609,9 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
/// A class to construct the inputs parameters for a program
struct
program_parameters
:
MIGRAPHX_HANDLE_BASE
(
program_parameters
)
{
program_parameters
(
migraphx_program_parameters
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
program_parameters
(
migraphx_program_parameters
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program_parameters
);
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
program_parameters
(
migraphx_program_parameters
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
program_parameters
()
{
this
->
make_handle
(
&
migraphx_program_parameters_create
);
}
...
...
@@ -604,9 +636,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
struct
arguments
:
MIGRAPHX_HANDLE_BASE
(
arguments
),
array_base
<
arguments
>
{
arguments
(
migraphx_arguments
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
arguments
(
migraphx_arguments
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
arguments
);
size_t
size
()
const
{
...
...
@@ -619,27 +649,13 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{
const_migraphx_argument_t
pout
;
call
(
&
migraphx_arguments_get
,
&
pout
,
this
->
get_handle_ptr
(),
pidx
);
return
{
pout
};
return
{
pout
,
this
->
share_handle
()
};
}
struct
iterator_read
{
migraphx_arguments
*
self
;
argument
operator
()(
size_t
pidx
)
const
{
const_migraphx_argument_t
pout
;
call
(
&
migraphx_arguments_get
,
&
pout
,
self
,
pidx
);
return
{
pout
};
}
};
};
struct
shapes
:
MIGRAPHX_HANDLE_BASE
(
shapes
),
array_base
<
shapes
>
{
shapes
(
migraphx_shapes
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
shapes
(
migraphx_shapes
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
shapes
);
size_t
size
()
const
{
...
...
@@ -652,26 +668,13 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{
const_migraphx_shape_t
pout
;
call
(
&
migraphx_shapes_get
,
&
pout
,
this
->
get_handle_ptr
(),
pidx
);
return
{
pout
};
return
{
pout
,
this
->
share_handle
()
};
}
struct
iterator_read
{
migraphx_shapes
*
self
;
shape
operator
()(
size_t
pidx
)
const
{
const_migraphx_shape_t
pout
;
call
(
&
migraphx_shapes_get
,
&
pout
,
self
,
pidx
);
return
{
pout
};
}
};
};
struct
operation
:
MIGRAPHX_HANDLE_BASE
(
operation
)
{
operation
(
migraphx_operation
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
operation
(
migraphx_operation
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
operation
);
template
<
class
...
Ts
>
operation
(
const
char
*
name
,
const
char
*
attributes
=
nullptr
,
Ts
...
xs
)
...
...
@@ -689,15 +692,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
struct
instruction
:
MIGRAPHX_CONST_HANDLE_BASE
(
instruction
)
{
instruction
(
migraphx_instruction
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
instruction
);
};
struct
instructions
:
MIGRAPHX_HANDLE_BASE
(
instructions
)
{
instructions
(
migraphx_instructions
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
instructions
(
migraphx_instructions
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
instructions
);
template
<
class
...
Ts
>
instructions
(
Ts
...
xs
)
...
...
@@ -711,33 +711,36 @@ struct module;
struct
modules
:
MIGRAPHX_HANDLE_BASE
(
modules
)
{
modules
(
migraphx_modules
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
modules
(
migraphx_modules
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
modules
);
template
<
class
...
Ts
>
modules
(
Ts
...
xs
)
{
std
::
array
<
migraphx_module_t
,
sizeof
...(
Ts
)
>
a
=
{
xs
.
mm
...};
std
::
array
<
migraphx_module_t
,
sizeof
...(
Ts
)
>
a
=
{
xs
.
get_handle_ptr
()
...};
this
->
make_handle
(
&
migraphx_modules_create
,
a
.
data
(),
a
.
size
());
}
};
struct
module
{
migraphx_module_t
mm
;
MIGRAPHX_DEPRECATED
(
"Constructor without lifetime annotation is deprecated."
)
module
(
migraphx_module
*
m
)
:
mm
(
std
::
shared_ptr
<
migraphx_module
*>
(),
m
)
{}
module
(
migraphx_module
*
m
,
borrow
)
:
mm
(
std
::
shared_ptr
<
migraphx_module
*>
(),
m
)
{}
module
(
const
migraphx_module_t
&
m
)
:
mm
(
m
)
{}
template
<
class
T
>
module
(
migraphx_module
*
m
,
share
<
T
>
b
)
:
mm
(
b
.
alias
(
m
))
{
}
void
print
()
const
{
call
(
&
migraphx_module_print
,
mm
);
}
void
print
()
const
{
call
(
&
migraphx_module_print
,
mm
.
get
()
);
}
instruction
add_instruction
(
const
migraphx
::
operation
&
op
,
const
migraphx
::
instructions
&
args
)
{
migraphx_instruction_t
op_ins
;
call
(
&
migraphx_module_add_instruction
,
&
op_ins
,
mm
,
mm
.
get
()
,
op
.
get_handle_ptr
(),
args
.
get_handle_ptr
());
return
instruction
(
op_ins
,
own
{});
...
...
@@ -750,7 +753,7 @@ struct module
migraphx_instruction_t
op_ins
;
call
(
&
migraphx_module_add_instruction_with_mod_args
,
&
op_ins
,
mm
,
mm
.
get
()
,
op
.
get_handle_ptr
(),
args
.
get_handle_ptr
(),
module_args
.
get_handle_ptr
());
...
...
@@ -760,30 +763,53 @@ struct module
instruction
add_parameter
(
const
std
::
string
&
name
,
shape
s
)
{
migraphx_instruction_t
param_ins
;
call
(
&
migraphx_module_add_parameter
,
&
param_ins
,
mm
,
name
.
c_str
(),
s
.
get_handle_ptr
());
call
(
&
migraphx_module_add_parameter
,
&
param_ins
,
mm
.
get
(),
name
.
c_str
(),
s
.
get_handle_ptr
());
return
instruction
(
param_ins
,
own
{});
}
instruction
add_return
(
const
migraphx
::
instructions
&
args
)
{
migraphx_instruction_t
ret_ins
;
call
(
&
migraphx_module_add_return
,
&
ret_ins
,
mm
,
args
.
get_handle_ptr
());
call
(
&
migraphx_module_add_return
,
&
ret_ins
,
mm
.
get
()
,
args
.
get_handle_ptr
());
return
instruction
(
ret_ins
,
own
{});
}
migraphx_module_t
get_handle_ptr
()
const
{
return
mm
.
get
();
}
private:
std
::
shared_ptr
<
migraphx_module
>
mm
;
};
struct
context
{
migraphx_context_t
ctx
;
context
(
migraphx_context
*
p
,
borrow
)
:
ctx
(
std
::
shared_ptr
<
migraphx_context
*>
(),
p
)
{}
template
<
class
T
>
context
(
migraphx_context
*
p
,
share
<
T
>
b
)
:
ctx
(
b
.
alias
(
p
))
{
}
void
finish
()
const
{
call
(
&
migraphx_context_finish
,
ctx
.
get
());
}
template
<
class
T
>
T
get_queue
()
{
void
*
out
;
call
(
&
migraphx_context_get_queue
,
&
out
,
ctx
.
get
());
// TODO: check type here
return
reinterpret_cast
<
T
>
(
out
);
}
void
finish
()
const
{
call
(
&
migraphx_context_finish
,
ctx
);
}
private:
std
::
shared_ptr
<
migraphx_context
>
ctx
;
};
struct
compile_options
:
MIGRAPHX_HANDLE_BASE
(
compile_options
)
{
compile_options
()
{
this
->
make_handle
(
&
migraphx_compile_options_create
);
}
compile_options
(
migraphx_compile_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
());
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
compile_options
);
/// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the
...
...
@@ -807,9 +833,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{
program
()
{
this
->
make_handle
(
&
migraphx_program_create
);
}
program
(
migraphx_program
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
program
(
migraphx_program
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program
);
/// Compile the program for a specific target to be ran on
void
compile
(
const
target
&
ptarget
,
const
compile_options
&
poptions
)
const
...
...
@@ -872,21 +896,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{
migraphx_module_t
p_modu
;
call
(
&
migraphx_program_get_main_module
,
&
p_modu
,
this
->
get_handle_ptr
());
return
module
{
p_modu
};
return
module
{
p_modu
,
this
->
share_handle
()
};
}
context
experimental_get_context
()
{
migraphx_context_t
ctx
;
call
(
&
migraphx_program_experimental_get_context
,
&
ctx
,
this
->
get_handle_ptr
());
return
context
{
ctx
};
return
context
{
ctx
,
this
->
share_handle
()
};
}
module
create_module
(
const
std
::
string
&
name
)
{
migraphx_module_t
p_modu
;
call
(
&
migraphx_program_create_module
,
&
p_modu
,
this
->
get_handle_ptr
(),
name
.
data
());
return
module
{
p_modu
};
return
module
{
p_modu
,
this
->
share_handle
()
};
}
friend
bool
operator
!=
(
const
program
&
px
,
const
program
&
py
)
{
return
!
(
px
==
py
);
}
...
...
@@ -895,10 +919,9 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
// options for migraphx file format options
struct
file_options
:
MIGRAPHX_HANDLE_BASE
(
file_options
)
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
file_options
);
file_options
()
{
this
->
make_handle
(
&
migraphx_file_options_create
);
}
file_options
(
migraphx_file_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
());
}
// set file format
void
set_file_format
(
const
char
*
format
)
{
...
...
@@ -938,7 +961,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{
onnx_options
()
{
this
->
make_handle
(
&
migraphx_onnx_options_create
);
}
onnx_options
(
migraphx_onnx_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
onnx_options
);
/// Make onnx parser treat an inputs with a certain dimensions
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
...
...
@@ -1020,7 +1043,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
{
tf_options
()
{
this
->
make_handle
(
&
migraphx_tf_options_create
);
}
tf_options
(
migraphx_tf_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
tf_options
);
/// Make tf parser treat an inputs with a certain dimensions
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
...
...
@@ -1073,7 +1096,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
{
quantize_op_names
()
{
this
->
make_handle
(
&
migraphx_quantize_op_names_create
);
}
quantize_op_names
(
migraphx_quantize_op_names
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
quantize_op_names
);
void
add
(
const
std
::
string
&
name
)
{
...
...
@@ -1098,12 +1121,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
{
quantize_int8_options
()
{
this
->
make_handle
(
&
migraphx_quantize_int8_options_create
);
}
quantize_int8_options
(
migraphx_quantize_int8_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
quantize_int8_options
(
migraphx_quantize_int8_options
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
quantize_int8_options
);
/// Add an operator that should be quantized
void
add_op_name
(
const
std
::
string
&
name
)
...
...
src/api/migraphx.py
View file @
dfa26315
...
...
@@ -403,6 +403,7 @@ api.add_function('migraphx_quantize_int8',
@
auto_handle
(
ref
=
True
)
def
context
(
h
):
h
.
method
(
'finish'
,
const
=
True
)
h
.
method
(
'get_queue'
,
returns
=
'void*'
,
fname
=
'get_queue().unsafe_get'
)
@
api
.
interface
(
'migraphx_experimental_custom_op'
,
...
...
src/include/migraphx/op/gathernd.hpp
0 → 100644
View file @
dfa26315
#ifndef MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
gathernd
{
int
batch_dims
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
batch_dims
,
"batch_dims"
));
}
std
::
string
name
()
const
{
return
"gathernd"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
auto
r
=
inputs
.
front
().
lens
().
size
();
auto
q
=
inputs
.
back
().
lens
().
size
();
auto
k
=
inputs
.
back
().
lens
().
back
();
if
(
k
>
r
-
batch_dims
)
{
MIGRAPHX_THROW
(
"GATHERND: Indices of length "
+
std
::
to_string
(
k
)
+
" cannot be used to access data of rank "
+
std
::
to_string
(
r
-
batch_dims
));
}
auto
indices_lens_iter
=
inputs
.
back
().
lens
().
begin
();
auto
output_lens_size
=
q
+
r
-
k
-
batch_dims
-
1
;
std
::
vector
<
std
::
size_t
>
output_lens
(
output_lens_size
);
std
::
copy
(
indices_lens_iter
,
indices_lens_iter
+
(
q
-
1
),
output_lens
.
begin
());
if
(
k
<
r
-
batch_dims
)
{
auto
data_lens
=
inputs
.
front
().
lens
();
std
::
copy
(
data_lens
.
begin
()
+
batch_dims
+
k
,
data_lens
.
end
(),
output_lens
.
begin
()
+
q
-
1
);
}
shape
output_shape
{
inputs
.
front
().
type
(),
output_lens
};
return
output_shape
;
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
auto
indices_shape
=
indices
.
get_shape
();
auto
indices_shape_lens
=
indices_shape
.
lens
();
auto
data_shape
=
data
.
get_shape
();
auto
data_shape_lens
=
data_shape
.
lens
();
auto
k
=
indices_shape
.
lens
().
back
();
const
auto
num_slice_dims
=
k
;
std
::
size_t
num_slices
=
std
::
accumulate
(
indices_shape_lens
.
begin
(),
indices_shape_lens
.
end
()
-
1
,
1
,
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
slice_size
=
std
::
accumulate
(
data_shape_lens
.
begin
()
+
k
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
num_batches
=
std
::
accumulate
(
data_shape_lens
.
begin
(),
data_shape_lens
.
begin
()
+
batch_dims
,
1
,
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
data_batch_stride
=
std
::
accumulate
(
data_shape_lens
.
begin
()
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
auto
num_slices_per_batch
=
num_slices
/
num_batches
;
std
::
vector
<
std
::
size_t
>
sizes_from_slice_dims
(
num_slice_dims
);
{
auto
running_product
=
slice_size
;
for
(
std
::
size_t
i
=
0
;
i
<
num_slice_dims
;
++
i
)
{
sizes_from_slice_dims
[
num_slice_dims
-
1
-
i
]
=
running_product
;
running_product
*=
data_shape_lens
[
batch_dims
+
num_slice_dims
-
1
-
i
];
}
}
std
::
vector
<
std
::
size_t
>
input_slice_offsets
(
num_slices
);
par_for
(
num_slices
,
[
&
](
const
auto
i
)
{
std
::
size_t
batch_idx
=
i
/
num_slices_per_batch
;
auto
slice_indices
=
indices
.
begin
()
+
(
i
*
num_slice_dims
);
std
::
size_t
relative_slice_offset
=
0
;
for
(
size_t
dim_idx
=
0
;
dim_idx
<
num_slice_dims
;
++
dim_idx
)
{
int64_t
index
=
*
(
slice_indices
+
dim_idx
);
const
std
::
size_t
input_dim_idx
=
batch_dims
+
dim_idx
;
const
auto
input_dim
=
data_shape_lens
[
input_dim_idx
];
if
(
index
<
-
static_cast
<
int64_t
>
(
input_dim
)
or
index
>=
static_cast
<
int64_t
>
(
input_dim
))
MIGRAPHX_THROW
(
"GatherND: index "
+
std
::
to_string
(
index
)
+
" is out of bounds for dim of len "
+
std
::
to_string
(
input_dim
));
if
(
index
<
0
)
index
+=
input_dim
;
relative_slice_offset
+=
index
*
sizes_from_slice_dims
[
dim_idx
];
}
input_slice_offsets
[
i
]
=
(
batch_idx
*
data_batch_stride
)
+
relative_slice_offset
;
});
par_for
(
num_slices
*
slice_size
,
[
&
](
const
auto
i
)
{
auto
slice_offset
=
input_slice_offsets
[
i
/
slice_size
];
output
[
i
]
=
data
[
slice_offset
+
i
%
slice_size
];
});
});
});
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/operators.hpp
View file @
dfa26315
...
...
@@ -35,6 +35,7 @@
#include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/gathernd.hpp>
#include <migraphx/op/get_tuple_elem.hpp>
#include <migraphx/op/greater.hpp>
#include <migraphx/op/gru.hpp>
...
...
src/onnx/parse_generic_op.cpp
View file @
dfa26315
...
...
@@ -28,6 +28,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{
"Flatten"
,
"flatten"
},
{
"Floor"
,
"floor"
},
{
"Gather"
,
"gather"
},
{
"GatherND"
,
"gathernd"
},
{
"Identity"
,
"identity"
},
{
"IsNaN"
,
"isnan"
},
{
"LeakyRelu"
,
"leaky_relu"
},
...
...
src/onnx/parse_reversesequence.cpp
0 → 100644
View file @
dfa26315
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
//! Parser for ReverseSequence ONNX operator.
/*!
Reverses the data along the time axis for the batches along the batch axis.
The sequence lengths can be given to reverse up to the given length for each batch, keeping the
rest of the sequence in the original order. Variable sequence_lens is not supported in this
version of MIGraphX. You can pass the sequence_lens either as a constant node or an attribute. The
batch axis and time axis must be [0, 1] and not the same.
*/
struct
parse_reversesequence
:
op_parser
<
parse_reversesequence
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"ReverseSequence"
}};
}
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
parser
,
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
int
batch_axis
=
1
;
if
(
contains
(
info
.
attributes
,
"batch_axis"
))
{
batch_axis
=
info
.
attributes
.
at
(
"batch_axis"
).
i
();
}
if
(
batch_axis
!=
0
and
batch_axis
!=
1
)
{
MIGRAPHX_THROW
(
"REVERSESEQUENCE: batch axis not 0 or 1"
);
}
int
time_axis
=
0
;
if
(
contains
(
info
.
attributes
,
"time_axis"
))
{
time_axis
=
info
.
attributes
.
at
(
"time_axis"
).
i
();
}
if
(
time_axis
!=
0
and
time_axis
!=
1
)
{
MIGRAPHX_THROW
(
"REVERSESEQUENCE: time axis not 0 or 1"
);
}
if
(
time_axis
==
batch_axis
)
{
MIGRAPHX_THROW
(
"REVERSESEQUENCE: time axis and batch axis are the same"
);
}
auto
input
=
args
[
0
];
auto
input_lens
=
input
->
get_shape
().
lens
();
if
(
input_lens
.
size
()
<
2
)
{
MIGRAPHX_THROW
(
"REVERSESEQUENCE: input tensor must have rank >= 2"
);
}
std
::
vector
<
int64_t
>
sequence_lens
;
if
(
args
.
size
()
==
2
)
{
migraphx
::
argument
seq_lens_arg
=
args
.
back
()
->
eval
();
check_arg_empty
(
seq_lens_arg
,
"REVERSESEQUENCE: cannot handle variable sequence_lens"
);
seq_lens_arg
.
visit
([
&
](
auto
s
)
{
sequence_lens
.
assign
(
s
.
begin
(),
s
.
end
());
});
}
else
if
(
contains
(
info
.
attributes
,
"sequence_lens"
))
{
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"sequence_lens"
));
s
.
visit
([
&
](
auto
v
)
{
sequence_lens
.
assign
(
v
.
begin
(),
v
.
end
());
});
}
auto
batch_size
=
input_lens
[
batch_axis
];
auto
time_size
=
input_lens
[
time_axis
];
// this condition may still work if sequence_len's shape was incorrect
if
(
sequence_lens
.
size
()
!=
batch_size
)
{
MIGRAPHX_THROW
(
"REVERSESEQUENCE: sequence_lens has incorrect shape"
);
}
instruction_ref
ret
;
auto
add_slice
=
[
&
info
,
&
input
,
batch_axis
,
time_axis
](
int
b
,
int
t_start
,
int
t_end
)
{
return
info
.
add_instruction
(
make_op
(
"slice"
,
{{
"axes"
,
{
batch_axis
,
time_axis
}},
{
"starts"
,
{
b
,
t_start
}},
{
"ends"
,
{
b
+
1
,
t_end
}}}),
input
);
};
for
(
int
b
=
0
;
b
<
batch_size
;
++
b
)
{
instruction_ref
s0
;
if
(
sequence_lens
[
b
]
>
1
)
{
s0
=
add_slice
(
b
,
0
,
sequence_lens
[
b
]);
s0
=
info
.
add_instruction
(
make_op
(
"reverse"
,
{{
"axes"
,
{
time_axis
}}}),
s0
);
// if reversed less than whole batch, concat rest of batch
if
(
sequence_lens
[
b
]
<
time_size
)
{
auto
s1
=
add_slice
(
b
,
sequence_lens
[
b
],
time_size
);
s0
=
info
.
add_instruction
(
make_op
(
"concat"
,
{{
"axis"
,
time_axis
}}),
s0
,
s1
);
}
}
else
{
// cases where nothing changes
s0
=
add_slice
(
b
,
0
,
time_size
);
}
if
(
b
==
0
)
{
ret
=
s0
;
}
else
{
ret
=
info
.
add_instruction
(
make_op
(
"concat"
,
{{
"axis"
,
batch_axis
}}),
ret
,
s0
);
}
}
return
ret
;
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/driver/run_op.cpp
View file @
dfa26315
...
...
@@ -20,7 +20,7 @@ struct run_op : action<run_op>
auto
op
=
make_op
(
name
);
if
(
v
.
contains
(
"fields"
))
op
.
from_value
(
v
.
at
(
"fields"
));
double
t
=
time_op
(
ctx
,
op
,
inputs
);
double
t
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
)
);
std
::
cout
<<
op
<<
": "
<<
t
<<
"ms"
<<
std
::
endl
;
}
};
...
...
src/targets/gpu/jit/gathernd.cpp
0 → 100644
View file @
dfa26315
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
// NOLINTNEXTLINE
static
const
char
*
const
gathernd_kernel
=
R"__migraphx__(
#include <migraphx/kernels/gathernd.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void gathernd_kernel(void* in_data, void* in_indices, void* output)
{
make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
auto settings = make_gathernd_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{BATCH_DIMS}));
gathernd(xs..., settings);
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
gathernd_compiler
:
compiler
<
gathernd_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"gathernd"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
hip_compile_options
options
;
auto
out_s
=
inputs
.
back
();
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
out_s
.
elements
()));
options
.
inputs
=
inputs
;
options
.
output
=
out_s
;
options
.
kernel_name
=
"gathernd_kernel"
;
options
.
virtual_inputs
=
inputs
;
// batch_dims
assert
(
v
.
contains
(
"batch_dims"
));
auto
batch_dims
=
v
.
at
(
"batch_dims"
).
to
<
int64_t
>
();
options
.
params
+=
" -DBATCH_DIMS="
+
std
::
to_string
(
batch_dims
);
return
compile_hip_code_object
(
gathernd_kernel
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
op
.
to_value
()));
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/jit/reduce.cpp
View file @
dfa26315
...
...
@@ -30,7 +30,7 @@ __global__ void kernel(void* input_p, void* output_p)
{
make_tensors()(input_p, output_p)([](auto input, auto output) {
simple_reduce(${reduction}, ${init}, input, output, ${read}, ${write});
simple_reduce
<reduce::${algo}>
(${reduction}, ${init}, input, output, ${read}, ${write});
});
}
...
...
@@ -57,6 +57,40 @@ static std::size_t get_reduce_elements(const std::vector<instruction_ref>& input
return
get_reduce_elements
(
to_shapes
(
inputs
));
}
static
std
::
vector
<
std
::
size_t
>
get_reduce_lens
(
const
std
::
vector
<
std
::
size_t
>&
input_lens
,
const
std
::
vector
<
std
::
size_t
>&
output_lens
)
{
std
::
vector
<
std
::
size_t
>
reduce_lens
;
std
::
transform
(
output_lens
.
begin
(),
output_lens
.
end
(),
input_lens
.
begin
(),
std
::
back_inserter
(
reduce_lens
),
[](
auto
x
,
auto
y
)
->
std
::
size_t
{
if
(
x
==
y
)
return
1
;
else
return
y
;
});
return
reduce_lens
;
}
static
std
::
string
get_reduce_algo
(
const
std
::
vector
<
shape
>&
inputs
)
{
auto
rlens
=
get_reduce_lens
(
inputs
.
front
().
lens
(),
inputs
.
back
().
lens
());
const
auto
init
=
std
::
numeric_limits
<
std
::
size_t
>::
max
();
// The minimum stride
auto
min_stride
=
std
::
inner_product
(
rlens
.
begin
(),
rlens
.
end
(),
inputs
.
front
().
strides
().
begin
(),
init
,
[](
auto
x
,
auto
y
)
{
return
std
::
min
(
x
,
y
);
},
[](
auto
len
,
auto
stride
)
{
return
len
==
1
?
init
:
stride
;
});
if
(
min_stride
>
2
)
return
"lane"
;
return
"block"
;
}
struct
reduce_compiler
:
compiler
<
reduce_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
...
...
@@ -68,20 +102,33 @@ struct reduce_compiler : compiler<reduce_compiler>
{
hip_compile_options
options
;
auto
reduce_elements
=
get_reduce_elements
(
inputs
);
auto
block_size
=
compute_block_size
(
reduce_elements
,
256
);
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()
*
block_size
,
256
),
block_size
);
auto
algo
=
v
.
get
(
"algo"
,
get_reduce_algo
(
inputs
));
if
(
algo
==
"block"
)
{
auto
block_size
=
compute_block_size
(
reduce_elements
,
256
);
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()
*
block_size
,
256
),
block_size
);
}
else
if
(
algo
==
"lane"
)
{
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
(),
256
));
}
else
{
MIGRAPHX_THROW
(
"Unknown reduce algo: "
+
algo
);
}
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
options
.
params
=
"-Wno-float-equal"
;
std
::
string
identity
=
"[](auto x) { return x; }"
;
auto
src
=
interpolate_string
(
simple_reduce_kernel
,
{{
"reduction"
,
v
.
at
(
"reduction"
).
to
<
std
::
string
>
()},
{
"init"
,
v
.
get
(
"init"
,
std
::
string
{
"0"
})},
{
"read"
,
v
.
get
(
"read"
,
identity
)},
{
"write"
,
v
.
get
(
"write"
,
identity
)},
{
"algo"
,
algo
},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})}});
options
.
params
+=
"-Wno-float-equal"
;
return
compile_hip_code_object
(
src
,
options
);
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
View file @
dfa26315
...
...
@@ -21,6 +21,16 @@ struct greater
}
};
template
<
class
InputIt
,
class
T
,
class
BinaryOperation
>
constexpr
T
accumulate
(
InputIt
first
,
InputIt
last
,
T
init
,
BinaryOperation
op
)
{
for
(;
first
!=
last
;
++
first
)
{
init
=
op
(
std
::
move
(
init
),
*
first
);
}
return
init
;
}
template
<
class
InputIt
,
class
OutputIt
>
constexpr
OutputIt
copy
(
InputIt
first
,
InputIt
last
,
OutputIt
d_first
)
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/debug.hpp
View file @
dfa26315
...
...
@@ -42,6 +42,32 @@ struct print_buffer
pos
++
;
}
}
template
<
class
T
,
class
=
decltype
(
T
{}
%
10
,
-
T
{}
)>
constexpr
void
append
(
T
i
)
{
if
(
i
<
0
)
{
append
(
'-'
);
i
=
-
i
;
}
char
c
=
(
i
%
10
)
+
'0'
;
if
(
i
>
9
)
append
(
i
/
10
);
append
(
c
);
}
constexpr
void
append
(
const
char
*
str
)
{
if
(
str
==
nullptr
)
return
;
int
i
=
512
;
while
(
*
str
!=
0
and
i
>
0
)
{
append
(
*
str
);
str
++
;
i
--
;
}
}
template
<
size_t
M
>
constexpr
void
append
(
const
char
(
&
array
)[
M
])
...
...
@@ -54,14 +80,36 @@ struct print_buffer
template
<
class
...
Ts
>
__host__
__device__
void
print
(
const
Ts
&
...
xs
)
{
const
auto
size
=
(
sizeof
(
xs
)
+
...);
print_buffer
<
size
>
buffer
;
print_buffer
<
1024
>
buffer
;
swallow
{(
buffer
.
append
(
xs
),
0
)...};
printf
(
"%s"
,
buffer
.
buffer
);
}
}
// namespace debug
struct
source_location
{
int
line
=
__builtin_LINE
();
const
char
*
file
=
__builtin_FILE
();
const
char
*
function
=
__builtin_FUNCTION
();
};
template
<
class
T
>
struct
source_location_capture
{
T
x
;
source_location
loc
;
template
<
class
U
,
class
=
decltype
(
T
(
U
{}
))>
constexpr
source_location_capture
(
U
px
,
source_location
ploc
=
source_location
{})
:
x
(
px
),
loc
(
ploc
)
{
}
constexpr
operator
source_location
()
const
{
return
loc
;
}
constexpr
operator
T
()
const
{
return
x
;
}
};
// noreturn cannot be used on this function because abort in hip is broken
template
<
class
T1
,
class
T2
,
class
T3
,
class
T4
>
MIGRAPHX_HIP_NORETURN
inline
__host__
__device__
void
...
...
@@ -73,20 +121,38 @@ assert_fail(const T1& assertion, const T2& file, const T3& line, const T4& funct
abort
();
}
template
<
class
...
Ts
>
MIGRAPHX_HIP_NORETURN
inline
__host__
__device__
void
assert_fail
(
const
source_location
&
loc
,
Ts
...
xs
)
{
debug
::
print
(
loc
.
file
,
":"
,
loc
.
line
,
": "
,
loc
.
function
,
": error: "
,
xs
...,
"
\n
"
);
abort
();
}
// NOLINTNEXTLINE
#define MIGRAPHX_
CHECK(cond)
\
#define MIGRAPHX_
ASSERT_FAIL(cond, ...)
\
((cond) ? void(0) : [](auto&&... private_migraphx_xs) { \
assert_fail(private_migraphx_xs...); \
}(#cond, __FILE__, MIGRAPHX_STRINGIZE(__LINE__), __PRETTY_FUNCTION__))
}(__VA_ARGS__))
// NOLINTNEXTLINE
#define MIGRAPHX_CHECK(cond) \
MIGRAPHX_ASSERT_FAIL(cond, #cond, __FILE__, __LINE__, __PRETTY_FUNCTION__)
#ifdef MIGRAPHX_DEBUG
// NOLINTNEXTLINE
#define MIGRAPHX_CAPTURE_SOURCE_LOCATION(T) source_location_capture<T>
#define MIGRAPHX_WARN(cond, loc, ...) MIGRAPHX_ASSERT_FAIL(cond, loc, __VA_ARGS__)
#define MIGRAPHX_ASSERT MIGRAPHX_CHECK
#define MIGRAPHX_ASSUME MIGRAPHX_CHECK
#define MIGRAPHX_UNREACHABLE() MIGRAPHX_ASSERT(false)
#else
// NOLINTNEXTLINE
#define MIGRAPHX_CAPTURE_SOURCE_LOCATION(T) T
#define MIGRAPHX_ASSUME __builtin_assume
#define MIGRAPHX_UNREACHABLE __builtin_unreachable
#define MIGRAPHX_ASSERT(cond)
#define MIGRAPHX_WARN(...)
#endif
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
0 → 100644
View file @
dfa26315
#ifndef MIGRAPHX_GUARD_KERNELS_GATHERND_HPP
#define MIGRAPHX_GUARD_KERNELS_GATHERND_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace
migraphx
{
template
<
class
T
>
struct
gathernd_settings
{
T
batch_dims
{};
};
template
<
class
...
Ts
>
constexpr
gathernd_settings
<
Ts
...
>
make_gathernd_settings
(
Ts
...
xs
)
{
return
{
xs
...};
}
template
<
class
T
,
class
U
,
class
V
,
class
Settings
>
__device__
void
gathernd
(
const
T
&
data_t
,
const
U
&
indices_t
,
const
V
&
output_t
,
Settings
s
)
{
auto
ind
=
make_index
();
auto
batch_dims
=
s
.
batch_dims
;
auto
output_shape
=
output_t
.
get_shape
();
auto
indices_shape
=
indices_t
.
get_shape
();
auto
data_shape
=
data_t
.
get_shape
();
auto
indices_shape_lens
=
indices_shape
.
lens
;
auto
data_shape_lens
=
data_shape
.
lens
;
auto
num_slice_dims
=
indices_shape_lens
.
back
();
std
::
size_t
num_slices
=
accumulate
(
indices_shape_lens
.
begin
(),
indices_shape_lens
.
end
()
-
1
,
1
,
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
slice_size
=
accumulate
(
data_shape_lens
.
begin
()
+
num_slice_dims
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
const
std
::
size_t
num_batches
=
accumulate
(
data_shape_lens
.
begin
(),
data_shape_lens
.
begin
()
+
batch_dims
,
1
,
std
::
multiplies
<
std
::
size_t
>
());
const
std
::
size_t
data_batch_stride
=
accumulate
(
data_shape_lens
.
begin
()
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
const
auto
num_slices_per_batch
=
num_slices
/
num_batches
;
ind
.
global_stride
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
const
auto
*
indices_ptr
=
indices_t
.
data
();
const
std
::
size_t
j
=
i
/
slice_size
;
const
std
::
size_t
batch_idx
=
j
/
num_slices_per_batch
;
auto
*
slice_indices
=
indices_ptr
+
(
j
*
num_slice_dims
);
std
::
size_t
relative_slice_offset
=
0
;
for
(
std
::
size_t
idx
=
0
;
idx
<
num_slice_dims
;
++
idx
)
{
int64_t
index
=
slice_indices
[
idx
];
const
std
::
size_t
input_dim_idx
=
batch_dims
+
idx
;
const
auto
input_dim
=
data_shape_lens
[
input_dim_idx
];
assert
(
index
>=
-
static_cast
<
int64_t
>
(
input_dim
)
and
index
<
static_cast
<
int64_t
>
(
input_dim
));
if
(
index
<
0
)
index
+=
input_dim
;
std
::
size_t
size_from_slice_dims
=
accumulate
(
data_shape_lens
.
begin
()
+
batch_dims
+
idx
+
1
,
data_shape_lens
.
begin
()
+
batch_dims
+
num_slice_dims
,
slice_size
,
std
::
multiplies
<
std
::
size_t
>
());
relative_slice_offset
+=
index
*
size_from_slice_dims
;
}
auto
slice_offset
=
(
batch_idx
*
data_batch_stride
)
+
relative_slice_offset
;
output_t
[
i
]
=
data_t
[
slice_offset
+
i
%
slice_size
];
});
}
}
// namespace migraphx
#endif
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
dfa26315
...
...
@@ -124,8 +124,8 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
}
#endif
template
<
class
Input
,
class
T
,
class
Output
>
constexpr
auto
reduce_slice
(
Input
input
,
T
i
,
Output
)
template
<
class
Output
,
class
Input
,
class
T
>
constexpr
auto
reduce_slice
(
Input
input
,
T
i
)
{
constexpr
auto
lens
=
transform
(
get_shape_c
<
Input
>
{}.
lens
,
get_shape_c
<
Output
>
{}.
lens
,
...
...
@@ -136,23 +136,126 @@ constexpr auto reduce_slice(Input input, T i, Output)
});
;
constexpr
auto
s
=
make_shape
(
lens
,
get_shape_c
<
Input
>
{}.
strides
);
MIGRAPHX_ASSERT
((
input
.
get_shape
().
index
(
i
)
+
s
.
element_space
())
<=
input
.
get_shape
().
element_space
());
return
make_tensor_view
(
&
input
[
i
],
s
);
}
template
<
class
Op
,
class
T
,
class
Input
,
class
Output
,
class
ReadInput
,
class
WriteOuput
>
namespace
reduce
{
template
<
class
Slicer
,
class
F
>
constexpr
auto
sliced
(
Slicer
slicer
,
F
f
)
{
return
[
=
](
auto
x
,
auto
...
xs
)
{
// TODO: assert all elements are the same
return
f
(
slicer
(
x
),
slicer
(
xs
)...);
};
}
struct
block
{
template
<
class
Slicer
>
struct
reducer
{
index
idx
;
Slicer
slicer
;
template
<
class
Op
,
class
T
,
class
Read
>
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
{
return
sliced
(
slicer
,
[
=
](
auto
x
,
auto
...
xs
)
{
return
block_reduce
(
idx
,
op
,
init
,
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
return
read
(
x
[
j
],
xs
[
j
]...);
});
});
}
template
<
class
F
>
__device__
void
outer
(
F
f
)
const
{
if
(
idx
.
local
==
0
)
f
();
}
};
template
<
class
Slicer
>
static
__device__
auto
make
(
index
idx
,
Slicer
slicer
)
{
return
reducer
<
Slicer
>
{
idx
,
slicer
};
}
template
<
class
Output
,
class
F
>
static
__device__
void
run
(
F
f
)
{
auto
idx
=
make_index
();
constexpr
auto
nelements
=
get_shape_c
<
Output
>
{}.
elements
();
idx
.
global_stride
(
nelements
*
idx
.
nlocal
(),
[
&
](
auto
i
)
{
const
auto
out_idx
=
get_shape_c
<
Output
>
{}.
multi
(
i
/
idx
.
nlocal
());
f
(
out_idx
,
make
(
idx
,
[
&
](
auto
input
)
{
return
reduce_slice
<
Output
>
(
input
,
out_idx
);
}));
});
}
};
struct
lane
{
template
<
class
Slicer
>
struct
reducer
{
index
idx
;
Slicer
slicer
;
template
<
class
Op
,
class
T
,
class
Read
>
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
{
return
sliced
(
slicer
,
[
=
](
auto
x
,
auto
...
xs
)
{
using
type
=
typename
decltype
(
x
)
::
type
;
type
r
=
init
;
for
(
index_int
j
=
0
;
j
<
x
.
get_shape
().
elements
();
j
++
)
{
r
=
op
(
r
,
read
(
x
[
j
],
xs
[
j
]...));
}
return
r
;
});
}
template
<
class
F
>
__device__
void
outer
(
F
f
)
const
{
f
();
}
};
template
<
class
Slicer
>
static
__device__
auto
make
(
index
idx
,
Slicer
slicer
)
{
return
reducer
<
Slicer
>
{
idx
,
slicer
};
}
template
<
class
Output
,
class
F
>
static
__device__
void
run
(
F
f
)
{
auto
idx
=
make_index
();
constexpr
auto
nelements
=
get_shape_c
<
Output
>
{}.
elements
();
idx
.
global_stride
(
nelements
,
[
&
](
auto
i
)
{
const
auto
out_idx
=
get_shape_c
<
Output
>
{}.
multi
(
i
);
f
(
out_idx
,
make
(
idx
,
[
&
](
auto
input
)
{
return
reduce_slice
<
Output
>
(
input
,
out_idx
);
}));
});
}
};
}
// namespace reduce
template
<
class
Algo
,
class
Op
,
class
T
,
class
Input
,
class
Output
,
class
ReadInput
,
class
WriteOuput
>
__device__
void
simple_reduce
(
Op
op
,
T
init
,
Input
input
,
Output
output
,
ReadInput
read
,
WriteOuput
write
)
{
auto
idx
=
make_index
();
constexpr
auto
nelements
=
get_shape_c
<
Output
>
{}.
elements
();
constexpr
auto
relements
=
get_shape_c
<
Input
>
{}.
elements
()
/
get_shape_c
<
Output
>
{}.
elements
();
idx
.
global_stride
(
nelements
*
idx
.
nlocal
(),
[
&
](
auto
i
)
{
const
auto
out_idx
=
output
.
get_shape
().
multi
(
i
/
idx
.
nlocal
());
auto
rs
=
reduce_slice
(
input
,
out_idx
,
output
);
MIGRAPHX_ASSERT
(
relements
==
rs
.
get_shape
().
elements
());
auto
r
=
block_reduce
(
idx
,
op
,
init
,
relements
,
[
&
](
auto
j
)
{
return
read
(
rs
[
j
]);
});
if
(
idx
.
local
==
0
)
output
[
out_idx
]
=
write
(
r
);
Algo
::
template
run
<
Output
>([
&
](
auto
out_idx
,
auto
r
)
{
auto
x
=
r
.
reduce
(
op
,
init
,
read
)(
input
);
r
.
outer
([
&
]
{
output
[
out_idx
]
=
write
(
x
);
});
});
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp
View file @
dfa26315
...
...
@@ -29,11 +29,23 @@ struct tensor_view
constexpr
Shape
get_shape
()
const
{
return
Shape
{};
}
constexpr
auto
size
()
const
{
return
get_shape
().
elements
();
}
template
<
class
U
>
constexpr
T
&
operator
[](
U
i
)
const
struct
index_to_offset
{
MIGRAPHX_ASSERT
(
get_shape
().
index
(
i
)
<
get_shape
().
element_space
());
return
x
[
get_shape
().
index
(
i
)];
index_int
offset
;
template
<
class
U
>
constexpr
index_to_offset
(
U
i
)
:
offset
(
Shape
{}.
index
(
i
))
{
}
};
constexpr
T
&
operator
[](
MIGRAPHX_CAPTURE_SOURCE_LOCATION
(
index_to_offset
)
i
)
const
{
index_to_offset
ito
=
i
;
MIGRAPHX_WARN
(
ito
.
offset
<
get_shape
().
element_space
(),
i
,
"Out of bounds access at offset: "
,
ito
.
offset
);
return
x
[
ito
.
offset
];
}
constexpr
T
*
data
()
const
{
return
x
;
}
...
...
test/api/CMakeLists.txt
View file @
dfa26315
function
(
add_api_test TEST_NAME TEST_SRC TEST_DIR
)
set
(
NAME test_api_
${
TEST_NAME
}
)
add_executable
(
${
NAME
}
EXCLUDE_FROM_ALL
${
TEST_SRC
}
)
...
...
@@ -10,6 +9,7 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
add_dependencies
(
check
${
NAME
}
)
endfunction
()
add_api_test
(
array_base test_array_base.cpp
${
TEST_ONNX_DIR
}
)
add_api_test
(
assign test_assign.cpp
${
TEST_ONNX_DIR
}
)
add_api_test
(
custom_op test_custom_op.cpp
${
TEST_ONNX_DIR
}
)
add_api_test
(
compile_options test_compile_options.cpp
${
TEST_ONNX_DIR
}
)
...
...
@@ -19,7 +19,8 @@ add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR})
add_api_test
(
save_load test_save_load.cpp
${
TEST_ONNX_DIR
}
)
add_api_test
(
op test_op_construct.cpp
${
TEST_ONNX_DIR
}
)
add_api_test
(
tf_parser test_tf_parser.cpp
${
TEST_TF_DIR
}
)
# GPU-based tests
if
(
MIGRAPHX_ENABLE_GPU
)
add_api_test
(
gpu test_gpu.cpp
${
TEST_ONNX_DIR
}
)
# GPU-based tests
target_link_libraries
(
test_api_gpu migraphx_gpu
)
endif
()
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment