Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c722117d
Unverified
Commit
c722117d
authored
Jul 22, 2022
by
Umang Yadav
Committed by
GitHub
Jul 22, 2022
Browse files
Improve error reporting in the API (#1274)
C++ API is not printing thrown exception string. this improves on it.
parent
6e6cb994
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
270 additions
and
77 deletions
+270
-77
src/api/api.cpp
src/api/api.cpp
+83
-34
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+7
-0
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+65
-5
src/api/migraphx.py
src/api/migraphx.py
+3
-1
test/api/test_custom_op.cpp
test/api/test_custom_op.cpp
+32
-4
test/api/test_custom_op_gpu.cpp
test/api/test_custom_op_gpu.cpp
+18
-4
tools/api.py
tools/api.py
+31
-11
tools/api/api.cpp
tools/api/api.cpp
+31
-18
No files found.
src/api/api.cpp
View file @
c722117d
...
@@ -39,34 +39,47 @@
...
@@ -39,34 +39,47 @@
#include <migraphx/convert_to_json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <algorithm>
#include <algorithm>
#include <cstdarg>
#include <cstdarg>
namespace
migraphx
{
namespace
migraphx
{
static
thread_local
bool
disable_exception_catch
=
false
;
// NOLINT
extern
"C"
void
migraphx_test_private_disable_exception_catch
(
bool
b
)
{
disable_exception_catch
=
b
;
}
template
<
class
F
>
template
<
class
F
>
migraphx_status
try_
(
F
f
,
bool
output
=
true
)
// NOLINT
migraphx_status
try_
(
F
f
,
bool
output
=
true
)
// NOLINT
{
{
try
if
(
disable_exception_catch
)
{
{
f
();
f
();
}
}
catch
(
const
migraphx
::
exception
&
ex
)
else
{
{
if
(
output
)
try
std
::
cerr
<<
"MIGraphX Error: "
<<
ex
.
what
()
<<
std
::
endl
;
{
if
(
ex
.
error
>
0
)
f
();
return
migraphx_status
(
ex
.
error
);
}
else
catch
(
const
migraphx
::
exception
&
ex
)
{
if
(
output
)
std
::
cerr
<<
"MIGraphX Error: "
<<
ex
.
what
()
<<
std
::
endl
;
if
(
ex
.
error
>
0
)
return
migraphx_status
(
ex
.
error
);
else
return
migraphx_status_unknown_error
;
}
catch
(
const
std
::
exception
&
ex
)
{
if
(
output
)
std
::
cerr
<<
"MIGraphX Error: "
<<
ex
.
what
()
<<
std
::
endl
;
return
migraphx_status_unknown_error
;
return
migraphx_status_unknown_error
;
}
}
catch
(
const
std
::
exception
&
ex
)
catch
(...)
{
{
if
(
output
)
return
migraphx_status_unknown_error
;
std
::
cerr
<<
"MIGraphX Error: "
<<
ex
.
what
()
<<
std
::
endl
;
}
return
migraphx_status_unknown_error
;
}
catch
(...)
{
return
migraphx_status_unknown_error
;
}
}
return
migraphx_status_success
;
return
migraphx_status_success
;
}
}
...
@@ -305,6 +318,7 @@ void destroy(T* x)
...
@@ -305,6 +318,7 @@ void destroy(T* x)
{
{
delete
x
;
// NOLINT
delete
x
;
// NOLINT
}
}
// TODO: Move to interface preamble
// TODO: Move to interface preamble
template
<
class
C
,
class
D
>
template
<
class
C
,
class
D
>
struct
manage_generic_ptr
struct
manage_generic_ptr
...
@@ -313,30 +327,35 @@ struct manage_generic_ptr
...
@@ -313,30 +327,35 @@ struct manage_generic_ptr
manage_generic_ptr
(
std
::
nullptr_t
)
{}
manage_generic_ptr
(
std
::
nullptr_t
)
{}
manage_generic_ptr
(
void
*
pdata
,
C
pcopier
,
D
pdeleter
)
manage_generic_ptr
(
void
*
pdata
,
const
char
*
obj_tname
,
C
pcopier
,
D
pdeleter
)
:
data
(
nullptr
),
copier
(
pcopier
),
deleter
(
pdeleter
)
:
data
(
nullptr
),
obj_typename
(
obj_tname
),
copier
(
pcopier
),
deleter
(
pdeleter
)
{
{
copier
(
&
data
,
pdata
);
copier
(
&
data
,
pdata
);
}
}
manage_generic_ptr
(
const
manage_generic_ptr
&
rhs
)
manage_generic_ptr
(
const
manage_generic_ptr
&
rhs
)
:
data
(
nullptr
),
copier
(
rhs
.
copier
),
deleter
(
rhs
.
deleter
)
:
data
(
nullptr
),
obj_typename
(
rhs
.
obj_typename
),
copier
(
rhs
.
copier
),
deleter
(
rhs
.
deleter
)
{
{
if
(
copier
)
if
(
copier
)
copier
(
&
data
,
rhs
.
data
);
copier
(
&
data
,
rhs
.
data
);
}
}
manage_generic_ptr
(
manage_generic_ptr
&&
other
)
noexcept
manage_generic_ptr
(
manage_generic_ptr
&&
other
)
noexcept
:
data
(
other
.
data
),
copier
(
other
.
copier
),
deleter
(
other
.
deleter
)
:
data
(
other
.
data
),
obj_typename
(
other
.
obj_typename
),
copier
(
other
.
copier
),
deleter
(
other
.
deleter
)
{
{
other
.
data
=
nullptr
;
other
.
data
=
nullptr
;
other
.
copier
=
nullptr
;
other
.
obj_typename
=
""
;
other
.
deleter
=
nullptr
;
other
.
copier
=
nullptr
;
other
.
deleter
=
nullptr
;
}
}
manage_generic_ptr
&
operator
=
(
manage_generic_ptr
rhs
)
manage_generic_ptr
&
operator
=
(
manage_generic_ptr
rhs
)
{
{
std
::
swap
(
data
,
rhs
.
data
);
std
::
swap
(
data
,
rhs
.
data
);
std
::
swap
(
obj_typename
,
rhs
.
obj_typename
);
std
::
swap
(
copier
,
rhs
.
copier
);
std
::
swap
(
copier
,
rhs
.
copier
);
std
::
swap
(
deleter
,
rhs
.
deleter
);
std
::
swap
(
deleter
,
rhs
.
deleter
);
return
*
this
;
return
*
this
;
...
@@ -348,9 +367,10 @@ struct manage_generic_ptr
...
@@ -348,9 +367,10 @@ struct manage_generic_ptr
deleter
(
data
);
deleter
(
data
);
}
}
void
*
data
=
nullptr
;
void
*
data
=
nullptr
;
C
copier
=
nullptr
;
const
char
*
obj_typename
=
""
;
D
deleter
=
nullptr
;
C
copier
=
nullptr
;
D
deleter
=
nullptr
;
};
};
extern
"C"
struct
migraphx_shape
;
extern
"C"
struct
migraphx_shape
;
...
@@ -580,8 +600,9 @@ struct migraphx_experimental_custom_op
...
@@ -580,8 +600,9 @@ struct migraphx_experimental_custom_op
migraphx_experimental_custom_op
(
void
*
p
,
migraphx_experimental_custom_op
(
void
*
p
,
migraphx_experimental_custom_op_copy
c
,
migraphx_experimental_custom_op_copy
c
,
migraphx_experimental_custom_op_delete
d
,
migraphx_experimental_custom_op_delete
d
,
const
char
*
obj_typename
,
Ts
&&
...
xs
)
Ts
&&
...
xs
)
:
object_ptr
(
p
,
c
,
d
),
xobject
(
std
::
forward
<
Ts
>
(
xs
)...)
:
object_ptr
(
p
,
obj_typename
,
c
,
d
),
xobject
(
std
::
forward
<
Ts
>
(
xs
)...)
{
{
}
}
manage_generic_ptr
<
migraphx_experimental_custom_op_copy
,
migraphx_experimental_custom_op_delete
>
manage_generic_ptr
<
migraphx_experimental_custom_op_copy
,
migraphx_experimental_custom_op_delete
>
...
@@ -595,13 +616,21 @@ struct migraphx_experimental_custom_op
...
@@ -595,13 +616,21 @@ struct migraphx_experimental_custom_op
std
::
remove_pointer_t
<
migraphx_argument_t
>
out
;
std
::
remove_pointer_t
<
migraphx_argument_t
>
out
;
if
(
compute_f
==
nullptr
)
if
(
compute_f
==
nullptr
)
throw
std
::
runtime_error
(
"compute function is missing."
);
throw
std
::
runtime_error
(
"compute function is missing."
);
std
::
array
<
char
,
256
>
exception_msg
;
exception_msg
.
front
()
=
'\0'
;
auto
api_error_result
=
compute_f
(
&
out
,
auto
api_error_result
=
compute_f
(
&
out
,
object_ptr
.
data
,
object_ptr
.
data
,
exception_msg
.
data
(),
exception_msg
.
size
(),
object_cast
<
migraphx_context_t
>
(
&
(
ctx
)),
object_cast
<
migraphx_context_t
>
(
&
(
ctx
)),
object_cast
<
migraphx_shape_t
>
(
&
(
output
)),
object_cast
<
migraphx_shape_t
>
(
&
(
output
)),
object_cast
<
migraphx_arguments_t
>
(
&
(
inputs
)));
object_cast
<
migraphx_arguments_t
>
(
&
(
inputs
)));
if
(
api_error_result
!=
migraphx_status_success
)
if
(
api_error_result
!=
migraphx_status_success
)
throw
std
::
runtime_error
(
"Error in compute."
);
{
const
std
::
string
exception_str
(
exception_msg
.
data
());
throw
std
::
runtime_error
(
"Error in compute of: "
+
std
::
string
(
object_ptr
.
obj_typename
)
+
": "
+
exception_str
);
}
return
(
&
out
)
->
object
;
return
(
&
out
)
->
object
;
}
}
...
@@ -611,10 +640,19 @@ struct migraphx_experimental_custom_op
...
@@ -611,10 +640,19 @@ struct migraphx_experimental_custom_op
std
::
remove_pointer_t
<
migraphx_shape_t
>
out
;
std
::
remove_pointer_t
<
migraphx_shape_t
>
out
;
if
(
compute_shape_f
==
nullptr
)
if
(
compute_shape_f
==
nullptr
)
throw
std
::
runtime_error
(
"compute_shape function is missing."
);
throw
std
::
runtime_error
(
"compute_shape function is missing."
);
auto
api_error_result
=
std
::
array
<
char
,
256
>
exception_msg
;
compute_shape_f
(
&
out
,
object_ptr
.
data
,
object_cast
<
migraphx_shapes_t
>
(
&
(
inputs
)));
exception_msg
.
front
()
=
'\0'
;
auto
api_error_result
=
compute_shape_f
(
&
out
,
object_ptr
.
data
,
exception_msg
.
data
(),
exception_msg
.
size
(),
object_cast
<
migraphx_shapes_t
>
(
&
(
inputs
)));
if
(
api_error_result
!=
migraphx_status_success
)
if
(
api_error_result
!=
migraphx_status_success
)
throw
std
::
runtime_error
(
"Error in compute_shape."
);
{
const
std
::
string
exception_str
(
exception_msg
.
data
());
throw
std
::
runtime_error
(
"Error in compute_shape of: "
+
std
::
string
(
object_ptr
.
obj_typename
)
+
": "
+
exception_str
);
}
return
(
&
out
)
->
object
;
return
(
&
out
)
->
object
;
}
}
};
};
...
@@ -743,6 +781,16 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha
...
@@ -743,6 +781,16 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha
return
api_error_result
;
return
api_error_result
;
}
}
extern
"C"
migraphx_status
migraphx_shape_standard
(
bool
*
out
,
const_migraphx_shape_t
shape
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
shape
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter shape: Null pointer"
);
*
out
=
(
shape
->
object
).
standard
();
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_argument_destroy
(
migraphx_argument_t
argument
)
extern
"C"
migraphx_status
migraphx_argument_destroy
(
migraphx_argument_t
argument
)
{
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
destroy
((
argument
));
});
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
destroy
((
argument
));
});
...
@@ -1806,11 +1854,12 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
...
@@ -1806,11 +1854,12 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
void
*
obj
,
void
*
obj
,
migraphx_experimental_custom_op_copy
c
,
migraphx_experimental_custom_op_copy
c
,
migraphx_experimental_custom_op_delete
d
,
migraphx_experimental_custom_op_delete
d
,
const
char
*
obj_typename
,
const
char
*
name
)
const
char
*
name
)
{
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
*
experimental_custom_op
=
*
experimental_custom_op
=
allocate
<
migraphx_experimental_custom_op_t
>
((
obj
),
(
c
),
(
d
),
(
name
));
allocate
<
migraphx_experimental_custom_op_t
>
((
obj
),
(
c
),
(
d
),
(
obj_typename
),
(
name
));
});
});
return
api_error_result
;
return
api_error_result
;
}
}
...
...
src/api/include/migraphx/migraphx.h
View file @
c722117d
...
@@ -132,12 +132,16 @@ typedef const struct migraphx_experimental_custom_op* const_migraphx_experimenta
...
@@ -132,12 +132,16 @@ typedef const struct migraphx_experimental_custom_op* const_migraphx_experimenta
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_compute
)(
migraphx_argument_t
out
,
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_compute
)(
migraphx_argument_t
out
,
void
*
obj
,
void
*
obj
,
char
*
exception_msg
,
size_t
exception_msg_size
,
migraphx_context_t
ctx
,
migraphx_context_t
ctx
,
migraphx_shape_t
output
,
migraphx_shape_t
output
,
migraphx_arguments_t
inputs
);
migraphx_arguments_t
inputs
);
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_compute_shape
)(
migraphx_shape_t
out
,
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_compute_shape
)(
migraphx_shape_t
out
,
void
*
obj
,
void
*
obj
,
char
*
exception_msg
,
size_t
exception_msg_size
,
migraphx_shapes_t
inputs
);
migraphx_shapes_t
inputs
);
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_copy
)(
void
**
out
,
void
*
input
);
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_copy
)(
void
**
out
,
void
*
input
);
...
@@ -176,6 +180,8 @@ migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape);
...
@@ -176,6 +180,8 @@ migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape);
migraphx_status
migraphx_status
migraphx_shape_equal
(
bool
*
out
,
const_migraphx_shape_t
shape
,
const_migraphx_shape_t
x
);
migraphx_shape_equal
(
bool
*
out
,
const_migraphx_shape_t
shape
,
const_migraphx_shape_t
x
);
migraphx_status
migraphx_shape_standard
(
bool
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_argument_destroy
(
migraphx_argument_t
argument
);
migraphx_status
migraphx_argument_destroy
(
migraphx_argument_t
argument
);
migraphx_status
migraphx_argument_assign_to
(
migraphx_argument_t
output
,
migraphx_status
migraphx_argument_assign_to
(
migraphx_argument_t
output
,
...
@@ -486,6 +492,7 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
...
@@ -486,6 +492,7 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
void
*
obj
,
void
*
obj
,
migraphx_experimental_custom_op_copy
c
,
migraphx_experimental_custom_op_copy
c
,
migraphx_experimental_custom_op_delete
d
,
migraphx_experimental_custom_op_delete
d
,
const
char
*
obj_typename
,
const
char
*
name
);
const
char
*
name
);
migraphx_status
migraphx_status
...
...
src/api/include/migraphx/migraphx.hpp
View file @
c722117d
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#include "migraphx.h"
#include "migraphx.h"
#include <cstring>
#include <initializer_list>
#include <initializer_list>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.h>
#include <memory>
#include <memory>
...
@@ -58,6 +59,42 @@ struct rank<0>
...
@@ -58,6 +59,42 @@ struct rank<0>
{
{
};
};
template
<
class
PrivateMigraphTypeNameProbe
>
std
::
string
compute_type_name
()
{
std
::
string
name
;
#ifdef _MSC_VER
name
=
typeid
(
PrivateMigraphTypeNameProbe
).
name
();
name
=
name
.
substr
(
7
);
#else
const
char
parameter_name
[]
=
"PrivateMigraphTypeNameProbe ="
;
// NOLINT
name
=
__PRETTY_FUNCTION__
;
auto
begin
=
name
.
find
(
parameter_name
)
+
sizeof
(
parameter_name
);
#if(defined(__GNUC__) && !defined(__clang__) && __GNUC__ == 4 && __GNUC_MINOR__ < 7)
auto
length
=
name
.
find_last_of
(
","
)
-
begin
;
#else
auto
length
=
name
.
find_first_of
(
"];"
,
begin
)
-
begin
;
#endif
name
=
name
.
substr
(
begin
,
length
);
#endif
return
name
;
}
template
<
class
T
>
const
std
::
string
&
get_type_name
()
{
static
const
std
::
string
name
=
compute_type_name
<
T
>
();
return
name
;
}
template
<
class
T
>
const
std
::
string
&
get_type_name
(
const
T
&
)
{
return
get_type_name
<
T
>
();
}
template
<
class
T
,
class
F
,
class
...
Ts
>
template
<
class
T
,
class
F
,
class
...
Ts
>
T
*
make
(
F
f
,
Ts
&&
...
xs
)
T
*
make
(
F
f
,
Ts
&&
...
xs
)
{
{
...
@@ -310,13 +347,22 @@ struct interface_base : Base
...
@@ -310,13 +347,22 @@ struct interface_base : Base
protected:
protected:
template
<
class
F
>
template
<
class
F
>
static
migraphx_status
try_
(
F
f
)
// NOLINT
static
migraphx_status
try_
(
F
f
,
char
*
ex_msg
=
nullptr
,
size_t
ex_msg_size
=
0
)
// NOLINT
{
{
try
try
{
{
f
();
f
();
return
migraphx_status_success
;
return
migraphx_status_success
;
}
}
catch
(
const
std
::
exception
&
ex
)
{
if
(
ex_msg
)
{
std
::
strncpy
(
ex_msg
,
ex
.
what
(),
ex_msg_size
);
ex_msg
[
ex_msg_size
-
1
]
=
'\0'
;
}
return
migraphx_status_unknown_error
;
}
catch
(...)
catch
(...)
{
{
return
migraphx_status_unknown_error
;
return
migraphx_status_unknown_error
;
...
@@ -349,9 +395,13 @@ struct interface_base : Base
...
@@ -349,9 +395,13 @@ struct interface_base : Base
{
{
static
F
f
=
pf
;
static
F
f
=
pf
;
(
void
)
f
;
// avoid warning on gcc
(
void
)
f
;
// avoid warning on gcc
call
(
setter
,
this
->
get_handle_ptr
(),
[](
auto
...
xs
)
->
migraphx_status
{
call
(
setter
,
return
try_
([
&
]
{
call_cast_arg
<
T
>
(
rank
<
1
>
{},
f
,
xs
...);
});
this
->
get_handle_ptr
(),
});
[](
auto
out
,
void
*
obj
,
char
*
ex_msg
,
size_t
ex_msg_size
,
auto
...
xs
)
->
migraphx_status
{
return
try_
(
[
&
]
{
call_cast_arg
<
T
>
(
rank
<
1
>
{},
f
,
out
,
obj
,
xs
...);
},
ex_msg
,
ex_msg_size
);
});
}
}
template
<
class
T
,
class
Setter
,
class
F
>
template
<
class
T
,
class
Setter
,
class
F
>
...
@@ -524,6 +574,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
...
@@ -524,6 +574,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return
pout
;
return
pout
;
}
}
bool
standard
()
const
{
bool
result
=
false
;
call
(
&
migraphx_shape_standard
,
&
result
,
this
->
get_handle_ptr
());
return
result
;
}
friend
bool
operator
==
(
const
shape
&
px
,
const
shape
&
py
)
friend
bool
operator
==
(
const
shape
&
px
,
const
shape
&
py
)
{
{
bool
pout
;
bool
pout
;
...
@@ -1206,7 +1263,10 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental
...
@@ -1206,7 +1263,10 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental
template
<
class
T
>
template
<
class
T
>
experimental_custom_op
(
T
&
obj
)
experimental_custom_op
(
T
&
obj
)
{
{
this
->
make_interface
(
&
migraphx_experimental_custom_op_create
,
obj
,
obj
.
name
().
c_str
());
this
->
make_interface
(
&
migraphx_experimental_custom_op_create
,
obj
,
get_type_name
(
obj
).
c_str
(),
obj
.
name
().
c_str
());
MIGRAPHX_INTERFACE_LIFT
(
T
,
experimental_custom_op
,
compute_shape
);
MIGRAPHX_INTERFACE_LIFT
(
T
,
experimental_custom_op
,
compute_shape
);
MIGRAPHX_INTERFACE_LIFT
(
T
,
experimental_custom_op
,
compute
);
MIGRAPHX_INTERFACE_LIFT
(
T
,
experimental_custom_op
,
compute
);
}
}
...
...
src/api/migraphx.py
View file @
c722117d
...
@@ -121,6 +121,7 @@ def shape(h):
...
@@ -121,6 +121,7 @@ def shape(h):
invoke
=
'migraphx::equal($@)'
,
invoke
=
'migraphx::equal($@)'
,
returns
=
'bool'
,
returns
=
'bool'
,
const
=
True
)
const
=
True
)
h
.
method
(
'standard'
,
returns
=
'bool'
,
const
=
True
)
@
auto_handle
()
@
auto_handle
()
...
@@ -439,7 +440,8 @@ def context(h):
...
@@ -439,7 +440,8 @@ def context(h):
@
api
.
interface
(
'migraphx_experimental_custom_op'
,
@
api
.
interface
(
'migraphx_experimental_custom_op'
,
'migraphx::experimental_custom_op'
)
'migraphx::experimental_custom_op'
)
def
experimental_custom_op
(
h
):
def
experimental_custom_op
(
h
):
h
.
constructor
(
'create'
,
api
.
params
(
name
=
'const char*'
))
h
.
constructor
(
'create'
,
api
.
params
(
obj_typename
=
'const char*'
,
name
=
'const char*'
))
h
.
virtual
(
'compute'
,
h
.
virtual
(
'compute'
,
api
.
params
(
ctx
=
'migraphx::context'
,
api
.
params
(
ctx
=
'migraphx::context'
,
output
=
'migraphx::shape'
,
output
=
'migraphx::shape'
,
...
...
test/api/test_custom_op.cpp
View file @
c722117d
...
@@ -23,8 +23,10 @@
...
@@ -23,8 +23,10 @@
*/
*/
#include <algorithm>
#include <algorithm>
#include <cmath>
#include <cmath>
#include <exception>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include <migraphx/migraphx.hpp>
#include <stdexcept>
#include "test.hpp"
#include "test.hpp"
struct
sigmoid_custom_op
final
:
migraphx
::
experimental_custom_op_base
struct
sigmoid_custom_op
final
:
migraphx
::
experimental_custom_op_base
...
@@ -43,10 +45,22 @@ struct sigmoid_custom_op final : migraphx::experimental_custom_op_base
...
@@ -43,10 +45,22 @@ struct sigmoid_custom_op final : migraphx::experimental_custom_op_base
virtual
migraphx
::
shape
compute_shape
(
migraphx
::
shapes
inputs
)
const
override
virtual
migraphx
::
shape
compute_shape
(
migraphx
::
shapes
inputs
)
const
override
{
{
CHECK
(
inputs
.
size
()
==
2
);
if
(
inputs
.
size
()
!=
2
)
CHECK
(
inputs
[
0
].
lengths
().
size
()
==
1
);
{
CHECK
(
inputs
[
0
].
type
()
==
migraphx_shape_float_type
);
throw
std
::
runtime_error
(
"op must have two inputs"
);
CHECK
(
bool
{
inputs
[
0
]
==
inputs
[
1
]});
}
if
(
inputs
[
0
].
lengths
().
size
()
!=
1
)
{
throw
std
::
runtime_error
(
"input arg must be a vector or scalar"
);
}
if
(
inputs
[
0
].
type
()
!=
migraphx_shape_float_type
)
{
throw
std
::
runtime_error
(
"input arg must be of type float"
);
}
if
(
inputs
[
0
]
!=
inputs
[
1
])
{
throw
std
::
runtime_error
(
"input arg and buffer allocation must be of same shape"
);
}
return
inputs
.
back
();
return
inputs
.
back
();
}
}
};
};
...
@@ -83,4 +97,18 @@ TEST_CASE(run_sigmoid_custom_op)
...
@@ -83,4 +97,18 @@ TEST_CASE(run_sigmoid_custom_op)
EXPECT
(
bool
{
result
==
migraphx
::
argument
(
s
,
expected_result
.
data
())});
EXPECT
(
bool
{
result
==
migraphx
::
argument
(
s
,
expected_result
.
data
())});
}
}
extern
"C"
void
migraphx_test_private_disable_exception_catch
(
bool
b
);
TEST_CASE
(
run_sigmoid_with_incorrect_shape
)
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx_shape_float_type
,
{
12
}};
migraphx
::
module
m
=
p
.
get_main_module
();
auto
x
=
m
.
add_parameter
(
"x"
,
s
);
migraphx_test_private_disable_exception_catch
(
true
);
EXPECT
(
test
::
throws
<
std
::
exception
>
(
[
&
]
{
m
.
add_instruction
(
migraphx
::
operation
(
"sigmoid_custom_op"
),
{
x
});
},
"Error in compute_shape of: sigmoid_custom_op: op must have two inputs"
));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/api/test_custom_op_gpu.cpp
View file @
c722117d
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
#include <hip/hip_runtime_api.h>
#include <hip/hip_runtime_api.h>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include <migraphx/migraphx.hpp>
#include <stdexcept>
#include "test.hpp"
#include "test.hpp"
#define MIGRAPHX_HIP_ASSERT(x) (EXPECT(x == hipSuccess))
#define MIGRAPHX_HIP_ASSERT(x) (EXPECT(x == hipSuccess))
...
@@ -54,6 +55,14 @@ struct simple_custom_op final : migraphx::experimental_custom_op_base
...
@@ -54,6 +55,14 @@ struct simple_custom_op final : migraphx::experimental_custom_op_base
virtual
migraphx
::
shape
compute_shape
(
migraphx
::
shapes
inputs
)
const
override
virtual
migraphx
::
shape
compute_shape
(
migraphx
::
shapes
inputs
)
const
override
{
{
if
(
!
inputs
[
0
].
standard
())
{
throw
std
::
runtime_error
(
"first arg must be standard shaped"
);
}
if
(
inputs
.
size
()
!=
2
)
{
throw
std
::
runtime_error
(
"number of inputs must be 2"
);
}
return
inputs
.
back
();
return
inputs
.
back
();
}
}
};
};
...
@@ -64,12 +73,17 @@ TEST_CASE(run_simple_custom_op)
...
@@ -64,12 +73,17 @@ TEST_CASE(run_simple_custom_op)
migraphx
::
register_experimental_custom_op
(
simple_op
);
migraphx
::
register_experimental_custom_op
(
simple_op
);
migraphx
::
program
p
;
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx_shape_int32_type
,
{
4
,
3
}};
migraphx
::
shape
s
{
migraphx_shape_int32_type
,
{
4
,
3
}};
migraphx
::
shape
trans_shape
{
migraphx_shape_int32_type
,
{
3
,
4
}};
migraphx
::
module
m
=
p
.
get_main_module
();
migraphx
::
module
m
=
p
.
get_main_module
();
auto
x
=
m
.
add_parameter
(
"x"
,
s
);
auto
x
=
m
.
add_parameter
(
"x"
,
s
);
auto
neg
=
m
.
add_instruction
(
migraphx
::
operation
(
"neg"
),
x
);
auto
neg
=
m
.
add_instruction
(
migraphx
::
operation
(
"neg"
),
x
);
auto
alloc
=
m
.
add_allocation
(
s
);
auto
alloc
=
m
.
add_allocation
(
trans_shape
);
auto
custom_kernel
=
m
.
add_instruction
(
migraphx
::
operation
(
"simple_custom_op"
),
{
neg
,
alloc
});
auto
neg_trans
=
auto
relu
=
m
.
add_instruction
(
migraphx
::
operation
(
"relu"
),
custom_kernel
);
m
.
add_instruction
(
migraphx
::
operation
(
"transpose"
,
"{permutation: [1, 0]}"
),
{
neg
});
auto
neg_cont
=
m
.
add_instruction
(
migraphx
::
operation
(
"contiguous"
),
{
neg_trans
});
auto
custom_kernel
=
m
.
add_instruction
(
migraphx
::
operation
(
"simple_custom_op"
),
{
neg_cont
,
alloc
});
auto
relu
=
m
.
add_instruction
(
migraphx
::
operation
(
"relu"
),
custom_kernel
);
m
.
add_return
({
relu
});
m
.
add_return
({
relu
});
migraphx
::
compile_options
options
;
migraphx
::
compile_options
options
;
options
.
set_offload_copy
();
options
.
set_offload_copy
();
...
@@ -82,7 +96,7 @@ TEST_CASE(run_simple_custom_op)
...
@@ -82,7 +96,7 @@ TEST_CASE(run_simple_custom_op)
auto
result_vec
=
result
.
as_vector
<
int
>
();
auto
result_vec
=
result
.
as_vector
<
int
>
();
std
::
vector
<
int
>
expected_result
(
12
,
0
);
std
::
vector
<
int
>
expected_result
(
12
,
0
);
std
::
fill
(
expected_result
.
begin
()
+
6
,
expected_result
.
end
(),
3
);
std
::
fill
(
expected_result
.
begin
()
+
6
,
expected_result
.
end
(),
3
);
EXPECT
(
bool
{
result
==
migraphx
::
argument
(
s
,
expected_result
.
data
())});
EXPECT
(
bool
{
result
==
migraphx
::
argument
(
trans_shape
,
expected_result
.
data
())});
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
tools/api.py
View file @
c722117d
...
@@ -197,7 +197,8 @@ class Parameter:
...
@@ -197,7 +197,8 @@ class Parameter:
optional
:
bool
=
False
,
optional
:
bool
=
False
,
returns
:
bool
=
False
,
returns
:
bool
=
False
,
virtual
:
bool
=
False
,
virtual
:
bool
=
False
,
this
:
bool
=
False
)
->
None
:
this
:
bool
=
False
,
hidden
:
bool
=
False
)
->
None
:
self
.
name
=
name
self
.
name
=
name
self
.
type
=
Type
(
type
)
self
.
type
=
Type
(
type
)
self
.
optional
=
optional
self
.
optional
=
optional
...
@@ -211,6 +212,7 @@ class Parameter:
...
@@ -211,6 +212,7 @@ class Parameter:
self
.
returns
=
returns
self
.
returns
=
returns
self
.
virtual
=
virtual
self
.
virtual
=
virtual
self
.
this
=
this
self
.
this
=
this
self
.
hidden
=
hidden
self
.
bad_param_check
:
Optional
[
BadParam
]
=
None
self
.
bad_param_check
:
Optional
[
BadParam
]
=
None
self
.
virtual_read
:
Optional
[
List
[
str
]]
=
None
self
.
virtual_read
:
Optional
[
List
[
str
]]
=
None
self
.
virtual_write
:
Optional
[
str
]
=
None
self
.
virtual_write
:
Optional
[
str
]
=
None
...
@@ -744,6 +746,8 @@ void destroy(T* x)
...
@@ -744,6 +746,8 @@ void destroy(T* x)
{
{
delete x; // NOLINT
delete x; // NOLINT
}
}
// TODO: Move to interface preamble
// TODO: Move to interface preamble
template <class C, class D>
template <class C, class D>
struct manage_generic_ptr
struct manage_generic_ptr
...
@@ -754,23 +758,24 @@ struct manage_generic_ptr
...
@@ -754,23 +758,24 @@ struct manage_generic_ptr
{
{
}
}
manage_generic_ptr(void* pdata, C pcopier, D pdeleter)
manage_generic_ptr(void* pdata,
const char* obj_tname,
C pcopier, D pdeleter)
: data(nullptr), copier(pcopier), deleter(pdeleter)
: data(nullptr),
obj_typename(obj_tname),
copier(pcopier), deleter(pdeleter)
{
{
copier(&data, pdata);
copier(&data, pdata);
}
}
manage_generic_ptr(const manage_generic_ptr& rhs)
manage_generic_ptr(const manage_generic_ptr& rhs)
: data(nullptr), copier(rhs.copier), deleter(rhs.deleter)
: data(nullptr),
obj_typename(rhs.obj_typename),
copier(rhs.copier), deleter(rhs.deleter)
{
{
if(copier)
if(copier)
copier(&data, rhs.data);
copier(&data, rhs.data);
}
}
manage_generic_ptr(manage_generic_ptr&& other) noexcept
manage_generic_ptr(manage_generic_ptr&& other) noexcept
: data(other.data), copier(other.copier), deleter(other.deleter)
: data(other.data),
obj_typename(other.obj_typename),
copier(other.copier), deleter(other.deleter)
{
{
other.data = nullptr;
other.data = nullptr;
other.obj_typename = "";
other.copier = nullptr;
other.copier = nullptr;
other.deleter = nullptr;
other.deleter = nullptr;
}
}
...
@@ -778,6 +783,7 @@ struct manage_generic_ptr
...
@@ -778,6 +783,7 @@ struct manage_generic_ptr
manage_generic_ptr& operator=(manage_generic_ptr rhs)
manage_generic_ptr& operator=(manage_generic_ptr rhs)
{
{
std::swap(data, rhs.data);
std::swap(data, rhs.data);
std::swap(obj_typename, rhs.obj_typename);
std::swap(copier, rhs.copier);
std::swap(copier, rhs.copier);
std::swap(deleter, rhs.deleter);
std::swap(deleter, rhs.deleter);
return *this;
return *this;
...
@@ -790,6 +796,7 @@ struct manage_generic_ptr
...
@@ -790,6 +796,7 @@ struct manage_generic_ptr
}
}
void* data = nullptr;
void* data = nullptr;
const char* obj_typename = "";
C copier = nullptr;
C copier = nullptr;
D deleter = nullptr;
D deleter = nullptr;
};
};
...
@@ -1042,8 +1049,8 @@ interface_handle_definition = Template('''
...
@@ -1042,8 +1049,8 @@ interface_handle_definition = Template('''
extern "C" struct ${ctype};
extern "C" struct ${ctype};
struct ${ctype} {
struct ${ctype} {
template<class... Ts>
template<class... Ts>
${ctype}(void* p, ${copier} c, ${deleter} d, Ts&&... xs)
${ctype}(void* p, ${copier} c, ${deleter} d,
const char* obj_typename,
Ts&&... xs)
: object_ptr(p, c, d), xobject(std::forward<Ts>(xs)...)
: object_ptr(p,
obj_typename,
c, d), xobject(std::forward<Ts>(xs)...)
{}
{}
manage_generic_ptr<${copier}, ${deleter}> object_ptr = nullptr;
manage_generic_ptr<${copier}, ${deleter}> object_ptr = nullptr;
${cpptype} xobject;
${cpptype} xobject;
...
@@ -1057,9 +1064,13 @@ ${return_type} ${name}(${params}) const
...
@@ -1057,9 +1064,13 @@ ${return_type} ${name}(${params}) const
${output_decls}
${output_decls}
if (${fname} == nullptr)
if (${fname} == nullptr)
throw std::runtime_error("${name} function is missing.");
throw std::runtime_error("${name} function is missing.");
std::array<char, 256> exception_msg;
exception_msg.front() = '
\\
0';
auto api_error_result = ${fname}(${args});
auto api_error_result = ${fname}(${args});
if (api_error_result != ${success})
if (api_error_result != ${success}) {
throw std::runtime_error("Error in ${name}.");
const std::string exception_str(exception_msg.data());
throw std::runtime_error("Error in ${name} of: " + std::string(object_ptr.obj_typename) + ": " + exception_str);
}
return ${output};
return ${output};
}
}
'''
)
'''
)
...
@@ -1079,7 +1090,9 @@ def generate_virtual_impl(f: Function, fname: str) -> str:
...
@@ -1079,7 +1090,9 @@ def generate_virtual_impl(f: Function, fname: str) -> str:
largs
+=
f
.
returns
.
virtual_output_args
()
largs
+=
f
.
returns
.
virtual_output_args
()
output
=
f
.
returns
.
virtual_output
()
output
=
f
.
returns
.
virtual_output
()
largs
+=
[
arg
for
p
in
f
.
params
for
arg
in
p
.
virtual_arg
()]
largs
+=
[
arg
for
p
in
f
.
params
for
arg
in
p
.
virtual_arg
()]
lparams
+=
[
p
.
virtual_param
()
for
p
in
f
.
params
if
not
p
.
this
]
lparams
+=
[
p
.
virtual_param
()
for
p
in
f
.
params
if
not
(
p
.
this
or
p
.
hidden
)
]
args
=
', '
.
join
(
largs
)
args
=
', '
.
join
(
largs
)
params
=
', '
.
join
(
lparams
)
params
=
', '
.
join
(
lparams
)
return
c_api_virtual_impl
.
substitute
(
locals
())
return
c_api_virtual_impl
.
substitute
(
locals
())
...
@@ -1126,8 +1139,15 @@ class Interface(Handle):
...
@@ -1126,8 +1139,15 @@ class Interface(Handle):
# Add this parameter to the function
# Add this parameter to the function
this
=
Parameter
(
'obj'
,
'void*'
,
this
=
True
)
this
=
Parameter
(
'obj'
,
'void*'
,
this
=
True
)
this
.
virtual_read
=
[
'object_ptr.data'
]
this
.
virtual_read
=
[
'object_ptr.data'
]
exception_msg
=
Parameter
(
'exception_msg'
,
'char*'
,
hidden
=
True
)
exception_msg
.
virtual_read
=
[
'${name}.data()'
]
exception_msg_size
=
Parameter
(
'exception_msg_size'
,
'size_t'
,
hidden
=
True
)
exception_msg_size
.
virtual_read
=
[
'exception_msg.size()'
]
f
=
Function
(
name
,
f
=
Function
(
name
,
params
=
[
this
]
+
(
params
or
[]),
params
=
[
this
,
exception_msg
,
exception_msg_size
]
+
(
params
or
[]),
virtual
=
True
,
virtual
=
True
,
**
kwargs
)
**
kwargs
)
self
.
ifunctions
.
append
(
f
)
self
.
ifunctions
.
append
(
f
)
...
...
tools/api/api.cpp
View file @
c722117d
...
@@ -39,34 +39,47 @@
...
@@ -39,34 +39,47 @@
#include <migraphx/convert_to_json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <algorithm>
#include <algorithm>
#include <cstdarg>
#include <cstdarg>
namespace
migraphx
{
namespace
migraphx
{
static
thread_local
bool
disable_exception_catch
=
false
;
// NOLINT
extern
"C"
void
migraphx_test_private_disable_exception_catch
(
bool
b
)
{
disable_exception_catch
=
b
;
}
template
<
class
F
>
template
<
class
F
>
migraphx_status
try_
(
F
f
,
bool
output
=
true
)
// NOLINT
migraphx_status
try_
(
F
f
,
bool
output
=
true
)
// NOLINT
{
{
try
if
(
disable_exception_catch
)
{
{
f
();
f
();
}
}
catch
(
const
migraphx
::
exception
&
ex
)
else
{
{
if
(
output
)
try
std
::
cerr
<<
"MIGraphX Error: "
<<
ex
.
what
()
<<
std
::
endl
;
{
if
(
ex
.
error
>
0
)
f
();
return
migraphx_status
(
ex
.
error
);
}
else
catch
(
const
migraphx
::
exception
&
ex
)
{
if
(
output
)
std
::
cerr
<<
"MIGraphX Error: "
<<
ex
.
what
()
<<
std
::
endl
;
if
(
ex
.
error
>
0
)
return
migraphx_status
(
ex
.
error
);
else
return
migraphx_status_unknown_error
;
}
catch
(
const
std
::
exception
&
ex
)
{
if
(
output
)
std
::
cerr
<<
"MIGraphX Error: "
<<
ex
.
what
()
<<
std
::
endl
;
return
migraphx_status_unknown_error
;
return
migraphx_status_unknown_error
;
}
}
catch
(
const
std
::
exception
&
ex
)
catch
(...)
{
{
if
(
output
)
return
migraphx_status_unknown_error
;
std
::
cerr
<<
"MIGraphX Error: "
<<
ex
.
what
()
<<
std
::
endl
;
}
return
migraphx_status_unknown_error
;
}
catch
(...)
{
return
migraphx_status_unknown_error
;
}
}
return
migraphx_status_success
;
return
migraphx_status_success
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment