Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
988bf26b
Commit
988bf26b
authored
May 11, 2022
by
turneram
Browse files
Merge remote-tracking branch 'origin/jit-softmax' into HEAD
parents
f99a3036
bb0fff52
Changes
39
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
820 additions
and
150 deletions
+820
-150
.clang-tidy
.clang-tidy
+1
-1
CMakeLists.txt
CMakeLists.txt
+1
-1
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/api/api.cpp
src/api/api.cpp
+16
-3
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+2
-0
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+139
-121
src/api/migraphx.py
src/api/migraphx.py
+1
-0
src/include/migraphx/op/gathernd.hpp
src/include/migraphx/op/gathernd.hpp
+131
-0
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+1
-0
src/onnx/parse_generic_op.cpp
src/onnx/parse_generic_op.cpp
+1
-0
src/targets/gpu/driver/run_op.cpp
src/targets/gpu/driver/run_op.cpp
+1
-1
src/targets/gpu/jit/gathernd.cpp
src/targets/gpu/jit/gathernd.cpp
+75
-0
src/targets/gpu/jit/reduce.cpp
src/targets/gpu/jit/reduce.cpp
+52
-5
src/targets/gpu/jit/softmax.cpp
src/targets/gpu/jit/softmax.cpp
+78
-0
src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
...argets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
+10
-0
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
+8
-0
src/targets/gpu/kernels/include/migraphx/kernels/debug.hpp
src/targets/gpu/kernels/include/migraphx/kernels/debug.hpp
+70
-4
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
...targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
+81
-0
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+150
-13
No files found.
.clang-tidy
View file @
988bf26b
...
@@ -4,7 +4,7 @@ CheckOptions:
...
@@ -4,7 +4,7 @@ CheckOptions:
- key: bugprone-unused-return-value.CheckedFunctions
- key: bugprone-unused-return-value.CheckedFunctions
value: '::std::async;::std::launder;::std::remove;::std::remove_if;::std::unique;::std::unique_ptr::release;::std::basic_string::empty;::std::vector::empty;::std::find;::std::find_if;::std::find_if_not;::std::all_of;::std::any_of;::std::none_of;::std::count;::std::count_if;::std::mismatch;::std::find_end;::std::find_first_of;::std::adjacent_find;::std::search;::std::search_n;::std::nth_element;::std::lower_bound;::std::upper_bound;::std::binary_search;::std::equal_range;::std::max;::std::max_element;::std::min;::std::min_element;::std::minmax;::std::minmax_element;::std::equal;::std::lexicographical_compare;::std::accumulate;::std::inner_product'
value: '::std::async;::std::launder;::std::remove;::std::remove_if;::std::unique;::std::unique_ptr::release;::std::basic_string::empty;::std::vector::empty;::std::find;::std::find_if;::std::find_if_not;::std::all_of;::std::any_of;::std::none_of;::std::count;::std::count_if;::std::mismatch;::std::find_end;::std::find_first_of;::std::adjacent_find;::std::search;::std::search_n;::std::nth_element;::std::lower_bound;::std::upper_bound;::std::binary_search;::std::equal_range;::std::max;::std::max_element;::std::min;::std::min_element;::std::minmax;::std::minmax_element;::std::equal;::std::lexicographical_compare;::std::accumulate;::std::inner_product'
- key: cppcoreguidelines-macro-usage.AllowedRegexp
- key: cppcoreguidelines-macro-usage.AllowedRegexp
value: 'DEBUG|ASSERT|ASSUME|UNREACHABLE|FALLTHROUGH|STRINGIZE|_HAS_|_THROW|_REQUIRES|_DECLARE_|_VISIT_|_REGISTER_|_GENERATE_|_DETAIL_|_TIDY_|_MANAGE_PTR|_MATCHER|DEVICE_SHARED|_WORKAROUND_'
value: 'DEBUG|ASSERT|ASSUME|UNREACHABLE|FALLTHROUGH|
DEPRECATED|
STRINGIZE|_HAS_|_THROW|_REQUIRES|_DECLARE_|_VISIT_|_REGISTER_|_GENERATE_|_DETAIL_|_TIDY_|_MANAGE_PTR|_MATCHER|DEVICE_SHARED|_WORKAROUND_'
- key: modernize-loop-convert.MinConfidence
- key: modernize-loop-convert.MinConfidence
value: risky
value: risky
- key: modernize-loop-convert.NamingStyle
- key: modernize-loop-convert.NamingStyle
...
...
CMakeLists.txt
View file @
988bf26b
...
@@ -42,7 +42,7 @@ find_package(nlohmann_json 3.8.0 REQUIRED)
...
@@ -42,7 +42,7 @@ find_package(nlohmann_json 3.8.0 REQUIRED)
include
(
ROCMSetupVersion
)
include
(
ROCMSetupVersion
)
rocm_setup_version
(
VERSION 2.
2
)
rocm_setup_version
(
VERSION 2.
3
)
set
(
MIGRAPHX_SO_VERSION
${
PROJECT_VERSION_MAJOR
}
.
${
PROJECT_VERSION_MINOR
}
)
set
(
MIGRAPHX_SO_VERSION
${
PROJECT_VERSION_MAJOR
}
.
${
PROJECT_VERSION_MINOR
}
)
option
(
BUILD_SHARED_LIBS
"Build as a shared library"
ON
)
option
(
BUILD_SHARED_LIBS
"Build as a shared library"
ON
)
...
...
src/CMakeLists.txt
View file @
988bf26b
...
@@ -109,6 +109,7 @@ register_migraphx_ops(
...
@@ -109,6 +109,7 @@ register_migraphx_ops(
flatten
flatten
floor
floor
gather
gather
gathernd
get_tuple_elem
get_tuple_elem
greater
greater
gru
gru
...
...
src/api/api.cpp
View file @
988bf26b
...
@@ -401,7 +401,8 @@ extern "C" struct migraphx_instruction;
...
@@ -401,7 +401,8 @@ extern "C" struct migraphx_instruction;
struct
migraphx_instruction
struct
migraphx_instruction
{
{
template
<
class
...
Ts
>
template
<
class
...
Ts
>
migraphx_instruction
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
migraphx_instruction
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
// NOLINT(readability-redundant-member-init)
{
{
}
}
migraphx
::
instruction_ref
object
;
migraphx
::
instruction_ref
object
;
...
@@ -411,7 +412,8 @@ extern "C" struct migraphx_instructions;
...
@@ -411,7 +412,8 @@ extern "C" struct migraphx_instructions;
struct
migraphx_instructions
struct
migraphx_instructions
{
{
template
<
class
...
Ts
>
template
<
class
...
Ts
>
migraphx_instructions
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
migraphx_instructions
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
// NOLINT(readability-redundant-member-init)
{
{
}
}
std
::
vector
<
migraphx
::
instruction_ref
>
object
;
std
::
vector
<
migraphx
::
instruction_ref
>
object
;
...
@@ -421,7 +423,8 @@ extern "C" struct migraphx_modules;
...
@@ -421,7 +423,8 @@ extern "C" struct migraphx_modules;
struct
migraphx_modules
struct
migraphx_modules
{
{
template
<
class
...
Ts
>
template
<
class
...
Ts
>
migraphx_modules
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
migraphx_modules
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
// NOLINT(readability-redundant-member-init)
{
{
}
}
std
::
vector
<
migraphx
::
module
*>
object
;
std
::
vector
<
migraphx
::
module
*>
object
;
...
@@ -1691,6 +1694,16 @@ extern "C" migraphx_status migraphx_context_finish(const_migraphx_context_t cont
...
@@ -1691,6 +1694,16 @@ extern "C" migraphx_status migraphx_context_finish(const_migraphx_context_t cont
return
api_error_result
;
return
api_error_result
;
}
}
extern
"C"
migraphx_status
migraphx_context_get_queue
(
void
**
out
,
migraphx_context_t
context
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
context
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter context: Null pointer"
);
*
out
=
(
context
->
object
).
get_queue
().
unsafe_get
();
});
return
api_error_result
;
}
extern
"C"
migraphx_status
extern
"C"
migraphx_status
migraphx_experimental_custom_op_destroy
(
migraphx_experimental_custom_op_t
experimental_custom_op
)
migraphx_experimental_custom_op_destroy
(
migraphx_experimental_custom_op_t
experimental_custom_op
)
{
{
...
...
src/api/include/migraphx/migraphx.h
View file @
988bf26b
...
@@ -433,6 +433,8 @@ migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
...
@@ -433,6 +433,8 @@ migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_status
migraphx_context_finish
(
const_migraphx_context_t
context
);
migraphx_status
migraphx_context_finish
(
const_migraphx_context_t
context
);
migraphx_status
migraphx_context_get_queue
(
void
**
out
,
migraphx_context_t
context
);
migraphx_status
migraphx_status
migraphx_experimental_custom_op_destroy
(
migraphx_experimental_custom_op_t
experimental_custom_op
);
migraphx_experimental_custom_op_destroy
(
migraphx_experimental_custom_op_t
experimental_custom_op
);
...
...
src/api/include/migraphx/migraphx.hpp
View file @
988bf26b
...
@@ -15,6 +15,16 @@ namespace migraphx {
...
@@ -15,6 +15,16 @@ namespace migraphx {
inline
namespace
api
{
// NOLINT
inline
namespace
api
{
// NOLINT
#endif
#endif
#ifdef __has_cpp_attribute
#if __has_cpp_attribute(deprecated)
#define MIGRAPHX_DEPRECATED(...) [[deprecated(__VA_ARGS__)]]
#endif
#endif
#ifndef MIGRAPHX_DEPRECATED
#define MIGRAPHX_DEPRECATED(...)
#endif
template
<
int
N
>
template
<
int
N
>
struct
rank
:
rank
<
N
-
1
>
struct
rank
:
rank
<
N
-
1
>
{
{
...
@@ -99,34 +109,22 @@ struct iota_iterator
...
@@ -99,34 +109,22 @@ struct iota_iterator
return
it
;
return
it
;
}
}
// TODO: operator->
// TODO: operator->
reference
operator
*
()
const
{
return
(
*
f
)(
index
);
}
reference
operator
*
()
const
{
return
f
(
index
);
}
};
template
<
class
F
,
class
Iterator
>
friend
iota_iterator
operator
+
(
iota_iterator
x
,
iota_iterator
y
)
inline
iota_iterator
<
F
,
Iterator
>
operator
+
(
iota_iterator
<
F
,
Iterator
>
x
,
{
iota_iterator
<
F
,
Iterator
>
y
)
return
iota_iterator
(
x
.
index
+
y
.
index
,
x
.
f
);
{
}
return
iota_iterator
<
F
,
Iterator
>
(
x
.
index
+
y
.
index
,
x
.
f
);
}
template
<
class
F
,
class
Iterator
>
friend
iota_iterator
operator
-
(
iota_iterator
x
,
iota_iterator
y
)
inline
iota_iterator
<
F
,
Iterator
>
operator
-
(
iota_iterator
<
F
,
Iterator
>
x
,
{
iota_iterator
<
F
,
Iterator
>
y
)
return
iota_iterator
(
x
.
index
-
y
.
index
,
x
.
f
);
{
}
return
iota_iterator
<
F
,
Iterator
>
(
x
.
index
-
y
.
index
,
x
.
f
);
}
template
<
class
F
,
class
Iterator
>
friend
bool
operator
==
(
iota_iterator
x
,
iota_iterator
y
)
{
return
x
.
index
==
y
.
index
;
}
inline
bool
operator
==
(
iota_iterator
<
F
,
Iterator
>
x
,
iota_iterator
<
F
,
Iterator
>
y
)
{
return
x
.
index
==
y
.
index
;
}
template
<
class
F
,
class
Iterator
>
friend
bool
operator
!=
(
iota_iterator
x
,
iota_iterator
y
)
{
return
x
.
index
!=
y
.
index
;
}
inline
bool
operator
!=
(
iota_iterator
<
F
,
Iterator
>
x
,
iota_iterator
<
F
,
Iterator
>
y
)
};
{
return
x
.
index
!=
y
.
index
;
}
template
<
class
Derived
>
template
<
class
Derived
>
struct
array_base
struct
array_base
...
@@ -136,8 +134,20 @@ struct array_base
...
@@ -136,8 +134,20 @@ struct array_base
template
<
class
T
>
template
<
class
T
>
using
value_type_t
=
decltype
(
std
::
declval
<
T
>
()[
0
]);
using
value_type_t
=
decltype
(
std
::
declval
<
T
>
()[
0
]);
struct
iterator_read
{
const
Derived
*
self
;
template
<
class
D
=
Derived
>
value_type_t
<
D
>
operator
()(
size_t
pidx
)
const
{
return
(
*
self
)[
pidx
];
}
};
template
<
class
T
>
template
<
class
T
>
using
iterator_t
=
iota_iterator
<
typename
T
::
iterator_read
>
;
using
iterator_t
=
iota_iterator
<
iterator_read
>
;
bool
empty
()
const
{
return
derived
().
size
()
==
0
;
}
template
<
class
D
=
Derived
>
template
<
class
D
=
Derived
>
value_type_t
<
D
>
front
()
const
value_type_t
<
D
>
front
()
const
...
@@ -154,13 +164,13 @@ struct array_base
...
@@ -154,13 +164,13 @@ struct array_base
template
<
class
D
=
Derived
>
template
<
class
D
=
Derived
>
iterator_t
<
D
>
begin
()
const
iterator_t
<
D
>
begin
()
const
{
{
return
{
0
,
{
derived
()
.
get_handle_ptr
()
}};
return
{
0
,
{
&
derived
()}};
}
}
template
<
class
D
=
Derived
>
template
<
class
D
=
Derived
>
iterator_t
<
D
>
end
()
const
iterator_t
<
D
>
end
()
const
{
{
return
{
derived
().
size
(),
{
derived
()
.
get_handle_ptr
()
}};
return
{
derived
().
size
(),
{
&
derived
()}};
}
}
};
};
...
@@ -200,9 +210,25 @@ struct borrow
...
@@ -200,9 +210,25 @@ struct borrow
{
{
};
};
template
<
class
T
>
struct
share
{
share
(
std
::
shared_ptr
<
T
>
p
)
:
ptr
(
std
::
move
(
p
))
{}
template
<
class
U
>
std
::
shared_ptr
<
U
>
alias
(
U
*
p
)
const
{
return
std
::
shared_ptr
<
U
>
{
ptr
,
p
};
}
private:
std
::
shared_ptr
<
T
>
ptr
;
};
template
<
class
Derived
,
class
T
,
class
D
,
D
Deleter
,
class
A
,
A
Assigner
>
template
<
class
Derived
,
class
T
,
class
D
,
D
Deleter
,
class
A
,
A
Assigner
>
struct
handle_base
:
handle_lookup
<
Derived
,
std
::
remove_cv_t
<
T
>>
struct
handle_base
:
handle_lookup
<
Derived
,
std
::
remove_cv_t
<
T
>>
{
{
using
handle_type
=
T
;
handle_base
()
:
m_handle
(
nullptr
)
{}
handle_base
()
:
m_handle
(
nullptr
)
{}
template
<
class
F
,
class
...
Ts
>
template
<
class
F
,
class
...
Ts
>
void
make_handle
(
F
f
,
Ts
&&
...
xs
)
void
make_handle
(
F
f
,
Ts
&&
...
xs
)
...
@@ -231,6 +257,14 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
...
@@ -231,6 +257,14 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
m_handle
=
std
::
shared_ptr
<
U
>
{
ptr
,
[](
U
*
)
{}};
m_handle
=
std
::
shared_ptr
<
U
>
{
ptr
,
[](
U
*
)
{}};
}
}
template
<
class
U
,
class
V
>
void
set_handle
(
U
*
ptr
,
share
<
V
>
b
)
{
m_handle
=
std
::
shared_ptr
<
T
>
{
ptr
,
[
b
](
U
*
)
{}};
}
share
<
T
>
share_handle
()
const
{
return
{
m_handle
};
}
template
<
class
U
>
template
<
class
U
>
void
assign_to_handle
(
U
*
x
)
void
assign_to_handle
(
U
*
x
)
{
{
...
@@ -241,6 +275,17 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
...
@@ -241,6 +275,17 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
std
::
shared_ptr
<
T
>
m_handle
;
std
::
shared_ptr
<
T
>
m_handle
;
};
};
// NOLINTNEXTLINE
#define MIGRAPHX_HANDLE_CONSTRUCTOR(name) \
template <class HandleType, \
class Lifetime, \
class = \
typename std::enable_if<std::is_convertible<HandleType*, handle_type*>{}>::type> \
name(HandleType* p, Lifetime lifetime) \
{ \
this->set_handle(p, std::move(lifetime)); \
}
template
<
class
Base
>
template
<
class
Base
>
struct
interface_base
:
Base
struct
interface_base
:
Base
{
{
...
@@ -398,11 +443,10 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
...
@@ -398,11 +443,10 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
{
{
shape
()
{}
shape
()
{}
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
shape
(
const
migraphx_shape
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
shape
(
const
migraphx_shape
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
shape
(
migraphx_shape
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
shape
);
shape
(
migraphx_shape
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
/// Construct a scalar shape
/// Construct a scalar shape
shape
(
migraphx_shape_datatype_t
type
)
shape
(
migraphx_shape_datatype_t
type
)
...
@@ -479,10 +523,9 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
...
@@ -479,10 +523,9 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{
{
argument
()
{}
argument
()
{}
argument
(
migraphx_argument
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
argument
);
argument
(
migraphx_argument
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
argument
(
const
migraphx_argument
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
argument
(
const
migraphx_argument
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
argument
(
shape
pshape
,
void
*
pbuffer
)
argument
(
shape
pshape
,
void
*
pbuffer
)
...
@@ -494,7 +537,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
...
@@ -494,7 +537,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{
{
const_migraphx_shape_t
pout
;
const_migraphx_shape_t
pout
;
call
(
&
migraphx_argument_shape
,
&
pout
,
this
->
get_handle_ptr
());
call
(
&
migraphx_argument_shape
,
&
pout
,
this
->
get_handle_ptr
());
return
{
pout
};
return
{
pout
,
this
->
share_handle
()
};
}
}
char
*
data
()
const
char
*
data
()
const
...
@@ -526,9 +569,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
...
@@ -526,9 +569,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
{
{
target
()
{}
target
()
{}
target
(
migraphx_target
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
target
);
target
(
migraphx_target
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
/// Construct a target from its name
/// Construct a target from its name
target
(
const
char
*
name
)
{
this
->
make_handle
(
&
migraphx_target_create
,
name
);
}
target
(
const
char
*
name
)
{
this
->
make_handle
(
&
migraphx_target_create
,
name
);
}
...
@@ -538,15 +579,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
...
@@ -538,15 +579,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{
{
program_parameter_shapes
()
{}
program_parameter_shapes
()
{}
program_parameter_shapes
(
migraphx_program_parameter_shapes
*
p
,
own
)
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program_parameter_shapes
);
{
this
->
set_handle
(
p
,
own
{});
}
program_parameter_shapes
(
migraphx_program_parameter_shapes
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
size_t
size
()
const
size_t
size
()
const
{
{
...
@@ -559,7 +592,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
...
@@ -559,7 +592,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{
{
const_migraphx_shape_t
pout
;
const_migraphx_shape_t
pout
;
call
(
&
migraphx_program_parameter_shapes_get
,
&
pout
,
this
->
get_handle_ptr
(),
pname
);
call
(
&
migraphx_program_parameter_shapes_get
,
&
pout
,
this
->
get_handle_ptr
(),
pname
);
return
{
pout
};
return
{
pout
,
this
->
share_handle
()
};
}
}
std
::
vector
<
const
char
*>
names
()
const
std
::
vector
<
const
char
*>
names
()
const
...
@@ -576,10 +609,9 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
...
@@ -576,10 +609,9 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
/// A class to construct the inputs parameters for a program
/// A class to construct the inputs parameters for a program
struct
program_parameters
:
MIGRAPHX_HANDLE_BASE
(
program_parameters
)
struct
program_parameters
:
MIGRAPHX_HANDLE_BASE
(
program_parameters
)
{
{
program_parameters
(
migraphx_program_parameters
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program_parameters
);
program_parameters
(
migraphx_program_parameters
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
program_parameters
(
migraphx_program_parameters
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
program_parameters
(
migraphx_program_parameters
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
program_parameters
()
{
this
->
make_handle
(
&
migraphx_program_parameters_create
);
}
program_parameters
()
{
this
->
make_handle
(
&
migraphx_program_parameters_create
);
}
...
@@ -604,9 +636,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
...
@@ -604,9 +636,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
struct
arguments
:
MIGRAPHX_HANDLE_BASE
(
arguments
),
array_base
<
arguments
>
struct
arguments
:
MIGRAPHX_HANDLE_BASE
(
arguments
),
array_base
<
arguments
>
{
{
arguments
(
migraphx_arguments
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
arguments
);
arguments
(
migraphx_arguments
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
size_t
size
()
const
size_t
size
()
const
{
{
...
@@ -619,27 +649,13 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
...
@@ -619,27 +649,13 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{
{
const_migraphx_argument_t
pout
;
const_migraphx_argument_t
pout
;
call
(
&
migraphx_arguments_get
,
&
pout
,
this
->
get_handle_ptr
(),
pidx
);
call
(
&
migraphx_arguments_get
,
&
pout
,
this
->
get_handle_ptr
(),
pidx
);
return
{
pout
};
return
{
pout
,
this
->
share_handle
()
};
}
}
struct
iterator_read
{
migraphx_arguments
*
self
;
argument
operator
()(
size_t
pidx
)
const
{
const_migraphx_argument_t
pout
;
call
(
&
migraphx_arguments_get
,
&
pout
,
self
,
pidx
);
return
{
pout
};
}
};
};
};
struct
shapes
:
MIGRAPHX_HANDLE_BASE
(
shapes
),
array_base
<
shapes
>
struct
shapes
:
MIGRAPHX_HANDLE_BASE
(
shapes
),
array_base
<
shapes
>
{
{
shapes
(
migraphx_shapes
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
shapes
);
shapes
(
migraphx_shapes
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
size_t
size
()
const
size_t
size
()
const
{
{
...
@@ -652,26 +668,13 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
...
@@ -652,26 +668,13 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{
{
const_migraphx_shape_t
pout
;
const_migraphx_shape_t
pout
;
call
(
&
migraphx_shapes_get
,
&
pout
,
this
->
get_handle_ptr
(),
pidx
);
call
(
&
migraphx_shapes_get
,
&
pout
,
this
->
get_handle_ptr
(),
pidx
);
return
{
pout
};
return
{
pout
,
this
->
share_handle
()
};
}
}
struct
iterator_read
{
migraphx_shapes
*
self
;
shape
operator
()(
size_t
pidx
)
const
{
const_migraphx_shape_t
pout
;
call
(
&
migraphx_shapes_get
,
&
pout
,
self
,
pidx
);
return
{
pout
};
}
};
};
};
struct
operation
:
MIGRAPHX_HANDLE_BASE
(
operation
)
struct
operation
:
MIGRAPHX_HANDLE_BASE
(
operation
)
{
{
operation
(
migraphx_operation
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
operation
);
operation
(
migraphx_operation
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
operation
(
const
char
*
name
,
const
char
*
attributes
=
nullptr
,
Ts
...
xs
)
operation
(
const
char
*
name
,
const
char
*
attributes
=
nullptr
,
Ts
...
xs
)
...
@@ -689,15 +692,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
...
@@ -689,15 +692,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
struct
instruction
:
MIGRAPHX_CONST_HANDLE_BASE
(
instruction
)
struct
instruction
:
MIGRAPHX_CONST_HANDLE_BASE
(
instruction
)
{
{
instruction
(
migraphx_instruction
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
instruction
);
};
};
struct
instructions
:
MIGRAPHX_HANDLE_BASE
(
instructions
)
struct
instructions
:
MIGRAPHX_HANDLE_BASE
(
instructions
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
instructions
);
instructions
(
migraphx_instructions
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
instructions
(
migraphx_instructions
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
instructions
(
Ts
...
xs
)
instructions
(
Ts
...
xs
)
...
@@ -711,33 +711,36 @@ struct module;
...
@@ -711,33 +711,36 @@ struct module;
struct
modules
:
MIGRAPHX_HANDLE_BASE
(
modules
)
struct
modules
:
MIGRAPHX_HANDLE_BASE
(
modules
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
modules
);
modules
(
migraphx_modules
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
modules
(
migraphx_modules
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
modules
(
Ts
...
xs
)
modules
(
Ts
...
xs
)
{
{
std
::
array
<
migraphx_module_t
,
sizeof
...(
Ts
)
>
a
=
{
xs
.
mm
...};
std
::
array
<
migraphx_module_t
,
sizeof
...(
Ts
)
>
a
=
{
xs
.
get_handle_ptr
()
...};
this
->
make_handle
(
&
migraphx_modules_create
,
a
.
data
(),
a
.
size
());
this
->
make_handle
(
&
migraphx_modules_create
,
a
.
data
(),
a
.
size
());
}
}
};
};
struct
module
struct
module
{
{
migraphx_module_t
mm
;
MIGRAPHX_DEPRECATED
(
"Constructor without lifetime annotation is deprecated."
)
module
(
migraphx_module
*
m
)
:
mm
(
std
::
shared_ptr
<
migraphx_module
*>
(),
m
)
{}
module
(
migraphx_module
*
m
,
borrow
)
:
mm
(
std
::
shared_ptr
<
migraphx_module
*>
(),
m
)
{}
module
(
const
migraphx_module_t
&
m
)
:
mm
(
m
)
{}
template
<
class
T
>
module
(
migraphx_module
*
m
,
share
<
T
>
b
)
:
mm
(
b
.
alias
(
m
))
{
}
void
print
()
const
{
call
(
&
migraphx_module_print
,
mm
);
}
void
print
()
const
{
call
(
&
migraphx_module_print
,
mm
.
get
()
);
}
instruction
add_instruction
(
const
migraphx
::
operation
&
op
,
const
migraphx
::
instructions
&
args
)
instruction
add_instruction
(
const
migraphx
::
operation
&
op
,
const
migraphx
::
instructions
&
args
)
{
{
migraphx_instruction_t
op_ins
;
migraphx_instruction_t
op_ins
;
call
(
&
migraphx_module_add_instruction
,
call
(
&
migraphx_module_add_instruction
,
&
op_ins
,
&
op_ins
,
mm
,
mm
.
get
()
,
op
.
get_handle_ptr
(),
op
.
get_handle_ptr
(),
args
.
get_handle_ptr
());
args
.
get_handle_ptr
());
return
instruction
(
op_ins
,
own
{});
return
instruction
(
op_ins
,
own
{});
...
@@ -750,7 +753,7 @@ struct module
...
@@ -750,7 +753,7 @@ struct module
migraphx_instruction_t
op_ins
;
migraphx_instruction_t
op_ins
;
call
(
&
migraphx_module_add_instruction_with_mod_args
,
call
(
&
migraphx_module_add_instruction_with_mod_args
,
&
op_ins
,
&
op_ins
,
mm
,
mm
.
get
()
,
op
.
get_handle_ptr
(),
op
.
get_handle_ptr
(),
args
.
get_handle_ptr
(),
args
.
get_handle_ptr
(),
module_args
.
get_handle_ptr
());
module_args
.
get_handle_ptr
());
...
@@ -760,30 +763,53 @@ struct module
...
@@ -760,30 +763,53 @@ struct module
instruction
add_parameter
(
const
std
::
string
&
name
,
shape
s
)
instruction
add_parameter
(
const
std
::
string
&
name
,
shape
s
)
{
{
migraphx_instruction_t
param_ins
;
migraphx_instruction_t
param_ins
;
call
(
&
migraphx_module_add_parameter
,
&
param_ins
,
mm
,
name
.
c_str
(),
s
.
get_handle_ptr
());
call
(
&
migraphx_module_add_parameter
,
&
param_ins
,
mm
.
get
(),
name
.
c_str
(),
s
.
get_handle_ptr
());
return
instruction
(
param_ins
,
own
{});
return
instruction
(
param_ins
,
own
{});
}
}
instruction
add_return
(
const
migraphx
::
instructions
&
args
)
instruction
add_return
(
const
migraphx
::
instructions
&
args
)
{
{
migraphx_instruction_t
ret_ins
;
migraphx_instruction_t
ret_ins
;
call
(
&
migraphx_module_add_return
,
&
ret_ins
,
mm
,
args
.
get_handle_ptr
());
call
(
&
migraphx_module_add_return
,
&
ret_ins
,
mm
.
get
()
,
args
.
get_handle_ptr
());
return
instruction
(
ret_ins
,
own
{});
return
instruction
(
ret_ins
,
own
{});
}
}
migraphx_module_t
get_handle_ptr
()
const
{
return
mm
.
get
();
}
private:
std
::
shared_ptr
<
migraphx_module
>
mm
;
};
};
struct
context
struct
context
{
{
migraphx_context_t
ctx
;
context
(
migraphx_context
*
p
,
borrow
)
:
ctx
(
std
::
shared_ptr
<
migraphx_context
*>
(),
p
)
{}
template
<
class
T
>
context
(
migraphx_context
*
p
,
share
<
T
>
b
)
:
ctx
(
b
.
alias
(
p
))
{
}
void
finish
()
const
{
call
(
&
migraphx_context_finish
,
ctx
.
get
());
}
template
<
class
T
>
T
get_queue
()
{
void
*
out
;
call
(
&
migraphx_context_get_queue
,
&
out
,
ctx
.
get
());
// TODO: check type here
return
reinterpret_cast
<
T
>
(
out
);
}
void
finish
()
const
{
call
(
&
migraphx_context_finish
,
ctx
);
}
private:
std
::
shared_ptr
<
migraphx_context
>
ctx
;
};
};
struct
compile_options
:
MIGRAPHX_HANDLE_BASE
(
compile_options
)
struct
compile_options
:
MIGRAPHX_HANDLE_BASE
(
compile_options
)
{
{
compile_options
()
{
this
->
make_handle
(
&
migraphx_compile_options_create
);
}
compile_options
()
{
this
->
make_handle
(
&
migraphx_compile_options_create
);
}
compile_options
(
migraphx_compile_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
());
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
compile_options
);
/// For targets with offloaded memory(such as the gpu), this will insert
/// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the
/// instructions during compilation to copy the input parameters to the
...
@@ -807,9 +833,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
...
@@ -807,9 +833,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{
{
program
()
{
this
->
make_handle
(
&
migraphx_program_create
);
}
program
()
{
this
->
make_handle
(
&
migraphx_program_create
);
}
program
(
migraphx_program
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program
);
program
(
migraphx_program
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
/// Compile the program for a specific target to be ran on
/// Compile the program for a specific target to be ran on
void
compile
(
const
target
&
ptarget
,
const
compile_options
&
poptions
)
const
void
compile
(
const
target
&
ptarget
,
const
compile_options
&
poptions
)
const
...
@@ -872,21 +896,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
...
@@ -872,21 +896,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{
{
migraphx_module_t
p_modu
;
migraphx_module_t
p_modu
;
call
(
&
migraphx_program_get_main_module
,
&
p_modu
,
this
->
get_handle_ptr
());
call
(
&
migraphx_program_get_main_module
,
&
p_modu
,
this
->
get_handle_ptr
());
return
module
{
p_modu
};
return
module
{
p_modu
,
this
->
share_handle
()
};
}
}
context
experimental_get_context
()
context
experimental_get_context
()
{
{
migraphx_context_t
ctx
;
migraphx_context_t
ctx
;
call
(
&
migraphx_program_experimental_get_context
,
&
ctx
,
this
->
get_handle_ptr
());
call
(
&
migraphx_program_experimental_get_context
,
&
ctx
,
this
->
get_handle_ptr
());
return
context
{
ctx
};
return
context
{
ctx
,
this
->
share_handle
()
};
}
}
module
create_module
(
const
std
::
string
&
name
)
module
create_module
(
const
std
::
string
&
name
)
{
{
migraphx_module_t
p_modu
;
migraphx_module_t
p_modu
;
call
(
&
migraphx_program_create_module
,
&
p_modu
,
this
->
get_handle_ptr
(),
name
.
data
());
call
(
&
migraphx_program_create_module
,
&
p_modu
,
this
->
get_handle_ptr
(),
name
.
data
());
return
module
{
p_modu
};
return
module
{
p_modu
,
this
->
share_handle
()
};
}
}
friend
bool
operator
!=
(
const
program
&
px
,
const
program
&
py
)
{
return
!
(
px
==
py
);
}
friend
bool
operator
!=
(
const
program
&
px
,
const
program
&
py
)
{
return
!
(
px
==
py
);
}
...
@@ -895,10 +919,9 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
...
@@ -895,10 +919,9 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
// options for migraphx file format options
// options for migraphx file format options
struct
file_options
:
MIGRAPHX_HANDLE_BASE
(
file_options
)
struct
file_options
:
MIGRAPHX_HANDLE_BASE
(
file_options
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
file_options
);
file_options
()
{
this
->
make_handle
(
&
migraphx_file_options_create
);
}
file_options
()
{
this
->
make_handle
(
&
migraphx_file_options_create
);
}
file_options
(
migraphx_file_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
());
}
// set file format
// set file format
void
set_file_format
(
const
char
*
format
)
void
set_file_format
(
const
char
*
format
)
{
{
...
@@ -938,7 +961,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
...
@@ -938,7 +961,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{
{
onnx_options
()
{
this
->
make_handle
(
&
migraphx_onnx_options_create
);
}
onnx_options
()
{
this
->
make_handle
(
&
migraphx_onnx_options_create
);
}
onnx_options
(
migraphx_onnx_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
onnx_options
);
/// Make onnx parser treat an inputs with a certain dimensions
/// Make onnx parser treat an inputs with a certain dimensions
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
...
@@ -1020,7 +1043,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
...
@@ -1020,7 +1043,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
{
{
tf_options
()
{
this
->
make_handle
(
&
migraphx_tf_options_create
);
}
tf_options
()
{
this
->
make_handle
(
&
migraphx_tf_options_create
);
}
tf_options
(
migraphx_tf_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
tf_options
);
/// Make tf parser treat an inputs with a certain dimensions
/// Make tf parser treat an inputs with a certain dimensions
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
...
@@ -1073,7 +1096,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
...
@@ -1073,7 +1096,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
{
{
quantize_op_names
()
{
this
->
make_handle
(
&
migraphx_quantize_op_names_create
);
}
quantize_op_names
()
{
this
->
make_handle
(
&
migraphx_quantize_op_names_create
);
}
quantize_op_names
(
migraphx_quantize_op_names
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
quantize_op_names
);
void
add
(
const
std
::
string
&
name
)
void
add
(
const
std
::
string
&
name
)
{
{
...
@@ -1098,12 +1121,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
...
@@ -1098,12 +1121,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
{
{
quantize_int8_options
()
{
this
->
make_handle
(
&
migraphx_quantize_int8_options_create
);
}
quantize_int8_options
()
{
this
->
make_handle
(
&
migraphx_quantize_int8_options_create
);
}
quantize_int8_options
(
migraphx_quantize_int8_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
quantize_int8_options
);
quantize_int8_options
(
migraphx_quantize_int8_options
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
/// Add an operator that should be quantized
/// Add an operator that should be quantized
void
add_op_name
(
const
std
::
string
&
name
)
void
add_op_name
(
const
std
::
string
&
name
)
...
...
src/api/migraphx.py
View file @
988bf26b
...
@@ -403,6 +403,7 @@ api.add_function('migraphx_quantize_int8',
...
@@ -403,6 +403,7 @@ api.add_function('migraphx_quantize_int8',
@
auto_handle
(
ref
=
True
)
@
auto_handle
(
ref
=
True
)
def
context
(
h
):
def
context
(
h
):
h
.
method
(
'finish'
,
const
=
True
)
h
.
method
(
'finish'
,
const
=
True
)
h
.
method
(
'get_queue'
,
returns
=
'void*'
,
fname
=
'get_queue().unsafe_get'
)
@
api
.
interface
(
'migraphx_experimental_custom_op'
,
@
api
.
interface
(
'migraphx_experimental_custom_op'
,
...
...
src/include/migraphx/op/gathernd.hpp
0 → 100644
View file @
988bf26b
#ifndef MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
gathernd
{
int
batch_dims
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
batch_dims
,
"batch_dims"
));
}
std
::
string
name
()
const
{
return
"gathernd"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
auto
r
=
inputs
.
front
().
lens
().
size
();
auto
q
=
inputs
.
back
().
lens
().
size
();
auto
k
=
inputs
.
back
().
lens
().
back
();
if
(
k
>
r
-
batch_dims
)
{
MIGRAPHX_THROW
(
"GATHERND: Indices of length "
+
std
::
to_string
(
k
)
+
" cannot be used to access data of rank "
+
std
::
to_string
(
r
-
batch_dims
));
}
auto
indices_lens_iter
=
inputs
.
back
().
lens
().
begin
();
auto
output_lens_size
=
q
+
r
-
k
-
batch_dims
-
1
;
std
::
vector
<
std
::
size_t
>
output_lens
(
output_lens_size
);
std
::
copy
(
indices_lens_iter
,
indices_lens_iter
+
(
q
-
1
),
output_lens
.
begin
());
if
(
k
<
r
-
batch_dims
)
{
auto
data_lens
=
inputs
.
front
().
lens
();
std
::
copy
(
data_lens
.
begin
()
+
batch_dims
+
k
,
data_lens
.
end
(),
output_lens
.
begin
()
+
q
-
1
);
}
shape
output_shape
{
inputs
.
front
().
type
(),
output_lens
};
return
output_shape
;
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
auto
indices_shape
=
indices
.
get_shape
();
auto
indices_shape_lens
=
indices_shape
.
lens
();
auto
data_shape
=
data
.
get_shape
();
auto
data_shape_lens
=
data_shape
.
lens
();
auto
k
=
indices_shape
.
lens
().
back
();
const
auto
num_slice_dims
=
k
;
std
::
size_t
num_slices
=
std
::
accumulate
(
indices_shape_lens
.
begin
(),
indices_shape_lens
.
end
()
-
1
,
1
,
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
slice_size
=
std
::
accumulate
(
data_shape_lens
.
begin
()
+
k
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
num_batches
=
std
::
accumulate
(
data_shape_lens
.
begin
(),
data_shape_lens
.
begin
()
+
batch_dims
,
1
,
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
data_batch_stride
=
std
::
accumulate
(
data_shape_lens
.
begin
()
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
auto
num_slices_per_batch
=
num_slices
/
num_batches
;
std
::
vector
<
std
::
size_t
>
sizes_from_slice_dims
(
num_slice_dims
);
{
auto
running_product
=
slice_size
;
for
(
std
::
size_t
i
=
0
;
i
<
num_slice_dims
;
++
i
)
{
sizes_from_slice_dims
[
num_slice_dims
-
1
-
i
]
=
running_product
;
running_product
*=
data_shape_lens
[
batch_dims
+
num_slice_dims
-
1
-
i
];
}
}
std
::
vector
<
std
::
size_t
>
input_slice_offsets
(
num_slices
);
par_for
(
num_slices
,
[
&
](
const
auto
i
)
{
std
::
size_t
batch_idx
=
i
/
num_slices_per_batch
;
auto
slice_indices
=
indices
.
begin
()
+
(
i
*
num_slice_dims
);
std
::
size_t
relative_slice_offset
=
0
;
for
(
size_t
dim_idx
=
0
;
dim_idx
<
num_slice_dims
;
++
dim_idx
)
{
int64_t
index
=
*
(
slice_indices
+
dim_idx
);
const
std
::
size_t
input_dim_idx
=
batch_dims
+
dim_idx
;
const
auto
input_dim
=
data_shape_lens
[
input_dim_idx
];
if
(
index
<
-
static_cast
<
int64_t
>
(
input_dim
)
or
index
>=
static_cast
<
int64_t
>
(
input_dim
))
MIGRAPHX_THROW
(
"GatherND: index "
+
std
::
to_string
(
index
)
+
" is out of bounds for dim of len "
+
std
::
to_string
(
input_dim
));
if
(
index
<
0
)
index
+=
input_dim
;
relative_slice_offset
+=
index
*
sizes_from_slice_dims
[
dim_idx
];
}
input_slice_offsets
[
i
]
=
(
batch_idx
*
data_batch_stride
)
+
relative_slice_offset
;
});
par_for
(
num_slices
*
slice_size
,
[
&
](
const
auto
i
)
{
auto
slice_offset
=
input_slice_offsets
[
i
/
slice_size
];
output
[
i
]
=
data
[
slice_offset
+
i
%
slice_size
];
});
});
});
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/operators.hpp
View file @
988bf26b
...
@@ -35,6 +35,7 @@
...
@@ -35,6 +35,7 @@
#include <migraphx/op/flatten.hpp>
#include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp>
#include <migraphx/op/floor.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/gathernd.hpp>
#include <migraphx/op/get_tuple_elem.hpp>
#include <migraphx/op/get_tuple_elem.hpp>
#include <migraphx/op/greater.hpp>
#include <migraphx/op/greater.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/gru.hpp>
...
...
src/onnx/parse_generic_op.cpp
View file @
988bf26b
...
@@ -28,6 +28,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
...
@@ -28,6 +28,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{
"Flatten"
,
"flatten"
},
{
"Flatten"
,
"flatten"
},
{
"Floor"
,
"floor"
},
{
"Floor"
,
"floor"
},
{
"Gather"
,
"gather"
},
{
"Gather"
,
"gather"
},
{
"GatherND"
,
"gathernd"
},
{
"Identity"
,
"identity"
},
{
"Identity"
,
"identity"
},
{
"IsNaN"
,
"isnan"
},
{
"IsNaN"
,
"isnan"
},
{
"LeakyRelu"
,
"leaky_relu"
},
{
"LeakyRelu"
,
"leaky_relu"
},
...
...
src/targets/gpu/driver/run_op.cpp
View file @
988bf26b
...
@@ -20,7 +20,7 @@ struct run_op : action<run_op>
...
@@ -20,7 +20,7 @@ struct run_op : action<run_op>
auto
op
=
make_op
(
name
);
auto
op
=
make_op
(
name
);
if
(
v
.
contains
(
"fields"
))
if
(
v
.
contains
(
"fields"
))
op
.
from_value
(
v
.
at
(
"fields"
));
op
.
from_value
(
v
.
at
(
"fields"
));
double
t
=
time_op
(
ctx
,
op
,
inputs
);
double
t
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
)
);
std
::
cout
<<
op
<<
": "
<<
t
<<
"ms"
<<
std
::
endl
;
std
::
cout
<<
op
<<
": "
<<
t
<<
"ms"
<<
std
::
endl
;
}
}
};
};
...
...
src/targets/gpu/jit/gathernd.cpp
0 → 100644
View file @
988bf26b
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
// NOLINTNEXTLINE
static
const
char
*
const
gathernd_kernel
=
R"__migraphx__(
#include <migraphx/kernels/gathernd.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void gathernd_kernel(void* in_data, void* in_indices, void* output)
{
make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
auto settings = make_gathernd_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{BATCH_DIMS}));
gathernd(xs..., settings);
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
gathernd_compiler
:
compiler
<
gathernd_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"gathernd"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
hip_compile_options
options
;
auto
out_s
=
inputs
.
back
();
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
out_s
.
elements
()));
options
.
inputs
=
inputs
;
options
.
output
=
out_s
;
options
.
kernel_name
=
"gathernd_kernel"
;
options
.
virtual_inputs
=
inputs
;
// batch_dims
assert
(
v
.
contains
(
"batch_dims"
));
auto
batch_dims
=
v
.
at
(
"batch_dims"
).
to
<
int64_t
>
();
options
.
params
+=
" -DBATCH_DIMS="
+
std
::
to_string
(
batch_dims
);
return
compile_hip_code_object
(
gathernd_kernel
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
op
.
to_value
()));
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/jit/reduce.cpp
View file @
988bf26b
...
@@ -30,7 +30,7 @@ __global__ void kernel(void* input_p, void* output_p)
...
@@ -30,7 +30,7 @@ __global__ void kernel(void* input_p, void* output_p)
{
{
make_tensors()(input_p, output_p)([](auto input, auto output) {
make_tensors()(input_p, output_p)([](auto input, auto output) {
simple_reduce(${reduction}, ${init}, input, output, ${read}, ${write});
simple_reduce
<reduce::${algo}>
(${reduction}, ${init}, input, output, ${read}, ${write});
});
});
}
}
...
@@ -57,6 +57,40 @@ static std::size_t get_reduce_elements(const std::vector<instruction_ref>& input
...
@@ -57,6 +57,40 @@ static std::size_t get_reduce_elements(const std::vector<instruction_ref>& input
return
get_reduce_elements
(
to_shapes
(
inputs
));
return
get_reduce_elements
(
to_shapes
(
inputs
));
}
}
static
std
::
vector
<
std
::
size_t
>
get_reduce_lens
(
const
std
::
vector
<
std
::
size_t
>&
input_lens
,
const
std
::
vector
<
std
::
size_t
>&
output_lens
)
{
std
::
vector
<
std
::
size_t
>
reduce_lens
;
std
::
transform
(
output_lens
.
begin
(),
output_lens
.
end
(),
input_lens
.
begin
(),
std
::
back_inserter
(
reduce_lens
),
[](
auto
x
,
auto
y
)
->
std
::
size_t
{
if
(
x
==
y
)
return
1
;
else
return
y
;
});
return
reduce_lens
;
}
static
std
::
string
get_reduce_algo
(
const
std
::
vector
<
shape
>&
inputs
)
{
auto
rlens
=
get_reduce_lens
(
inputs
.
front
().
lens
(),
inputs
.
back
().
lens
());
const
auto
init
=
std
::
numeric_limits
<
std
::
size_t
>::
max
();
// The minimum stride
auto
min_stride
=
std
::
inner_product
(
rlens
.
begin
(),
rlens
.
end
(),
inputs
.
front
().
strides
().
begin
(),
init
,
[](
auto
x
,
auto
y
)
{
return
std
::
min
(
x
,
y
);
},
[](
auto
len
,
auto
stride
)
{
return
len
==
1
?
init
:
stride
;
});
if
(
min_stride
>
2
)
return
"lane"
;
return
"block"
;
}
struct
reduce_compiler
:
compiler
<
reduce_compiler
>
struct
reduce_compiler
:
compiler
<
reduce_compiler
>
{
{
std
::
vector
<
std
::
string
>
names
()
const
std
::
vector
<
std
::
string
>
names
()
const
...
@@ -68,20 +102,33 @@ struct reduce_compiler : compiler<reduce_compiler>
...
@@ -68,20 +102,33 @@ struct reduce_compiler : compiler<reduce_compiler>
{
{
hip_compile_options
options
;
hip_compile_options
options
;
auto
reduce_elements
=
get_reduce_elements
(
inputs
);
auto
reduce_elements
=
get_reduce_elements
(
inputs
);
auto
block_size
=
compute_block_size
(
reduce_elements
,
256
);
auto
algo
=
v
.
get
(
"algo"
,
get_reduce_algo
(
inputs
));
options
.
set_launch_params
(
if
(
algo
==
"block"
)
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()
*
block_size
,
256
),
block_size
);
{
auto
block_size
=
compute_block_size
(
reduce_elements
,
256
);
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()
*
block_size
,
256
),
block_size
);
}
else
if
(
algo
==
"lane"
)
{
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
(),
256
));
}
else
{
MIGRAPHX_THROW
(
"Unknown reduce algo: "
+
algo
);
}
options
.
inputs
=
inputs
;
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
output
=
inputs
.
back
();
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
options
.
params
=
"-Wno-float-equal"
;
std
::
string
identity
=
"[](auto x) { return x; }"
;
std
::
string
identity
=
"[](auto x) { return x; }"
;
auto
src
=
interpolate_string
(
simple_reduce_kernel
,
auto
src
=
interpolate_string
(
simple_reduce_kernel
,
{{
"reduction"
,
v
.
at
(
"reduction"
).
to
<
std
::
string
>
()},
{{
"reduction"
,
v
.
at
(
"reduction"
).
to
<
std
::
string
>
()},
{
"init"
,
v
.
get
(
"init"
,
std
::
string
{
"0"
})},
{
"init"
,
v
.
get
(
"init"
,
std
::
string
{
"0"
})},
{
"read"
,
v
.
get
(
"read"
,
identity
)},
{
"read"
,
v
.
get
(
"read"
,
identity
)},
{
"write"
,
v
.
get
(
"write"
,
identity
)},
{
"write"
,
v
.
get
(
"write"
,
identity
)},
{
"algo"
,
algo
},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})}});
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})}});
options
.
params
+=
"-Wno-float-equal"
;
return
compile_hip_code_object
(
src
,
options
);
return
compile_hip_code_object
(
src
,
options
);
}
}
...
...
src/targets/gpu/jit/softmax.cpp
0 → 100644
View file @
988bf26b
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
static
const
char
*
const
softmax_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/softmax.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void softmax_kernel(void* input_p, void* output_p)
{
make_tensors()(input_p, output_p)([](auto input, auto output) {
softmax<${axis}>(input, output);
});
}
}
} // namespace migraphx
)__migraphx__"
;
constexpr
std
::
size_t
compute_block_size
(
std
::
size_t
n
,
std
::
size_t
max_block_size
=
1024
)
{
size_t
block_size
=
128
;
while
(
block_size
<=
max_block_size
and
block_size
<=
n
)
block_size
*=
2
;
return
block_size
/
2
;
}
struct
softmax_compiler
:
compiler
<
softmax_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"softmax"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
auto
axis
=
v
.
at
(
"axis"
).
to
<
int64_t
>
();
auto
relements
=
inputs
[
0
].
lens
()[
axis
];
auto
nelements
=
inputs
.
back
().
elements
()
/
relements
;
auto
block_size
=
compute_block_size
(
relements
,
256
);
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
options
.
output
=
inputs
.
back
();
options
.
inputs
=
inputs
;
options
.
kernel_name
=
"softmax_kernel"
;
auto
src
=
interpolate_string
(
softmax_kernel
,
{{
"axis"
,
to_string
(
axis
)}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
op
.
to_value
()));
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
View file @
988bf26b
...
@@ -21,6 +21,16 @@ struct greater
...
@@ -21,6 +21,16 @@ struct greater
}
}
};
};
template
<
class
InputIt
,
class
T
,
class
BinaryOperation
>
constexpr
T
accumulate
(
InputIt
first
,
InputIt
last
,
T
init
,
BinaryOperation
op
)
{
for
(;
first
!=
last
;
++
first
)
{
init
=
op
(
std
::
move
(
init
),
*
first
);
}
return
init
;
}
template
<
class
InputIt
,
class
OutputIt
>
template
<
class
InputIt
,
class
OutputIt
>
constexpr
OutputIt
copy
(
InputIt
first
,
InputIt
last
,
OutputIt
d_first
)
constexpr
OutputIt
copy
(
InputIt
first
,
InputIt
last
,
OutputIt
d_first
)
{
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
View file @
988bf26b
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/debug.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -190,6 +191,13 @@ constexpr auto transform(integral_const_array<T, Xs...>, F f)
...
@@ -190,6 +191,13 @@ constexpr auto transform(integral_const_array<T, Xs...>, F f)
return
integral_const_array
<
T
,
f
(
Xs
)...
>
{};
return
integral_const_array
<
T
,
f
(
Xs
)...
>
{};
}
}
template
<
class
T
,
T
...
Xs
,
class
F
>
constexpr
auto
transform_i
(
integral_const_array
<
T
,
Xs
...
>
,
F
f
)
{
return
sequence_c
<
sizeof
...(
Xs
)
>
(
[
=
](
auto
...
is
)
{
return
integral_const_array
<
T
,
f
(
Xs
,
is
)...
>
{};
});
}
template
<
class
T
,
T
...
Xs
,
class
U
,
U
...
Ys
,
class
F
>
template
<
class
T
,
T
...
Xs
,
class
U
,
U
...
Ys
,
class
F
>
constexpr
auto
transform
(
integral_const_array
<
T
,
Xs
...
>
,
integral_const_array
<
U
,
Ys
...
>
,
F
f
)
constexpr
auto
transform
(
integral_const_array
<
T
,
Xs
...
>
,
integral_const_array
<
U
,
Ys
...
>
,
F
f
)
{
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/debug.hpp
View file @
988bf26b
...
@@ -42,6 +42,32 @@ struct print_buffer
...
@@ -42,6 +42,32 @@ struct print_buffer
pos
++
;
pos
++
;
}
}
}
}
template
<
class
T
,
class
=
decltype
(
T
{}
%
10
,
-
T
{}
)>
constexpr
void
append
(
T
i
)
{
if
(
i
<
0
)
{
append
(
'-'
);
i
=
-
i
;
}
char
c
=
(
i
%
10
)
+
'0'
;
if
(
i
>
9
)
append
(
i
/
10
);
append
(
c
);
}
constexpr
void
append
(
const
char
*
str
)
{
if
(
str
==
nullptr
)
return
;
int
i
=
512
;
while
(
*
str
!=
0
and
i
>
0
)
{
append
(
*
str
);
str
++
;
i
--
;
}
}
template
<
size_t
M
>
template
<
size_t
M
>
constexpr
void
append
(
const
char
(
&
array
)[
M
])
constexpr
void
append
(
const
char
(
&
array
)[
M
])
...
@@ -54,14 +80,36 @@ struct print_buffer
...
@@ -54,14 +80,36 @@ struct print_buffer
template
<
class
...
Ts
>
template
<
class
...
Ts
>
__host__
__device__
void
print
(
const
Ts
&
...
xs
)
__host__
__device__
void
print
(
const
Ts
&
...
xs
)
{
{
const
auto
size
=
(
sizeof
(
xs
)
+
...);
print_buffer
<
1024
>
buffer
;
print_buffer
<
size
>
buffer
;
swallow
{(
buffer
.
append
(
xs
),
0
)...};
swallow
{(
buffer
.
append
(
xs
),
0
)...};
printf
(
"%s"
,
buffer
.
buffer
);
printf
(
"%s"
,
buffer
.
buffer
);
}
}
}
// namespace debug
}
// namespace debug
struct
source_location
{
int
line
=
__builtin_LINE
();
const
char
*
file
=
__builtin_FILE
();
const
char
*
function
=
__builtin_FUNCTION
();
};
template
<
class
T
>
struct
source_location_capture
{
T
x
;
source_location
loc
;
template
<
class
U
,
class
=
decltype
(
T
(
U
{}
))>
constexpr
source_location_capture
(
U
px
,
source_location
ploc
=
source_location
{})
:
x
(
px
),
loc
(
ploc
)
{
}
constexpr
operator
source_location
()
const
{
return
loc
;
}
constexpr
operator
T
()
const
{
return
x
;
}
};
// noreturn cannot be used on this function because abort in hip is broken
// noreturn cannot be used on this function because abort in hip is broken
template
<
class
T1
,
class
T2
,
class
T3
,
class
T4
>
template
<
class
T1
,
class
T2
,
class
T3
,
class
T4
>
MIGRAPHX_HIP_NORETURN
inline
__host__
__device__
void
MIGRAPHX_HIP_NORETURN
inline
__host__
__device__
void
...
@@ -73,20 +121,38 @@ assert_fail(const T1& assertion, const T2& file, const T3& line, const T4& funct
...
@@ -73,20 +121,38 @@ assert_fail(const T1& assertion, const T2& file, const T3& line, const T4& funct
abort
();
abort
();
}
}
template
<
class
...
Ts
>
MIGRAPHX_HIP_NORETURN
inline
__host__
__device__
void
assert_fail
(
const
source_location
&
loc
,
Ts
...
xs
)
{
debug
::
print
(
loc
.
file
,
":"
,
loc
.
line
,
": "
,
loc
.
function
,
": error: "
,
xs
...,
"
\n
"
);
abort
();
}
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_
CHECK(cond)
\
#define MIGRAPHX_
ASSERT_FAIL(cond, ...)
\
((cond) ? void(0) : [](auto&&... private_migraphx_xs) { \
((cond) ? void(0) : [](auto&&... private_migraphx_xs) { \
assert_fail(private_migraphx_xs...); \
assert_fail(private_migraphx_xs...); \
}(#cond, __FILE__, MIGRAPHX_STRINGIZE(__LINE__), __PRETTY_FUNCTION__))
}(__VA_ARGS__))
// NOLINTNEXTLINE
#define MIGRAPHX_CHECK(cond) \
MIGRAPHX_ASSERT_FAIL(cond, #cond, __FILE__, __LINE__, __PRETTY_FUNCTION__)
#ifdef MIGRAPHX_DEBUG
#ifdef MIGRAPHX_DEBUG
// NOLINTNEXTLINE
#define MIGRAPHX_CAPTURE_SOURCE_LOCATION(T) source_location_capture<T>
#define MIGRAPHX_WARN(cond, loc, ...) MIGRAPHX_ASSERT_FAIL(cond, loc, __VA_ARGS__)
#define MIGRAPHX_ASSERT MIGRAPHX_CHECK
#define MIGRAPHX_ASSERT MIGRAPHX_CHECK
#define MIGRAPHX_ASSUME MIGRAPHX_CHECK
#define MIGRAPHX_ASSUME MIGRAPHX_CHECK
#define MIGRAPHX_UNREACHABLE() MIGRAPHX_ASSERT(false)
#define MIGRAPHX_UNREACHABLE() MIGRAPHX_ASSERT(false)
#else
#else
// NOLINTNEXTLINE
#define MIGRAPHX_CAPTURE_SOURCE_LOCATION(T) T
#define MIGRAPHX_ASSUME __builtin_assume
#define MIGRAPHX_ASSUME __builtin_assume
#define MIGRAPHX_UNREACHABLE __builtin_unreachable
#define MIGRAPHX_UNREACHABLE __builtin_unreachable
#define MIGRAPHX_ASSERT(cond)
#define MIGRAPHX_ASSERT(cond)
#define MIGRAPHX_WARN(...)
#endif
#endif
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
View file @
988bf26b
#ifndef MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#define MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#define MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#include <migraphx/kernels/
array
.hpp>
#include <migraphx/kernels/
integral_constant
.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
0 → 100644
View file @
988bf26b
#ifndef MIGRAPHX_GUARD_KERNELS_GATHERND_HPP
#define MIGRAPHX_GUARD_KERNELS_GATHERND_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace
migraphx
{
template
<
class
T
>
struct
gathernd_settings
{
T
batch_dims
{};
};
template
<
class
...
Ts
>
constexpr
gathernd_settings
<
Ts
...
>
make_gathernd_settings
(
Ts
...
xs
)
{
return
{
xs
...};
}
template
<
class
T
,
class
U
,
class
V
,
class
Settings
>
__device__
void
gathernd
(
const
T
&
data_t
,
const
U
&
indices_t
,
const
V
&
output_t
,
Settings
s
)
{
auto
ind
=
make_index
();
auto
batch_dims
=
s
.
batch_dims
;
auto
output_shape
=
output_t
.
get_shape
();
auto
indices_shape
=
indices_t
.
get_shape
();
auto
data_shape
=
data_t
.
get_shape
();
auto
indices_shape_lens
=
indices_shape
.
lens
;
auto
data_shape_lens
=
data_shape
.
lens
;
auto
num_slice_dims
=
indices_shape_lens
.
back
();
std
::
size_t
num_slices
=
accumulate
(
indices_shape_lens
.
begin
(),
indices_shape_lens
.
end
()
-
1
,
1
,
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
slice_size
=
accumulate
(
data_shape_lens
.
begin
()
+
num_slice_dims
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
const
std
::
size_t
num_batches
=
accumulate
(
data_shape_lens
.
begin
(),
data_shape_lens
.
begin
()
+
batch_dims
,
1
,
std
::
multiplies
<
std
::
size_t
>
());
const
std
::
size_t
data_batch_stride
=
accumulate
(
data_shape_lens
.
begin
()
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
const
auto
num_slices_per_batch
=
num_slices
/
num_batches
;
ind
.
global_stride
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
const
auto
*
indices_ptr
=
indices_t
.
data
();
const
std
::
size_t
j
=
i
/
slice_size
;
const
std
::
size_t
batch_idx
=
j
/
num_slices_per_batch
;
auto
*
slice_indices
=
indices_ptr
+
(
j
*
num_slice_dims
);
std
::
size_t
relative_slice_offset
=
0
;
for
(
std
::
size_t
idx
=
0
;
idx
<
num_slice_dims
;
++
idx
)
{
int64_t
index
=
slice_indices
[
idx
];
const
std
::
size_t
input_dim_idx
=
batch_dims
+
idx
;
const
auto
input_dim
=
data_shape_lens
[
input_dim_idx
];
assert
(
index
>=
-
static_cast
<
int64_t
>
(
input_dim
)
and
index
<
static_cast
<
int64_t
>
(
input_dim
));
if
(
index
<
0
)
index
+=
input_dim
;
std
::
size_t
size_from_slice_dims
=
accumulate
(
data_shape_lens
.
begin
()
+
batch_dims
+
idx
+
1
,
data_shape_lens
.
begin
()
+
batch_dims
+
num_slice_dims
,
slice_size
,
std
::
multiplies
<
std
::
size_t
>
());
relative_slice_offset
+=
index
*
size_from_slice_dims
;
}
auto
slice_offset
=
(
batch_idx
*
data_batch_stride
)
+
relative_slice_offset
;
output_t
[
i
]
=
data_t
[
slice_offset
+
i
%
slice_size
];
});
}
}
// namespace migraphx
#endif
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
988bf26b
...
@@ -124,8 +124,8 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
...
@@ -124,8 +124,8 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
}
}
#endif
#endif
template
<
class
Input
,
class
T
,
class
Output
>
template
<
class
Output
,
class
Input
,
class
T
>
constexpr
auto
reduce_slice
(
Input
input
,
T
i
,
Output
)
constexpr
auto
reduce_slice
(
Input
input
,
T
i
)
{
{
constexpr
auto
lens
=
transform
(
get_shape_c
<
Input
>
{}.
lens
,
constexpr
auto
lens
=
transform
(
get_shape_c
<
Input
>
{}.
lens
,
get_shape_c
<
Output
>
{}.
lens
,
get_shape_c
<
Output
>
{}.
lens
,
...
@@ -136,23 +136,160 @@ constexpr auto reduce_slice(Input input, T i, Output)
...
@@ -136,23 +136,160 @@ constexpr auto reduce_slice(Input input, T i, Output)
});
});
;
;
constexpr
auto
s
=
make_shape
(
lens
,
get_shape_c
<
Input
>
{}.
strides
);
constexpr
auto
s
=
make_shape
(
lens
,
get_shape_c
<
Input
>
{}.
strides
);
MIGRAPHX_ASSERT
((
input
.
get_shape
().
index
(
i
)
+
s
.
element_space
())
<=
input
.
get_shape
().
element_space
());
return
make_tensor_view
(
&
input
[
i
],
s
);
return
make_tensor_view
(
&
input
[
i
],
s
);
}
}
template
<
class
Op
,
class
T
,
class
Input
,
class
Output
,
class
ReadInput
,
class
WriteOuput
>
namespace
reduce
{
template
<
class
Slicer
,
class
F
>
constexpr
auto
sliced
(
Slicer
slicer
,
F
f
)
{
return
[
=
](
auto
x
,
auto
...
xs
)
{
// TODO: assert all elements are the same
return
f
(
slicer
(
x
),
slicer
(
xs
)...);
};
}
template
<
class
Input
,
index_int
Axis
>
constexpr
auto
compute_reduce_axis
()
{
constexpr
auto
lens
=
transform_i
(
get_shape_c
<
Input
>
{}.
lens
,
[](
index_int
x
,
index_int
i
)
->
index_int
{
if
(
i
==
Axis
)
return
1
;
return
x
;
});
return
make_shape
(
lens
,
get_shape_c
<
Input
>
{}.
strides
);
}
template
<
class
Input
,
index_int
Axis
>
using
with_axis
=
decltype
(
compute_reduce_axis
<
Input
,
Axis
>
());
struct
block
{
template
<
class
Slicer
>
struct
reducer
{
index
idx
;
Slicer
slicer
;
template
<
class
Op
,
class
T
,
class
Read
>
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
{
return
sliced
(
slicer
,
[
=
](
auto
x
,
auto
...
xs
)
{
return
block_reduce
(
idx
,
op
,
init
,
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
return
read
(
x
[
j
],
xs
[
j
]...);
});
});
}
template
<
class
F
>
__device__
void
outer
(
F
f
)
const
{
if
(
idx
.
local
==
0
)
f
();
}
template
<
class
F
>
__device__
auto
inner
(
F
f
)
const
{
return
sliced
(
slicer
,
[
=
](
auto
x
,
auto
...
xs
)
{
idx
.
local_stride
(
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
f
(
x
[
j
],
xs
[
j
]...);
});
});
}
};
template
<
class
Slicer
>
static
__device__
auto
make
(
index
idx
,
Slicer
slicer
)
{
return
reducer
<
Slicer
>
{
idx
,
slicer
};
}
template
<
class
Output
,
class
F
>
static
__device__
void
run
(
F
f
)
{
auto
idx
=
make_index
();
constexpr
auto
nelements
=
get_shape_c
<
Output
>
{}.
elements
();
idx
.
global_stride
(
nelements
*
idx
.
nlocal
(),
[
&
](
auto
i
)
{
const
auto
out_idx
=
get_shape_c
<
Output
>
{}.
multi
(
i
/
idx
.
nlocal
());
f
(
out_idx
,
make
(
idx
,
[
&
](
auto
input
)
{
return
reduce_slice
<
Output
>
(
input
,
out_idx
);
}));
});
}
};
struct
lane
{
template
<
class
Slicer
>
struct
reducer
{
index
idx
;
Slicer
slicer
;
template
<
class
Op
,
class
T
,
class
Read
>
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
{
return
sliced
(
slicer
,
[
=
](
auto
x
,
auto
...
xs
)
{
using
type
=
typename
decltype
(
x
)
::
type
;
type
r
=
init
;
for
(
index_int
j
=
0
;
j
<
x
.
get_shape
().
elements
();
j
++
)
{
r
=
op
(
r
,
read
(
x
[
j
],
xs
[
j
]...));
}
return
r
;
});
}
template
<
class
F
>
__device__
void
outer
(
F
f
)
const
{
f
();
}
template
<
class
F
>
__device__
auto
inner
(
F
f
)
const
{
return
sliced
(
slicer
,
[
=
](
auto
x
,
auto
...
xs
)
{
for
(
index_int
j
=
0
;
j
<
x
.
get_shape
().
elements
();
j
++
)
{
f
(
x
[
j
],
xs
[
j
]...);
}
});
}
};
template
<
class
Slicer
>
static
__device__
auto
make
(
index
idx
,
Slicer
slicer
)
{
return
reducer
<
Slicer
>
{
idx
,
slicer
};
}
template
<
class
Output
,
class
F
>
static
__device__
void
run
(
F
f
)
{
auto
idx
=
make_index
();
constexpr
auto
nelements
=
get_shape_c
<
Output
>
{}.
elements
();
idx
.
global_stride
(
nelements
,
[
&
](
auto
i
)
{
const
auto
out_idx
=
get_shape_c
<
Output
>
{}.
multi
(
i
);
f
(
out_idx
,
make
(
idx
,
[
&
](
auto
input
)
{
return
reduce_slice
<
Output
>
(
input
,
out_idx
);
}));
});
}
};
}
// namespace reduce
template
<
class
Algo
,
class
Op
,
class
T
,
class
Input
,
class
Output
,
class
ReadInput
,
class
WriteOuput
>
__device__
void
__device__
void
simple_reduce
(
Op
op
,
T
init
,
Input
input
,
Output
output
,
ReadInput
read
,
WriteOuput
write
)
simple_reduce
(
Op
op
,
T
init
,
Input
input
,
Output
output
,
ReadInput
read
,
WriteOuput
write
)
{
{
auto
idx
=
make_index
();
Algo
::
template
run
<
Output
>([
&
](
auto
out_idx
,
auto
r
)
{
constexpr
auto
nelements
=
get_shape_c
<
Output
>
{}.
elements
();
auto
x
=
r
.
reduce
(
op
,
init
,
read
)(
input
);
constexpr
auto
relements
=
get_shape_c
<
Input
>
{}.
elements
()
/
get_shape_c
<
Output
>
{}.
elements
();
r
.
outer
([
&
]
{
output
[
out_idx
]
=
write
(
x
);
});
idx
.
global_stride
(
nelements
*
idx
.
nlocal
(),
[
&
](
auto
i
)
{
const
auto
out_idx
=
output
.
get_shape
().
multi
(
i
/
idx
.
nlocal
());
auto
rs
=
reduce_slice
(
input
,
out_idx
,
output
);
MIGRAPHX_ASSERT
(
relements
==
rs
.
get_shape
().
elements
());
auto
r
=
block_reduce
(
idx
,
op
,
init
,
relements
,
[
&
](
auto
j
)
{
return
read
(
rs
[
j
]);
});
if
(
idx
.
local
==
0
)
output
[
out_idx
]
=
write
(
r
);
});
});
}
}
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment