Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8b25fd3e
Commit
8b25fd3e
authored
Oct 12, 2022
by
charlie
Browse files
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_ref_broadcast
parents
a88810da
4f3cc417
Changes
31
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
950 additions
and
123 deletions
+950
-123
examples/migraphx/custom_op_hip_kernel/custom_op_hip_kernel.cpp
...es/migraphx/custom_op_hip_kernel/custom_op_hip_kernel.cpp
+9
-2
examples/migraphx/custom_op_miopen_kernel/custom_op_miopen_kernel.cpp
...raphx/custom_op_miopen_kernel/custom_op_miopen_kernel.cpp
+7
-0
examples/migraphx/custom_op_rocblas_kernel/custom_op_rocblas_kernel.cpp
...phx/custom_op_rocblas_kernel/custom_op_rocblas_kernel.cpp
+8
-1
src/api/api.cpp
src/api/api.cpp
+130
-2
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+27
-1
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+125
-20
src/api/migraphx.py
src/api/migraphx.py
+13
-0
src/include/migraphx/context.hpp
src/include/migraphx/context.hpp
+67
-1
src/include/migraphx/execution_environment.hpp
src/include/migraphx/execution_environment.hpp
+41
-0
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+3
-2
src/program.cpp
src/program.cpp
+53
-40
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+17
-0
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+70
-0
src/targets/gpu/include/migraphx/gpu/context.hpp
src/targets/gpu/include/migraphx/gpu/context.hpp
+26
-2
src/targets/gpu/jit/softmax.cpp
src/targets/gpu/jit/softmax.cpp
+5
-0
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+8
-8
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
+9
-5
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+31
-2
test/api/test_custom_op.cpp
test/api/test_custom_op.cpp
+43
-0
test/api/test_custom_op_gpu.cpp
test/api/test_custom_op_gpu.cpp
+258
-37
No files found.
examples/migraphx/custom_op_hip_kernel/custom_op_hip_kernel.cpp
View file @
8b25fd3e
...
@@ -45,6 +45,13 @@ __global__ void vector_square(T* C_d, const T* A_d, size_t N)
...
@@ -45,6 +45,13 @@ __global__ void vector_square(T* C_d, const T* A_d, size_t N)
struct
square_custom_op
final
:
migraphx
::
experimental_custom_op_base
struct
square_custom_op
final
:
migraphx
::
experimental_custom_op_base
{
{
virtual
std
::
string
name
()
const
override
{
return
"square_custom_op"
;
}
virtual
std
::
string
name
()
const
override
{
return
"square_custom_op"
;
}
// flag to identify whether custom op runs on the GPU or on the host.
// Based on this flag MIGraphX would inject necessary copies to and from GPU for the input and
// output buffers as necessary. Therefore if custom_op runs on GPU then it can assume its input
// buffers are in GPU memory, and similarly for the host
virtual
bool
runs_on_offload_target
()
const
override
{
return
true
;
}
virtual
migraphx
::
argument
virtual
migraphx
::
argument
compute
(
migraphx
::
context
ctx
,
migraphx
::
shape
,
migraphx
::
arguments
inputs
)
const
override
compute
(
migraphx
::
context
ctx
,
migraphx
::
shape
,
migraphx
::
arguments
inputs
)
const
override
{
{
...
@@ -54,7 +61,7 @@ struct square_custom_op final : migraphx::experimental_custom_op_base
...
@@ -54,7 +61,7 @@ struct square_custom_op final : migraphx::experimental_custom_op_base
// is output argument, so it should be returned from compute method.
// is output argument, so it should be returned from compute method.
auto
*
input_buffer
=
reinterpret_cast
<
float
*>
(
inputs
[
0
].
data
());
auto
*
input_buffer
=
reinterpret_cast
<
float
*>
(
inputs
[
0
].
data
());
auto
*
output_buffer
=
reinterpret_cast
<
float
*>
(
inputs
[
1
].
data
());
auto
*
output_buffer
=
reinterpret_cast
<
float
*>
(
inputs
[
1
].
data
());
size_t
n_elements
=
inputs
[
0
].
get_shape
().
bytes
()
/
sizeof
(
inputs
[
0
].
get_shape
().
type
()
);
size_t
n_elements
=
inputs
[
0
].
get_shape
().
elements
(
);
MIGRAPHX_HIP_ASSERT
(
hipSetDevice
(
0
));
MIGRAPHX_HIP_ASSERT
(
hipSetDevice
(
0
));
const
unsigned
blocks
=
512
;
const
unsigned
blocks
=
512
;
const
unsigned
threads_per_block
=
256
;
const
unsigned
threads_per_block
=
256
;
...
@@ -103,7 +110,7 @@ int main(int argc, const char* argv[])
...
@@ -103,7 +110,7 @@ int main(int argc, const char* argv[])
options
.
set_offload_copy
();
options
.
set_offload_copy
();
p
.
compile
(
migraphx
::
target
(
"gpu"
),
options
);
p
.
compile
(
migraphx
::
target
(
"gpu"
),
options
);
migraphx
::
program_parameters
pp
;
migraphx
::
program_parameters
pp
;
std
::
vector
<
float
>
x_data
(
s
.
bytes
()
/
sizeof
(
s
.
type
()
));
std
::
vector
<
float
>
x_data
(
s
.
elements
(
));
std
::
iota
(
x_data
.
begin
(),
x_data
.
end
(),
0
);
std
::
iota
(
x_data
.
begin
(),
x_data
.
end
(),
0
);
pp
.
add
(
"x"
,
migraphx
::
argument
(
s
,
x_data
.
data
()));
pp
.
add
(
"x"
,
migraphx
::
argument
(
s
,
x_data
.
data
()));
auto
results
=
p
.
eval
(
pp
);
auto
results
=
p
.
eval
(
pp
);
...
...
examples/migraphx/custom_op_miopen_kernel/custom_op_miopen_kernel.cpp
View file @
8b25fd3e
...
@@ -93,6 +93,13 @@ inline auto make_activation_descriptor(miopenActivationMode_t mode,
...
@@ -93,6 +93,13 @@ inline auto make_activation_descriptor(miopenActivationMode_t mode,
struct
abs_custom_op
final
:
migraphx
::
experimental_custom_op_base
struct
abs_custom_op
final
:
migraphx
::
experimental_custom_op_base
{
{
virtual
std
::
string
name
()
const
override
{
return
"abs_custom_op"
;
}
virtual
std
::
string
name
()
const
override
{
return
"abs_custom_op"
;
}
// flag to identify whether custom op runs on the GPU or on the host.
// Based on this flag MIGraphX would inject necessary copies to and from GPU for the input and
// output buffers as necessary. Therefore if custom_op runs on GPU then it can assume its input
// buffers are in GPU memory, and similarly for the host
virtual
bool
runs_on_offload_target
()
const
override
{
return
true
;
}
virtual
migraphx
::
argument
compute
(
migraphx
::
context
ctx
,
virtual
migraphx
::
argument
compute
(
migraphx
::
context
ctx
,
migraphx
::
shape
output_shape
,
migraphx
::
shape
output_shape
,
migraphx
::
arguments
args
)
const
override
migraphx
::
arguments
args
)
const
override
...
...
examples/migraphx/custom_op_rocblas_kernel/custom_op_rocblas_kernel.cpp
View file @
8b25fd3e
...
@@ -51,6 +51,13 @@ rocblas_handle create_rocblas_handle_ptr(migraphx::context& ctx)
...
@@ -51,6 +51,13 @@ rocblas_handle create_rocblas_handle_ptr(migraphx::context& ctx)
struct
sscal_custom_op
final
:
migraphx
::
experimental_custom_op_base
struct
sscal_custom_op
final
:
migraphx
::
experimental_custom_op_base
{
{
virtual
std
::
string
name
()
const
override
{
return
"sscal_custom_op"
;
}
virtual
std
::
string
name
()
const
override
{
return
"sscal_custom_op"
;
}
// flag to identify whether custom op runs on the GPU or on the host.
// Based on this flag MIGraphX would inject necessary copies to and from GPU for the input and
// output buffers as necessary. Therefore if custom_op runs on GPU then it can assume its input
// buffers are in GPU memory, and similarly for the host
virtual
bool
runs_on_offload_target
()
const
override
{
return
true
;
}
virtual
migraphx
::
argument
compute
(
migraphx
::
context
ctx
,
virtual
migraphx
::
argument
compute
(
migraphx
::
context
ctx
,
migraphx
::
shape
output_shape
,
migraphx
::
shape
output_shape
,
migraphx
::
arguments
args
)
const
override
migraphx
::
arguments
args
)
const
override
...
@@ -107,7 +114,7 @@ int main(int argc, const char* argv[])
...
@@ -107,7 +114,7 @@ int main(int argc, const char* argv[])
options
.
set_offload_copy
();
options
.
set_offload_copy
();
p
.
compile
(
migraphx
::
target
(
"gpu"
),
options
);
p
.
compile
(
migraphx
::
target
(
"gpu"
),
options
);
migraphx
::
program_parameters
pp
;
migraphx
::
program_parameters
pp
;
std
::
vector
<
float
>
x_data
(
x_shape
.
bytes
()
/
sizeof
(
x_shape
.
type
()
));
std
::
vector
<
float
>
x_data
(
x_shape
.
elements
(
));
std
::
vector
<
float
>
scale_data
{
-
1
};
std
::
vector
<
float
>
scale_data
{
-
1
};
std
::
iota
(
x_data
.
begin
(),
x_data
.
end
(),
0
);
std
::
iota
(
x_data
.
begin
(),
x_data
.
end
(),
0
);
pp
.
add
(
"x"
,
migraphx
::
argument
(
x_shape
,
x_data
.
data
()));
pp
.
add
(
"x"
,
migraphx
::
argument
(
x_shape
,
x_data
.
data
()));
...
...
src/api/api.cpp
View file @
8b25fd3e
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <migraphx/execution_environment.hpp>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.h>
#include <migraphx/rank.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/shape.hpp>
...
@@ -166,6 +167,13 @@ void set_output_names(tf_options& options, std::vector<const char*> names)
...
@@ -166,6 +167,13 @@ void set_output_names(tf_options& options, std::vector<const char*> names)
options
.
output_node_names
=
std
::
vector
<
std
::
string
>
(
names
.
begin
(),
names
.
end
());
options
.
output_node_names
=
std
::
vector
<
std
::
string
>
(
names
.
begin
(),
names
.
end
());
}
}
std
::
vector
<
argument
>
run_async
(
program
&
p
,
const
parameter_map
&
params
,
void
*
s
,
std
::
string_view
name
)
{
execution_environment
exec_env
{
any_ptr
(
s
,
name
),
true
};
return
p
.
eval
(
params
,
exec_env
);
}
template
<
class
Value
>
template
<
class
Value
>
std
::
vector
<
const
char
*>
get_names
(
const
std
::
unordered_map
<
std
::
string
,
Value
>&
m
)
std
::
vector
<
const
char
*>
get_names
(
const
std
::
unordered_map
<
std
::
string
,
Value
>&
m
)
{
{
...
@@ -265,11 +273,18 @@ struct experimental_custom_op
...
@@ -265,11 +273,18 @@ struct experimental_custom_op
template
<
class
CustomOp
>
template
<
class
CustomOp
>
struct
custom_operation
struct
custom_operation
{
{
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
,
F
)
static
auto
reflect
(
Self
&
,
F
)
{
{
return
pack
();
return
pack
();
}
}
value
attributes
()
const
{
return
{{
"custom_op"
,
true
},
{
"target"
,
op
.
runs_on_offload_target
()
?
"gpu"
:
"cpu"
}};
}
CustomOp
op
;
CustomOp
op
;
std
::
string
name
()
const
{
return
op
.
xobject
.
name
;
}
std
::
string
name
()
const
{
return
op
.
xobject
.
name
;
}
...
@@ -284,6 +299,23 @@ struct custom_operation
...
@@ -284,6 +299,23 @@ struct custom_operation
{
{
return
op
.
compute
(
std
::
move
(
ctx
),
std
::
move
(
output_shape
),
std
::
move
(
inputs
));
return
op
.
compute
(
std
::
move
(
ctx
),
std
::
move
(
output_shape
),
std
::
move
(
inputs
));
}
}
std
::
ptrdiff_t
output_alias
(
std
::
vector
<
shape
>
inputs
)
const
{
auto
alias_vec
=
op
.
output_alias
(
std
::
move
(
inputs
));
// TODO: For now, only support one output alias
if
(
alias_vec
.
empty
())
{
return
-
1
;
}
if
(
alias_vec
.
size
()
>
1
)
{
MIGRAPHX_THROW
(
"Currently, CustomOps in MIGraphX only supports one output_alias"
);
}
return
alias_vec
.
front
();
}
bool
runs_on_offload_target
()
const
{
return
op
.
runs_on_offload_target
();
}
};
};
template
<
class
CustomOp
>
template
<
class
CustomOp
>
...
@@ -613,9 +645,9 @@ struct migraphx_experimental_custom_op
...
@@ -613,9 +645,9 @@ struct migraphx_experimental_custom_op
migraphx
::
shape
output
,
migraphx
::
shape
output
,
std
::
vector
<
migraphx
::
argument
>
inputs
)
const
std
::
vector
<
migraphx
::
argument
>
inputs
)
const
{
{
std
::
remove_pointer_t
<
migraphx_argument_t
>
out
;
if
(
compute_f
==
nullptr
)
if
(
compute_f
==
nullptr
)
throw
std
::
runtime_error
(
"compute function is missing."
);
throw
std
::
runtime_error
(
"compute function is missing."
);
std
::
remove_pointer_t
<
migraphx_argument_t
>
out
;
std
::
array
<
char
,
256
>
exception_msg
;
std
::
array
<
char
,
256
>
exception_msg
;
exception_msg
.
front
()
=
'\0'
;
exception_msg
.
front
()
=
'\0'
;
auto
api_error_result
=
compute_f
(
&
out
,
auto
api_error_result
=
compute_f
(
&
out
,
...
@@ -637,9 +669,9 @@ struct migraphx_experimental_custom_op
...
@@ -637,9 +669,9 @@ struct migraphx_experimental_custom_op
migraphx_experimental_custom_op_compute_shape
compute_shape_f
=
nullptr
;
migraphx_experimental_custom_op_compute_shape
compute_shape_f
=
nullptr
;
migraphx
::
shape
compute_shape
(
std
::
vector
<
migraphx
::
shape
>
inputs
)
const
migraphx
::
shape
compute_shape
(
std
::
vector
<
migraphx
::
shape
>
inputs
)
const
{
{
std
::
remove_pointer_t
<
migraphx_shape_t
>
out
;
if
(
compute_shape_f
==
nullptr
)
if
(
compute_shape_f
==
nullptr
)
throw
std
::
runtime_error
(
"compute_shape function is missing."
);
throw
std
::
runtime_error
(
"compute_shape function is missing."
);
std
::
remove_pointer_t
<
migraphx_shape_t
>
out
;
std
::
array
<
char
,
256
>
exception_msg
;
std
::
array
<
char
,
256
>
exception_msg
;
exception_msg
.
front
()
=
'\0'
;
exception_msg
.
front
()
=
'\0'
;
auto
api_error_result
=
compute_shape_f
(
&
out
,
auto
api_error_result
=
compute_shape_f
(
&
out
,
...
@@ -655,6 +687,49 @@ struct migraphx_experimental_custom_op
...
@@ -655,6 +687,49 @@ struct migraphx_experimental_custom_op
}
}
return
(
&
out
)
->
object
;
return
(
&
out
)
->
object
;
}
}
migraphx_experimental_custom_op_output_alias
output_alias_f
=
nullptr
;
std
::
vector
<
size_t
>
output_alias
(
std
::
vector
<
migraphx
::
shape
>
inputs
)
const
{
if
(
output_alias_f
==
nullptr
)
throw
std
::
runtime_error
(
"output_alias function is missing."
);
std
::
array
<
size_t
,
1024
>
out
;
std
::
remove_pointer_t
<
size_t
*>
out_size
=
1024
;
std
::
array
<
char
,
256
>
exception_msg
;
exception_msg
.
front
()
=
'\0'
;
auto
api_error_result
=
output_alias_f
(
out
.
data
(),
&
out_size
,
object_ptr
.
data
,
exception_msg
.
data
(),
exception_msg
.
size
(),
object_cast
<
migraphx_shapes_t
>
(
&
(
inputs
)));
if
(
api_error_result
!=
migraphx_status_success
)
{
const
std
::
string
exception_str
(
exception_msg
.
data
());
throw
std
::
runtime_error
(
"Error in output_alias of: "
+
std
::
string
(
object_ptr
.
obj_typename
)
+
": "
+
exception_str
);
}
return
{
out
.
begin
(),
out
.
begin
()
+
out_size
};
// cppcheck-suppress returnDanglingLifetime;
}
migraphx_experimental_custom_op_runs_on_offload_target
runs_on_offload_target_f
=
nullptr
;
bool
runs_on_offload_target
()
const
{
if
(
runs_on_offload_target_f
==
nullptr
)
throw
std
::
runtime_error
(
"runs_on_offload_target function is missing."
);
std
::
remove_pointer_t
<
bool
*>
out
;
std
::
array
<
char
,
256
>
exception_msg
;
exception_msg
.
front
()
=
'\0'
;
auto
api_error_result
=
runs_on_offload_target_f
(
&
out
,
object_ptr
.
data
,
exception_msg
.
data
(),
exception_msg
.
size
());
if
(
api_error_result
!=
migraphx_status_success
)
{
const
std
::
string
exception_str
(
exception_msg
.
data
());
throw
std
::
runtime_error
(
"Error in runs_on_offload_target of: "
+
std
::
string
(
object_ptr
.
obj_typename
)
+
": "
+
exception_str
);
}
return
out
;
}
};
};
extern
"C"
migraphx_status
migraphx_shape_destroy
(
migraphx_shape_t
shape
)
extern
"C"
migraphx_status
migraphx_shape_destroy
(
migraphx_shape_t
shape
)
...
@@ -758,6 +833,16 @@ extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out,
...
@@ -758,6 +833,16 @@ extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out,
return
api_error_result
;
return
api_error_result
;
}
}
extern
"C"
migraphx_status
migraphx_shape_elements
(
size_t
*
out
,
const_migraphx_shape_t
shape
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
shape
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter shape: Null pointer"
);
*
out
=
(
shape
->
object
).
elements
();
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_shape_bytes
(
size_t
*
out
,
const_migraphx_shape_t
shape
)
extern
"C"
migraphx_status
migraphx_shape_bytes
(
size_t
*
out
,
const_migraphx_shape_t
shape
)
{
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
...
@@ -791,6 +876,16 @@ extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_sha
...
@@ -791,6 +876,16 @@ extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_sha
return
api_error_result
;
return
api_error_result
;
}
}
extern
"C"
migraphx_status
migraphx_shape_index
(
size_t
*
out
,
const_migraphx_shape_t
shape
,
size_t
i
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
shape
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter shape: Null pointer"
);
*
out
=
(
shape
->
object
).
index
((
i
));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_argument_destroy
(
migraphx_argument_t
argument
)
extern
"C"
migraphx_status
migraphx_argument_destroy
(
migraphx_argument_t
argument
)
{
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
destroy
((
argument
));
});
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
destroy
((
argument
));
});
...
@@ -1348,6 +1443,23 @@ extern "C" migraphx_status migraphx_program_run(migraphx_arguments_t* out,
...
@@ -1348,6 +1443,23 @@ extern "C" migraphx_status migraphx_program_run(migraphx_arguments_t* out,
return
api_error_result
;
return
api_error_result
;
}
}
extern
"C"
migraphx_status
migraphx_program_run_async
(
migraphx_arguments_t
*
out
,
migraphx_program_t
program
,
migraphx_program_parameters_t
params
,
void
*
s
,
const
char
*
name
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
program
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter program: Null pointer"
);
if
(
params
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter params: Null pointer"
);
*
out
=
allocate
<
migraphx_arguments_t
>
(
migraphx
::
run_async
((
program
->
object
),
(
params
->
object
),
(
s
),
(
name
)));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
extern
"C"
migraphx_status
migraphx_program_equal
(
bool
*
out
,
const_migraphx_program_t
program
,
const_migraphx_program_t
x
)
migraphx_program_equal
(
bool
*
out
,
const_migraphx_program_t
program
,
const_migraphx_program_t
x
)
{
{
...
@@ -1879,6 +1991,22 @@ extern "C" migraphx_status migraphx_experimental_custom_op_set_compute_shape(
...
@@ -1879,6 +1991,22 @@ extern "C" migraphx_status migraphx_experimental_custom_op_set_compute_shape(
return
api_error_result
;
return
api_error_result
;
}
}
extern
"C"
migraphx_status
migraphx_experimental_custom_op_set_output_alias
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_output_alias
input
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
(
obj
)
->
output_alias_f
=
(
input
);
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_experimental_custom_op_set_runs_on_offload_target
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_runs_on_offload_target
input
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
(
obj
)
->
runs_on_offload_target_f
=
(
input
);
});
return
api_error_result
;
}
extern
"C"
migraphx_status
extern
"C"
migraphx_status
migraphx_experimental_custom_op_register
(
migraphx_experimental_custom_op_t
experimental_custom_op
)
migraphx_experimental_custom_op_register
(
migraphx_experimental_custom_op_t
experimental_custom_op
)
{
{
...
...
src/api/include/migraphx/migraphx.h
View file @
8b25fd3e
...
@@ -26,7 +26,6 @@
...
@@ -26,7 +26,6 @@
#include <stdlib.h>
#include <stdlib.h>
#include <stdbool.h>
#include <stdbool.h>
// Add new types here
// Add new types here
// clang-format off
// clang-format off
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
...
@@ -144,6 +143,16 @@ typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraph
...
@@ -144,6 +143,16 @@ typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraph
size_t
exception_msg_size
,
size_t
exception_msg_size
,
migraphx_shapes_t
inputs
);
migraphx_shapes_t
inputs
);
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_output_alias
)(
size_t
*
out
,
size_t
*
out_size
,
void
*
obj
,
char
*
exception_msg
,
size_t
exception_msg_size
,
migraphx_shapes_t
inputs
);
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_runs_on_offload_target
)(
bool
*
out
,
void
*
obj
,
char
*
exception_msg
,
size_t
exception_msg_size
);
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_copy
)(
void
**
out
,
void
*
input
);
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_copy
)(
void
**
out
,
void
*
input
);
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_delete
)(
void
*
input
);
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_delete
)(
void
*
input
);
...
@@ -175,6 +184,8 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap
...
@@ -175,6 +184,8 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap
migraphx_status
migraphx_shape_type
(
migraphx_shape_datatype_t
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_type
(
migraphx_shape_datatype_t
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_elements
(
size_t
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_bytes
(
size_t
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_bytes
(
size_t
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_status
...
@@ -182,6 +193,8 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha
...
@@ -182,6 +193,8 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha
migraphx_status
migraphx_shape_standard
(
bool
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_standard
(
bool
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_index
(
size_t
*
out
,
const_migraphx_shape_t
shape
,
size_t
i
);
migraphx_status
migraphx_argument_destroy
(
migraphx_argument_t
argument
);
migraphx_status
migraphx_argument_destroy
(
migraphx_argument_t
argument
);
migraphx_status
migraphx_argument_assign_to
(
migraphx_argument_t
output
,
migraphx_status
migraphx_argument_assign_to
(
migraphx_argument_t
output
,
...
@@ -344,6 +357,12 @@ migraphx_status migraphx_program_run(migraphx_arguments_t* out,
...
@@ -344,6 +357,12 @@ migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_program_t
program
,
migraphx_program_t
program
,
migraphx_program_parameters_t
params
);
migraphx_program_parameters_t
params
);
migraphx_status
migraphx_program_run_async
(
migraphx_arguments_t
*
out
,
migraphx_program_t
program
,
migraphx_program_parameters_t
params
,
void
*
s
,
const
char
*
name
);
migraphx_status
migraphx_status
migraphx_program_equal
(
bool
*
out
,
const_migraphx_program_t
program
,
const_migraphx_program_t
x
);
migraphx_program_equal
(
bool
*
out
,
const_migraphx_program_t
program
,
const_migraphx_program_t
x
);
...
@@ -502,6 +521,13 @@ migraphx_experimental_custom_op_set_compute(migraphx_experimental_custom_op_t ob
...
@@ -502,6 +521,13 @@ migraphx_experimental_custom_op_set_compute(migraphx_experimental_custom_op_t ob
migraphx_status
migraphx_experimental_custom_op_set_compute_shape
(
migraphx_status
migraphx_experimental_custom_op_set_compute_shape
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_compute_shape
input
);
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_compute_shape
input
);
migraphx_status
migraphx_experimental_custom_op_set_output_alias
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_output_alias
input
);
migraphx_status
migraphx_experimental_custom_op_set_runs_on_offload_target
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_runs_on_offload_target
input
);
migraphx_status
migraphx_status
migraphx_experimental_custom_op_register
(
migraphx_experimental_custom_op_t
experimental_custom_op
);
migraphx_experimental_custom_op_register
(
migraphx_experimental_custom_op_t
experimental_custom_op
);
...
...
src/api/include/migraphx/migraphx.hpp
View file @
8b25fd3e
...
@@ -25,10 +25,12 @@
...
@@ -25,10 +25,12 @@
#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#include "migraphx.h"
#include "migraphx.h"
#include <algorithm>
#include <cstring>
#include <cstring>
#include <initializer_list>
#include <initializer_list>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.h>
#include <memory>
#include <memory>
#include <numeric>
#include <exception>
#include <exception>
#include <vector>
#include <vector>
#include <cassert>
#include <cassert>
...
@@ -340,6 +342,11 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
...
@@ -340,6 +342,11 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
this->set_handle(p, std::move(lifetime)); \
this->set_handle(p, std::move(lifetime)); \
}
}
template
<
size_t
N
>
struct
out_params
{
};
template
<
class
Base
>
template
<
class
Base
>
struct
interface_base
:
Base
struct
interface_base
:
Base
{
{
...
@@ -391,7 +398,22 @@ struct interface_base : Base
...
@@ -391,7 +398,22 @@ struct interface_base : Base
}
}
template
<
class
T
,
class
Setter
,
class
F
>
template
<
class
T
,
class
Setter
,
class
F
>
void
set_fp
(
Setter
setter
,
F
pf
)
void
set_fp
(
Setter
setter
,
F
pf
,
out_params
<
2
>
)
{
static
F
f
=
pf
;
(
void
)
f
;
// avoid warning on gcc
call
(
setter
,
this
->
get_handle_ptr
(),
[](
auto
out1
,
auto
out2
,
void
*
obj
,
char
*
ex_msg
,
size_t
ex_msg_size
,
auto
...
xs
)
->
migraphx_status
{
return
try_
([
&
]
{
call_cast_arg
<
T
>
(
rank
<
2
>
{},
f
,
out1
,
out2
,
obj
,
xs
...);
},
ex_msg
,
ex_msg_size
);
});
}
template
<
class
T
,
class
Setter
,
class
F
>
void
set_fp
(
Setter
setter
,
F
pf
,
out_params
<
1
>
)
{
{
static
F
f
=
pf
;
static
F
f
=
pf
;
(
void
)
f
;
// avoid warning on gcc
(
void
)
f
;
// avoid warning on gcc
...
@@ -405,11 +427,27 @@ struct interface_base : Base
...
@@ -405,11 +427,27 @@ struct interface_base : Base
}
}
template
<
class
T
,
class
Setter
,
class
F
>
template
<
class
T
,
class
Setter
,
class
F
>
void
set_
auto_
fp
(
Setter
setter
,
F
f
)
void
set_fp
(
Setter
setter
,
F
pf
,
out_params
<
0
>
)
{
{
return
set_fp
<
T
>
(
setter
,
[
=
](
T
&
obj
,
auto
out
,
auto
...
xs
)
{
static
F
f
=
pf
;
auto_invoke
(
f
,
out
,
obj
,
auto_convert_param
(
rank
<
2
>
{},
xs
)...);
(
void
)
f
;
// avoid warning on gcc
});
call
(
setter
,
this
->
get_handle_ptr
(),
[](
void
*
obj
,
char
*
ex_msg
,
size_t
ex_msg_size
,
auto
...
xs
)
->
migraphx_status
{
return
try_
(
[
&
]
{
call_cast_arg
<
T
>
(
rank
<
0
>
{},
f
,
obj
,
xs
...);
},
ex_msg
,
ex_msg_size
);
});
}
template
<
class
T
,
class
Setter
,
class
F
,
class
Out
>
void
set_auto_fp
(
Setter
setter
,
F
f
,
Out
nums
)
{
return
set_fp
<
T
>
(
setter
,
[
=
](
T
&
obj
,
auto
out1
,
auto
out2
,
auto
...
xs
)
{
auto_invoke
(
f
,
out1
,
out2
,
obj
,
auto_convert_param
(
rank
<
2
>
{},
xs
)...);
},
nums
);
}
}
struct
no_out_arg
struct
no_out_arg
...
@@ -419,7 +457,7 @@ struct interface_base : Base
...
@@ -419,7 +457,7 @@ struct interface_base : Base
template
<
class
T
,
class
F
,
class
X
,
class
...
Xs
,
class
=
std
::
enable_if_t
<
std
::
is_void
<
X
>{}
>>
template
<
class
T
,
class
F
,
class
X
,
class
...
Xs
,
class
=
std
::
enable_if_t
<
std
::
is_void
<
X
>{}
>>
static
void
call_cast_arg
(
rank
<
0
>
,
F
f
,
X
*
obj
,
Xs
...
xs
)
static
void
call_cast_arg
(
rank
<
0
>
,
F
f
,
X
*
obj
,
Xs
...
xs
)
{
{
f
(
reinterpret_cast
<
T
*>
(
obj
),
no_out_arg
{},
xs
...);
f
(
reinterpret_cast
<
T
*>
(
obj
),
no_out_arg
{},
no_out_arg
{},
xs
...);
}
}
template
<
class
T
,
template
<
class
T
,
...
@@ -430,17 +468,35 @@ struct interface_base : Base
...
@@ -430,17 +468,35 @@ struct interface_base : Base
class
=
std
::
enable_if_t
<
std
::
is_void
<
X
>{}
>>
class
=
std
::
enable_if_t
<
std
::
is_void
<
X
>{}
>>
static
void
call_cast_arg
(
rank
<
1
>
,
F
f
,
R
result
,
X
*
obj
,
Xs
...
xs
)
static
void
call_cast_arg
(
rank
<
1
>
,
F
f
,
R
result
,
X
*
obj
,
Xs
...
xs
)
{
{
f
(
*
reinterpret_cast
<
T
*>
(
obj
),
result
,
xs
...);
f
(
*
reinterpret_cast
<
T
*>
(
obj
),
result
,
no_out_arg
{},
xs
...);
}
template
<
class
T
,
class
F
,
class
R1
,
class
R2
,
class
X
,
class
...
Xs
,
class
=
std
::
enable_if_t
<
std
::
is_void
<
X
>{}
>>
static
void
call_cast_arg
(
rank
<
2
>
,
F
f
,
R1
result1
,
R2
result2
,
X
*
obj
,
Xs
...
xs
)
{
f
(
*
reinterpret_cast
<
T
*>
(
obj
),
result1
,
result2
,
xs
...);
}
template
<
class
F
,
class
T1
,
class
T2
,
class
...
Ts
>
void
auto_invoke
(
F
f
,
T1
*
out1
,
T2
*
out2
,
Ts
&&
...
xs
)
{
auto_assign
(
rank
<
2
>
{},
out1
,
out2
,
f
(
std
::
forward
<
Ts
>
(
xs
)...));
}
}
template
<
class
F
,
class
T
,
class
...
Ts
>
template
<
class
F
,
class
T
,
class
...
Ts
>
void
auto_invoke
(
F
f
,
T
*
out
,
Ts
&&
...
xs
)
void
auto_invoke
(
F
f
,
T
*
out
,
no_out_arg
,
Ts
&&
...
xs
)
{
{
auto_assign
(
rank
<
2
>
{},
out
,
f
(
std
::
forward
<
Ts
>
(
xs
)...));
auto_assign
(
rank
<
1
>
{},
out
,
f
(
std
::
forward
<
Ts
>
(
xs
)...));
}
}
template
<
class
F
,
class
T
,
class
...
Ts
>
template
<
class
F
,
class
T
,
class
...
Ts
>
void
auto_invoke
(
F
f
,
no_out_arg
,
Ts
&&
...
xs
)
void
auto_invoke
(
F
f
,
no_out_arg
,
no_out_arg
,
Ts
&&
...
xs
)
{
{
f
(
std
::
forward
<
Ts
>
(
xs
)...);
f
(
std
::
forward
<
Ts
>
(
xs
)...);
}
}
...
@@ -469,7 +525,7 @@ struct interface_base : Base
...
@@ -469,7 +525,7 @@ struct interface_base : Base
template
<
class
T
,
class
U
>
template
<
class
T
,
class
U
>
void
auto_assign
(
rank
<
0
>
,
T
*
out
,
U
x
)
void
auto_assign
(
rank
<
0
>
,
T
*
out
,
U
x
)
{
{
return
*
out
=
x
;
*
out
=
x
;
}
}
template
<
class
T
,
class
U
>
template
<
class
T
,
class
U
>
...
@@ -477,12 +533,21 @@ struct interface_base : Base
...
@@ -477,12 +533,21 @@ struct interface_base : Base
{
{
x
.
assign_to_handle
(
out
);
x
.
assign_to_handle
(
out
);
}
}
template
<
class
T1
,
class
T2
,
class
U
,
class
=
std
::
enable_if_t
<
std
::
is_same
<
T2
,
size_t
>{}
>>
auto
auto_assign
(
rank
<
2
>
,
T1
*
out_ptr
,
T2
*
out_size
,
U
x
)
{
*
out_size
=
std
::
min
(
*
out_size
,
x
.
size
());
std
::
copy_n
(
x
.
begin
(),
*
out_size
,
out_ptr
);
}
};
};
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_INTERFACE_LIFT(T, prefix, name) \
#define MIGRAPHX_INTERFACE_LIFT(n_out, T, prefix, name) \
this->set_auto_fp<T>(&migraphx_##prefix##_set_##name, \
this->set_auto_fp<T>( \
[](T& x, auto... xs) { return x.name(xs...); })
&migraphx_##prefix##_set_##name, \
[](T& x, auto... xs) { return x.name(xs...); }, \
out_params<n_out>{})
template
<
class
Base
,
class
T
>
template
<
class
Base
,
class
T
>
using
require_interface
=
using
require_interface
=
...
@@ -567,6 +632,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
...
@@ -567,6 +632,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return
pout
;
return
pout
;
}
}
size_t
elements
()
const
{
size_t
pout
;
call
(
&
migraphx_shape_elements
,
&
pout
,
this
->
get_handle_ptr
());
return
pout
;
}
size_t
bytes
()
const
size_t
bytes
()
const
{
{
size_t
pout
;
size_t
pout
;
...
@@ -581,6 +653,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
...
@@ -581,6 +653,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return
result
;
return
result
;
}
}
// map element index to space index
size_t
index
(
size_t
i
)
const
{
size_t
result
;
call
(
&
migraphx_shape_index
,
&
result
,
this
->
get_handle_ptr
(),
i
);
return
result
;
}
friend
bool
operator
==
(
const
shape
&
px
,
const
shape
&
py
)
friend
bool
operator
==
(
const
shape
&
px
,
const
shape
&
py
)
{
{
bool
pout
;
bool
pout
;
...
@@ -628,9 +708,15 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
...
@@ -628,9 +708,15 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
template
<
typename
T
>
template
<
typename
T
>
std
::
vector
<
T
>
as_vector
()
const
std
::
vector
<
T
>
as_vector
()
const
{
{
size_t
vector_len
=
this
->
get_shape
().
bytes
()
/
sizeof
(
T
);
auto
ss
=
this
->
get_shape
();
T
*
buffer_ptr
=
reinterpret_cast
<
T
*>
(
this
->
data
());
auto
num_elements
=
ss
.
elements
();
return
{
buffer_ptr
,
buffer_ptr
+
vector_len
};
std
::
vector
<
T
>
res
(
num_elements
);
T
*
buffer_ptr
=
reinterpret_cast
<
T
*>
(
this
->
data
());
for
(
size_t
i
=
0
;
i
<
num_elements
;
i
++
)
{
res
[
i
]
=
buffer_ptr
[
ss
.
index
(
i
)];
}
return
res
;
}
}
/// Generate an argument using random data
/// Generate an argument using random data
...
@@ -979,6 +1065,20 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
...
@@ -979,6 +1065,20 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return
arguments
(
pout
,
own
{});
return
arguments
(
pout
,
own
{});
}
}
template
<
class
Stream
>
/// Overloaded to allow for execution_environment input
arguments
run_async
(
const
program_parameters
&
pparams
,
Stream
*
s
)
const
{
migraphx_arguments_t
pout
;
call
(
&
migraphx_program_run_async
,
&
pout
,
this
->
get_handle_ptr
(),
pparams
.
get_handle_ptr
(),
s
,
get_type_name
<
Stream
>
().
c_str
());
return
arguments
(
pout
,
own
{});
}
void
print
()
const
{
call
(
&
migraphx_program_print
,
this
->
get_handle_ptr
());
}
void
print
()
const
{
call
(
&
migraphx_program_print
,
this
->
get_handle_ptr
());
}
program
sort
()
program
sort
()
...
@@ -1255,7 +1355,10 @@ struct experimental_custom_op_base
...
@@ -1255,7 +1355,10 @@ struct experimental_custom_op_base
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
argument
compute
(
context
ctx
,
shape
output
,
arguments
inputs
)
const
=
0
;
virtual
argument
compute
(
context
ctx
,
shape
output
,
arguments
inputs
)
const
=
0
;
virtual
shape
compute_shape
(
shapes
inputs
)
const
=
0
;
virtual
shape
compute_shape
(
shapes
inputs
)
const
=
0
;
virtual
~
experimental_custom_op_base
()
=
default
;
virtual
std
::
vector
<
size_t
>
output_alias
(
shapes
)
const
{
return
{};
}
// TODO: Return target string instead of bool
virtual
bool
runs_on_offload_target
()
const
=
0
;
virtual
~
experimental_custom_op_base
()
=
default
;
};
};
struct
experimental_custom_op
:
interface_base
<
MIGRAPHX_HANDLE_BASE
(
experimental_custom_op
)
>
struct
experimental_custom_op
:
interface_base
<
MIGRAPHX_HANDLE_BASE
(
experimental_custom_op
)
>
...
@@ -1267,8 +1370,10 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental
...
@@ -1267,8 +1370,10 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental
obj
,
obj
,
get_type_name
(
obj
).
c_str
(),
get_type_name
(
obj
).
c_str
(),
obj
.
name
().
c_str
());
obj
.
name
().
c_str
());
MIGRAPHX_INTERFACE_LIFT
(
T
,
experimental_custom_op
,
compute_shape
);
MIGRAPHX_INTERFACE_LIFT
(
1
,
T
,
experimental_custom_op
,
compute_shape
);
MIGRAPHX_INTERFACE_LIFT
(
T
,
experimental_custom_op
,
compute
);
MIGRAPHX_INTERFACE_LIFT
(
1
,
T
,
experimental_custom_op
,
compute
);
MIGRAPHX_INTERFACE_LIFT
(
2
,
T
,
experimental_custom_op
,
output_alias
);
MIGRAPHX_INTERFACE_LIFT
(
1
,
T
,
experimental_custom_op
,
runs_on_offload_target
);
}
}
void
register_op
()
{
call
(
&
migraphx_experimental_custom_op_register
,
this
->
get_handle_ptr
());
}
void
register_op
()
{
call
(
&
migraphx_experimental_custom_op_register
,
this
->
get_handle_ptr
());
}
...
...
src/api/migraphx.py
View file @
8b25fd3e
...
@@ -115,6 +115,7 @@ def shape(h):
...
@@ -115,6 +115,7 @@ def shape(h):
const
=
True
)
const
=
True
)
h
.
method
(
'strides'
,
returns
=
'const std::vector<size_t>&'
,
const
=
True
)
h
.
method
(
'strides'
,
returns
=
'const std::vector<size_t>&'
,
const
=
True
)
h
.
method
(
'type'
,
returns
=
'migraphx::shape::type_t'
,
const
=
True
)
h
.
method
(
'type'
,
returns
=
'migraphx::shape::type_t'
,
const
=
True
)
h
.
method
(
'elements'
,
returns
=
'size_t'
,
const
=
True
)
h
.
method
(
'bytes'
,
returns
=
'size_t'
,
const
=
True
)
h
.
method
(
'bytes'
,
returns
=
'size_t'
,
const
=
True
)
h
.
method
(
'equal'
,
h
.
method
(
'equal'
,
api
.
params
(
x
=
'const migraphx::shape&'
),
api
.
params
(
x
=
'const migraphx::shape&'
),
...
@@ -122,6 +123,7 @@ def shape(h):
...
@@ -122,6 +123,7 @@ def shape(h):
returns
=
'bool'
,
returns
=
'bool'
,
const
=
True
)
const
=
True
)
h
.
method
(
'standard'
,
returns
=
'bool'
,
const
=
True
)
h
.
method
(
'standard'
,
returns
=
'bool'
,
const
=
True
)
h
.
method
(
'index'
,
api
.
params
(
i
=
'size_t'
),
returns
=
'size_t'
,
const
=
True
)
@
auto_handle
()
@
auto_handle
()
...
@@ -274,6 +276,13 @@ def program(h):
...
@@ -274,6 +276,13 @@ def program(h):
params
=
'std::unordered_map<std::string, migraphx::argument>'
),
params
=
'std::unordered_map<std::string, migraphx::argument>'
),
invoke
=
'migraphx::run($@)'
,
invoke
=
'migraphx::run($@)'
,
returns
=
'std::vector<migraphx::argument>'
)
returns
=
'std::vector<migraphx::argument>'
)
h
.
method
(
'run_async'
,
api
.
params
(
params
=
'std::unordered_map<std::string, migraphx::argument>'
,
s
=
'void*'
,
name
=
'const char *'
),
invoke
=
'migraphx::run_async($@)'
,
returns
=
'std::vector<migraphx::argument>'
)
h
.
method
(
'equal'
,
h
.
method
(
'equal'
,
api
.
params
(
x
=
'const migraphx::program&'
),
api
.
params
(
x
=
'const migraphx::program&'
),
invoke
=
'migraphx::equal($@)'
,
invoke
=
'migraphx::equal($@)'
,
...
@@ -450,4 +459,8 @@ def experimental_custom_op(h):
...
@@ -450,4 +459,8 @@ def experimental_custom_op(h):
h
.
virtual
(
'compute_shape'
,
h
.
virtual
(
'compute_shape'
,
api
.
params
(
inputs
=
'std::vector<migraphx::shape>'
),
api
.
params
(
inputs
=
'std::vector<migraphx::shape>'
),
returns
=
'migraphx::shape'
)
returns
=
'migraphx::shape'
)
h
.
virtual
(
'output_alias'
,
api
.
params
(
inputs
=
'std::vector<migraphx::shape>'
),
returns
=
'std::vector<size_t>'
)
h
.
virtual
(
'runs_on_offload_target'
,
returns
=
'bool'
)
h
.
method
(
'register'
,
invoke
=
'migraphx::register_custom_op($@)'
)
h
.
method
(
'register'
,
invoke
=
'migraphx::register_custom_op($@)'
)
src/include/migraphx/context.hpp
View file @
8b25fd3e
...
@@ -66,6 +66,15 @@ any_ptr get_queue_context(T&)
...
@@ -66,6 +66,15 @@ any_ptr get_queue_context(T&)
{
{
return
{};
return
{};
}
}
template
<
class
T
>
void
wait_for_context
(
T
&
,
any_ptr
)
{
}
template
<
class
T
>
void
finish_on_context
(
T
&
,
any_ptr
)
{
}
#ifdef TYPE_ERASED_DECLARATION
#ifdef TYPE_ERASED_DECLARATION
...
@@ -78,6 +87,10 @@ struct context
...
@@ -78,6 +87,10 @@ struct context
void
from_value
(
const
value
&
v
);
void
from_value
(
const
value
&
v
);
// (optional)
// (optional)
any_ptr
get_queue
();
any_ptr
get_queue
();
// (optional)
void
wait_for
(
any_ptr
queue
);
// (optional)
void
finish_on
(
any_ptr
queue
);
//
//
void
finish
()
const
;
void
finish
()
const
;
};
};
...
@@ -165,6 +178,18 @@ struct context
...
@@ -165,6 +178,18 @@ struct context
return
(
*
this
).
private_detail_te_get_handle
().
get_queue
();
return
(
*
this
).
private_detail_te_get_handle
().
get_queue
();
}
}
void
wait_for
(
any_ptr
queue
)
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
wait_for
(
queue
);
}
void
finish_on
(
any_ptr
queue
)
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
finish_on
(
queue
);
}
void
finish
()
const
void
finish
()
const
{
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
assert
((
*
this
).
private_detail_te_handle_mem_var
);
...
@@ -187,6 +212,8 @@ struct context
...
@@ -187,6 +212,8 @@ struct context
virtual
value
to_value
()
const
=
0
;
virtual
value
to_value
()
const
=
0
;
virtual
void
from_value
(
const
value
&
v
)
=
0
;
virtual
void
from_value
(
const
value
&
v
)
=
0
;
virtual
any_ptr
get_queue
()
=
0
;
virtual
any_ptr
get_queue
()
=
0
;
virtual
void
wait_for
(
any_ptr
queue
)
=
0
;
virtual
void
finish_on
(
any_ptr
queue
)
=
0
;
virtual
void
finish
()
const
=
0
;
virtual
void
finish
()
const
=
0
;
};
};
...
@@ -231,6 +258,33 @@ struct context
...
@@ -231,6 +258,33 @@ struct context
return
get_queue_context
(
private_detail_te_self
);
return
get_queue_context
(
private_detail_te_self
);
}
}
template
<
class
T
>
static
auto
private_detail_te_default_wait_for
(
char
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
->
decltype
(
private_detail_te_self
.
wait_for
(
queue
))
{
private_detail_te_self
.
wait_for
(
queue
);
}
template
<
class
T
>
static
void
private_detail_te_default_wait_for
(
float
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
{
wait_for_context
(
private_detail_te_self
,
queue
);
}
template
<
class
T
>
static
auto
private_detail_te_default_finish_on
(
char
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
->
decltype
(
private_detail_te_self
.
finish_on
(
queue
))
{
private_detail_te_self
.
finish_on
(
queue
);
}
template
<
class
T
>
static
void
private_detail_te_default_finish_on
(
float
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
{
finish_on_context
(
private_detail_te_self
,
queue
);
}
template
<
typename
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedT
>
struct
private_detail_te_handle_type
:
private_detail_te_handle_base_type
struct
private_detail_te_handle_type
:
private_detail_te_handle_base_type
{
{
...
@@ -248,7 +302,7 @@ struct context
...
@@ -248,7 +302,7 @@ struct context
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
)
)
:
private_detail_te_value
(
value
)
{
{
}
}
...
@@ -277,6 +331,18 @@ struct context
...
@@ -277,6 +331,18 @@ struct context
return
private_detail_te_default_get_queue
(
char
(
0
),
private_detail_te_value
);
return
private_detail_te_default_get_queue
(
char
(
0
),
private_detail_te_value
);
}
}
void
wait_for
(
any_ptr
queue
)
override
{
private_detail_te_default_wait_for
(
char
(
0
),
private_detail_te_value
,
queue
);
}
void
finish_on
(
any_ptr
queue
)
override
{
private_detail_te_default_finish_on
(
char
(
0
),
private_detail_te_value
,
queue
);
}
void
finish
()
const
override
{
private_detail_te_value
.
finish
();
}
void
finish
()
const
override
{
private_detail_te_value
.
finish
();
}
PrivateDetailTypeErasedT
private_detail_te_value
;
PrivateDetailTypeErasedT
private_detail_te_value
;
...
...
src/include/migraphx/execution_environment.hpp
0 → 100644
View file @
8b25fd3e
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP
#include <migraphx/any_ptr.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
execution_environment
{
any_ptr
queue
=
any_ptr
{};
bool
async
=
false
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
/* MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP */
src/include/migraphx/program.hpp
View file @
8b25fd3e
...
@@ -37,6 +37,7 @@
...
@@ -37,6 +37,7 @@
#include <migraphx/assignment_options.hpp>
#include <migraphx/assignment_options.hpp>
#include <migraphx/env.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/execution_environment.hpp>
#include <algorithm>
#include <algorithm>
#include <iostream>
#include <iostream>
...
@@ -76,8 +77,8 @@ struct program
...
@@ -76,8 +77,8 @@ struct program
std
::
unordered_map
<
std
::
string
,
shape
>
get_parameter_shapes
()
const
;
std
::
unordered_map
<
std
::
string
,
shape
>
get_parameter_shapes
()
const
;
std
::
vector
<
argument
>
eval
(
parameter_map
params
)
const
;
std
::
vector
<
argument
>
eval
(
parameter_map
params
,
execution_environment
exec_env
=
execution_environment
{})
const
;
std
::
size_t
size
()
const
;
std
::
size_t
size
()
const
;
std
::
vector
<
shape
>
get_output_shapes
()
const
;
std
::
vector
<
shape
>
get_output_shapes
()
const
;
...
...
src/program.cpp
View file @
8b25fd3e
...
@@ -398,7 +398,7 @@ std::vector<argument> generic_eval(const program& p,
...
@@ -398,7 +398,7 @@ std::vector<argument> generic_eval(const program& p,
return
generic_eval
(
mm
,
ctx
,
params
,
{},
make_trace
);
return
generic_eval
(
mm
,
ctx
,
params
,
{},
make_trace
);
}
}
std
::
vector
<
argument
>
program
::
eval
(
parameter_map
params
)
const
std
::
vector
<
argument
>
program
::
eval
(
parameter_map
params
,
execution_environment
exec_env
)
const
{
{
auto
&
ctx
=
this
->
impl
->
ctx
;
auto
&
ctx
=
this
->
impl
->
ctx
;
#ifndef NDEBUG
#ifndef NDEBUG
...
@@ -423,6 +423,12 @@ std::vector<argument> program::eval(parameter_map params) const
...
@@ -423,6 +423,12 @@ std::vector<argument> program::eval(parameter_map params) const
#endif
#endif
auto
trace_level
=
value_of
(
MIGRAPHX_TRACE_EVAL
{});
auto
trace_level
=
value_of
(
MIGRAPHX_TRACE_EVAL
{});
std
::
vector
<
argument
>
ret
;
if
(
exec_env
.
async
)
{
ctx
.
wait_for
(
exec_env
.
queue
);
}
if
(
trace_level
>
0
)
if
(
trace_level
>
0
)
{
{
...
@@ -434,49 +440,56 @@ std::vector<argument> program::eval(parameter_map params) const
...
@@ -434,49 +440,56 @@ std::vector<argument> program::eval(parameter_map params) const
ins_out
[
x
]
=
ss
.
str
();
ins_out
[
x
]
=
ss
.
str
();
});
});
ret
urn
generic_eval
(
*
this
,
ret
=
generic_eval
(
*
this
,
ctx
,
ctx
,
std
::
move
(
params
),
std
::
move
(
params
),
with_check_context
([
&
](
auto
&
ins
,
auto
f
,
auto
&&
check_context
)
{
with_check_context
([
&
](
auto
&
ins
,
auto
f
,
auto
&&
check_context
)
{
ctx
.
finish
();
ctx
.
finish
();
std
::
cout
<<
"Run instruction: "
<<
ins_out
.
at
(
ins
)
<<
std
::
endl
;
std
::
cout
<<
"Run instruction: "
<<
ins_out
.
at
(
ins
)
<<
std
::
endl
;
timer
t
{};
timer
t
{};
auto
result
=
check_context
(
f
);
auto
result
=
check_context
(
f
);
double
t1
=
t
.
record
<
milliseconds
>
();
double
t1
=
t
.
record
<
milliseconds
>
();
ctx
.
finish
();
ctx
.
finish
();
double
t2
=
t
.
record
<
milliseconds
>
();
double
t2
=
t
.
record
<
milliseconds
>
();
std
::
cout
<<
"Time: "
<<
t1
<<
"ms, "
<<
t2
<<
"ms"
<<
std
::
endl
;
std
::
cout
<<
"Time: "
<<
t1
<<
"ms, "
<<
t2
<<
"ms"
<<
std
::
endl
;
if
(
trace_level
>
1
and
ins
->
name
().
front
()
!=
'@'
and
if
(
trace_level
>
1
and
ins
->
name
().
front
()
!=
'@'
and
ins
->
name
()
!=
"load"
and
not
result
.
empty
())
ins
->
name
()
!=
"load"
and
not
result
.
empty
())
{
{
target
tgt
=
make_target
(
this
->
impl
->
target_name
);
target
tgt
=
make_target
(
this
->
impl
->
target_name
);
auto
buffer
=
tgt
.
copy_from
(
result
);
auto
buffer
=
tgt
.
copy_from
(
result
);
if
(
trace_level
==
2
)
if
(
trace_level
==
2
)
{
{
std
::
cout
<<
"Output has "
std
::
cout
<<
"Output has "
<<
to_string_range
(
classify_argument
(
buffer
))
<<
to_string_range
(
classify_argument
(
buffer
))
<<
std
::
endl
;
<<
std
::
endl
;
std
::
cout
<<
"Output: "
;
std
::
cout
<<
"Output: "
;
preview_argument
(
std
::
cout
,
buffer
);
preview_argument
(
std
::
cout
,
buffer
);
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
}
}
else
else
{
{
std
::
cout
<<
"Output: "
<<
buffer
<<
std
::
endl
;
std
::
cout
<<
"Output: "
<<
buffer
<<
std
::
endl
;
}
}
}
}
return
result
;
return
result
;
}));
}));
}
}
else
else
{
{
ret
urn
generic_eval
(
*
this
,
ret
=
generic_eval
(
*
this
,
ctx
,
ctx
,
std
::
move
(
params
),
std
::
move
(
params
),
with_check_context
([
&
](
auto
&
,
auto
f
,
auto
&&
check_context
)
{
with_check_context
([
&
](
auto
&
,
auto
f
,
auto
&&
check_context
)
{
return
check_context
(
f
);
return
check_context
(
f
);
}));
}));
}
}
if
(
exec_env
.
async
)
{
ctx
.
finish_on
(
exec_env
.
queue
);
}
return
ret
;
}
}
const
int
program_file_version
=
5
;
const
int
program_file_version
=
5
;
...
...
src/py/migraphx_py.cpp
View file @
8b25fd3e
...
@@ -355,6 +355,23 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
...
@@ -355,6 +355,23 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
}
}
return
p
.
eval
(
pm
);
return
p
.
eval
(
pm
);
})
})
.
def
(
"run_async"
,
[](
migraphx
::
program
&
p
,
py
::
dict
params
,
std
::
uintptr_t
stream
,
std
::
string
stream_name
)
{
migraphx
::
parameter_map
pm
;
for
(
auto
x
:
params
)
{
std
::
string
key
=
x
.
first
.
cast
<
std
::
string
>
();
py
::
buffer
b
=
x
.
second
.
cast
<
py
::
buffer
>
();
py
::
buffer_info
info
=
b
.
request
();
pm
[
key
]
=
migraphx
::
argument
(
to_shape
(
info
),
info
.
ptr
);
}
migraphx
::
execution_environment
exec_env
{
migraphx
::
any_ptr
(
reinterpret_cast
<
void
*>
(
stream
),
stream_name
),
true
};
return
p
.
eval
(
pm
,
exec_env
);
})
.
def
(
"sort"
,
&
migraphx
::
program
::
sort
)
.
def
(
"sort"
,
&
migraphx
::
program
::
sort
)
.
def
(
"print"
,
[](
const
migraphx
::
program
&
p
)
{
std
::
cout
<<
p
<<
std
::
endl
;
})
.
def
(
"print"
,
[](
const
migraphx
::
program
&
p
)
{
std
::
cout
<<
p
<<
std
::
endl
;
})
.
def
(
"__eq__"
,
std
::
equal_to
<
migraphx
::
program
>
{})
.
def
(
"__eq__"
,
std
::
equal_to
<
migraphx
::
program
>
{})
...
...
src/simplify_algebra.cpp
View file @
8b25fd3e
...
@@ -933,6 +933,73 @@ struct find_div_const
...
@@ -933,6 +933,73 @@ struct find_div_const
}
}
};
};
struct
find_unit_ops
{
auto
matcher
()
const
{
auto
mul_1
=
match
::
name
(
"mul"
)(
match
::
either_arg
(
0
,
1
)(
match
::
has_value
(
1.0
f
),
match
::
any
().
bind
(
"x"
)));
auto
div_1
=
match
::
name
(
"div"
)(
match
::
args
(
match
::
any
().
bind
(
"x"
),
match
::
has_value
(
1.0
f
)));
auto
add_0
=
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
has_value
(
0.0
f
,
1e-12
),
match
::
any
().
bind
(
"x"
)));
auto
sub_0
=
match
::
name
(
"sub"
)(
match
::
args
(
match
::
any
().
bind
(
"x"
),
match
::
has_value
(
0.0
f
)));
return
match
::
any_of
(
mul_1
,
div_1
,
add_0
,
sub_0
);
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
c_in
=
r
.
instructions
[
"x"
];
m
.
replace_instruction
(
ins
,
c_in
);
}
};
struct
find_neg_unit_ops
{
auto
matcher
()
const
{
auto
mul_neg_1
=
match
::
name
(
"mul"
)(
match
::
either_arg
(
0
,
1
)(
match
::
has_value
(
-
1.0
f
),
match
::
any
().
bind
(
"x"
)));
auto
div_neg_1
=
match
::
name
(
"div"
)(
match
::
args
(
match
::
any
().
bind
(
"x"
),
match
::
has_value
(
-
1.0
f
)));
auto
sub_0
=
match
::
name
(
"sub"
)(
match
::
args
(
match
::
has_value
(
0.0
f
),
match
::
any
().
bind
(
"x"
)));
return
match
::
any_of
(
mul_neg_1
,
div_neg_1
,
sub_0
);
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
c_in
=
r
.
instructions
[
"x"
];
auto
neg
=
m
.
add_instruction
(
make_op
(
"neg"
),
c_in
);
m
.
replace_instruction
(
ins
,
neg
);
}
};
struct
find_zero_ops
{
auto
matcher
()
const
{
auto
mul_zero
=
match
::
name
(
"mul"
)(
match
::
either_arg
(
0
,
1
)(
match
::
has_value
(
0.0
f
).
bind
(
"x"
),
match
::
any
()));
auto
div_zero
=
match
::
name
(
"div"
)(
match
::
args
(
match
::
has_value
(
0.0
f
).
bind
(
"x"
),
match
::
any
()));
return
match
::
any_of
(
mul_zero
,
div_zero
);
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
zero_ins
=
r
.
instructions
[
"x"
];
m
.
replace_instruction
(
ins
,
zero_ins
);
}
};
struct
find_sub_const
struct
find_sub_const
{
{
auto
matcher
()
const
auto
matcher
()
const
...
@@ -1149,6 +1216,9 @@ void simplify_algebra::apply(module& m) const
...
@@ -1149,6 +1216,9 @@ void simplify_algebra::apply(module& m) const
find_mul_conv
{},
find_mul_conv
{},
find_mul_slice_conv
{},
find_mul_slice_conv
{},
find_mul_add
{},
find_mul_add
{},
find_unit_ops
{},
find_neg_unit_ops
{},
find_zero_ops
{},
find_dot_add
{},
find_dot_add
{},
find_div_const
{},
find_div_const
{},
find_sub_const
{},
find_sub_const
{},
...
...
src/targets/gpu/include/migraphx/gpu/context.hpp
View file @
8b25fd3e
...
@@ -197,7 +197,9 @@ struct hip_device
...
@@ -197,7 +197,9 @@ struct hip_device
struct
context
struct
context
{
{
context
(
std
::
size_t
device_id
=
0
,
std
::
size_t
n
=
value_of
(
MIGRAPHX_NSTREAMS
{},
1
))
context
(
std
::
size_t
device_id
=
0
,
std
::
size_t
n
=
value_of
(
MIGRAPHX_NSTREAMS
{},
1
))
:
current_device
(
std
::
make_shared
<
hip_device
>
(
device_id
,
n
))
:
current_device
(
std
::
make_shared
<
hip_device
>
(
device_id
,
n
)),
begin_event
(
create_event
()),
finish_event
(
create_event
())
{
{
}
}
...
@@ -274,6 +276,24 @@ struct context
...
@@ -274,6 +276,24 @@ struct context
this
->
current_device
=
std
::
make_shared
<
hip_device
>
(
0
,
n_streams
);
this
->
current_device
=
std
::
make_shared
<
hip_device
>
(
0
,
n_streams
);
}
}
void
wait_for
(
any_ptr
queue
)
{
auto
status
=
hipEventRecord
(
begin_event
.
get
(),
queue
.
get
<
hipStream_t
>
());
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"failed to record "
+
hip_error
(
status
));
get_stream
().
wait
(
begin_event
.
get
());
}
void
finish_on
(
any_ptr
queue
)
{
get_stream
().
record
(
finish_event
.
get
());
auto
status
=
hipStreamWaitEvent
(
queue
.
get
<
hipStream_t
>
(),
finish_event
.
get
(),
0
);
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"Failed to wait on event "
+
hip_error
(
status
));
}
any_ptr
get_queue
()
{
return
get_stream
().
get
();
}
any_ptr
get_queue
()
{
return
get_stream
().
get
();
}
void
enable_perf_measurement
(
bool
b
=
true
)
void
enable_perf_measurement
(
bool
b
=
true
)
...
@@ -316,9 +336,13 @@ struct context
...
@@ -316,9 +336,13 @@ struct context
// TODO: Make this a vector to support multiple devices
// TODO: Make this a vector to support multiple devices
std
::
shared_ptr
<
hip_device
>
current_device
;
std
::
shared_ptr
<
hip_device
>
current_device
;
std
::
vector
<
shared
<
hip_event_ptr
>>
events
;
std
::
vector
<
shared
<
hip_event_ptr
>>
events
;
bool
measure_perf
=
false
;
bool
measure_perf
=
false
;
// for event perf timing
shared
<
hip_event_ptr
>
start_event
=
nullptr
;
shared
<
hip_event_ptr
>
start_event
=
nullptr
;
shared
<
hip_event_ptr
>
stop_event
=
nullptr
;
shared
<
hip_event_ptr
>
stop_event
=
nullptr
;
// for stream syncronization
shared
<
hip_event_ptr
>
begin_event
=
nullptr
;
shared
<
hip_event_ptr
>
finish_event
=
nullptr
;
};
};
inline
void
migraphx_to_value
(
value
&
v
,
const
context
&
ctx
)
{
v
=
ctx
.
to_value
();
}
inline
void
migraphx_to_value
(
value
&
v
,
const
context
&
ctx
)
{
v
=
ctx
.
to_value
();
}
...
...
src/targets/gpu/jit/softmax.cpp
View file @
8b25fd3e
...
@@ -32,6 +32,8 @@ namespace migraphx {
...
@@ -32,6 +32,8 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_USE_FAST_SOFTMAX
)
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
static
const
char
*
const
softmax_kernel
=
R"__migraphx__(
static
const
char
*
const
softmax_kernel
=
R"__migraphx__(
...
@@ -81,6 +83,9 @@ struct softmax_compiler : compiler<softmax_compiler>
...
@@ -81,6 +83,9 @@ struct softmax_compiler : compiler<softmax_compiler>
options
.
inputs
=
inputs
;
options
.
inputs
=
inputs
;
options
.
kernel_name
=
"softmax_kernel"
;
options
.
kernel_name
=
"softmax_kernel"
;
if
(
enabled
(
MIGRAPHX_USE_FAST_SOFTMAX
{}))
options
.
params
=
"-DMIGRAPHX_USE_FAST_SOFTMAX"
;
auto
src
=
interpolate_string
(
auto
src
=
interpolate_string
(
softmax_kernel
,
softmax_kernel
,
{{
"transformers"
,
make_transformer_args
(
vec
)},
{
"axis"
,
to_string
(
axis
)}});
{{
"transformers"
,
make_transformer_args
(
vec
)},
{
"axis"
,
to_string
(
axis
)}});
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
8b25fd3e
...
@@ -197,11 +197,11 @@ struct block
...
@@ -197,11 +197,11 @@ struct block
struct
reducer
struct
reducer
{
{
index
idx
;
index
idx
;
Slicer
slice
r
;
Slicer
slice
;
template
<
class
Op
,
class
T
,
class
Read
>
template
<
class
Op
,
class
T
,
class
Read
>
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
{
{
return
sliced
(
slice
r
,
[
=
](
auto
x
,
auto
...
xs
)
{
return
sliced
(
slice
,
[
=
](
auto
x
,
auto
...
xs
)
{
return
block_reduce
(
idx
,
op
,
init
,
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
return
block_reduce
(
idx
,
op
,
init
,
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
return
vec_reduce
(
read
(
x
[
j
],
xs
[
j
]...),
op
);
return
vec_reduce
(
read
(
x
[
j
],
xs
[
j
]...),
op
);
});
});
...
@@ -218,7 +218,7 @@ struct block
...
@@ -218,7 +218,7 @@ struct block
template
<
class
F
>
template
<
class
F
>
__device__
auto
inner
(
F
f
)
const
__device__
auto
inner
(
F
f
)
const
{
{
return
sliced
(
slice
r
,
[
=
](
auto
x
,
auto
...
xs
)
{
return
sliced
(
slice
,
[
=
](
auto
x
,
auto
...
xs
)
{
idx
.
local_stride
(
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
f
(
x
[
j
],
xs
[
j
]...);
});
idx
.
local_stride
(
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
f
(
x
[
j
],
xs
[
j
]...);
});
});
});
}
}
...
@@ -226,7 +226,7 @@ struct block
...
@@ -226,7 +226,7 @@ struct block
template
<
class
Input
>
template
<
class
Input
>
constexpr
auto
elements
()
const
constexpr
auto
elements
()
const
{
{
using
reduce_type
=
decltype
(
slice
r
(
Input
{}));
using
reduce_type
=
decltype
(
slice
(
Input
{}));
using
value_type
=
typename
Input
::
type
;
using
value_type
=
typename
Input
::
type
;
constexpr
auto
relements
=
get_shape_c
<
reduce_type
>
{}.
elements
();
constexpr
auto
relements
=
get_shape_c
<
reduce_type
>
{}.
elements
();
if
constexpr
(
vec_size
<
value_type
>
()
>
1
)
if
constexpr
(
vec_size
<
value_type
>
()
>
1
)
...
@@ -260,11 +260,11 @@ struct lane
...
@@ -260,11 +260,11 @@ struct lane
struct
reducer
struct
reducer
{
{
index
idx
;
index
idx
;
Slicer
slice
r
;
Slicer
slice
;
template
<
class
Op
,
class
T
,
class
Read
>
template
<
class
Op
,
class
T
,
class
Read
>
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
{
{
return
sliced
(
slice
r
,
[
=
](
auto
x
,
auto
...
xs
)
{
return
sliced
(
slice
,
[
=
](
auto
x
,
auto
...
xs
)
{
using
type
=
typename
decltype
(
x
)
::
type
;
using
type
=
typename
decltype
(
x
)
::
type
;
type
r
=
init
;
type
r
=
init
;
for
(
index_int
j
=
0
;
j
<
x
.
get_shape
().
elements
();
j
++
)
for
(
index_int
j
=
0
;
j
<
x
.
get_shape
().
elements
();
j
++
)
...
@@ -284,7 +284,7 @@ struct lane
...
@@ -284,7 +284,7 @@ struct lane
template
<
class
F
>
template
<
class
F
>
__device__
auto
inner
(
F
f
)
const
__device__
auto
inner
(
F
f
)
const
{
{
return
sliced
(
slice
r
,
[
=
](
auto
x
,
auto
...
xs
)
{
return
sliced
(
slice
,
[
=
](
auto
x
,
auto
...
xs
)
{
for
(
index_int
j
=
0
;
j
<
x
.
get_shape
().
elements
();
j
++
)
for
(
index_int
j
=
0
;
j
<
x
.
get_shape
().
elements
();
j
++
)
{
{
f
(
x
[
j
],
xs
[
j
]...);
f
(
x
[
j
],
xs
[
j
]...);
...
@@ -295,7 +295,7 @@ struct lane
...
@@ -295,7 +295,7 @@ struct lane
template
<
class
Input
>
template
<
class
Input
>
constexpr
auto
elements
()
const
constexpr
auto
elements
()
const
{
{
using
reduce_type
=
decltype
(
slice
r
(
Input
{}));
using
reduce_type
=
decltype
(
slice
(
Input
{}));
return
get_shape_c
<
reduce_type
>
{}.
elements
();
return
get_shape_c
<
reduce_type
>
{}.
elements
();
}
}
};
};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
View file @
8b25fd3e
...
@@ -33,11 +33,15 @@ template <index_int Axis, class Input, class Output>
...
@@ -33,11 +33,15 @@ template <index_int Axis, class Input, class Output>
__device__
void
softmax
(
Input
input
,
Output
output
)
__device__
void
softmax
(
Input
input
,
Output
output
)
{
{
reduce
::
block
::
run
<
reduce
::
with_axis
<
Input
,
Axis
>>
([
&
](
auto
,
auto
r
)
{
reduce
::
block
::
run
<
reduce
::
with_axis
<
Input
,
Axis
>>
([
&
](
auto
,
auto
r
)
{
auto
batch_max
=
r
.
reduce
(
op
::
max
{},
lowest
{},
op
::
id
{})(
input
);
#ifdef MIGRAPHX_USE_FAST_SOFTMAX
auto
batch_sum
=
const
auto
c
=
vec_at
(
r
.
slice
(
input
)[
0
],
0
);
r
.
reduce
(
op
::
sum
{},
0
,
[
&
](
auto
x
)
{
return
migraphx
::
exp
(
x
-
batch_max
);
})(
input
);
#else
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
migraphx
::
exp
(
x
-
batch_max
)
/
batch_sum
;
})(
output
,
const
auto
c
=
r
.
reduce
(
op
::
max
{},
lowest
{},
op
::
id
{})(
input
);
input
);
#endif
auto
batch_sum
=
r
.
reduce
(
op
::
sum
{},
0
,
[
&
](
auto
x
)
{
return
migraphx
::
convert
<
float
>
(
migraphx
::
exp
(
x
-
c
));
})(
input
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
migraphx
::
exp
(
x
-
c
)
/
batch_sum
;
})(
output
,
input
);
});
});
}
}
...
...
src/targets/gpu/lowering.cpp
View file @
8b25fd3e
...
@@ -26,6 +26,8 @@
...
@@ -26,6 +26,8 @@
#include <migraphx/manage_ptr.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp>
#include <migraphx/op/deconvolution.hpp>
...
@@ -171,7 +173,8 @@ struct miopen_apply
...
@@ -171,7 +173,8 @@ struct miopen_apply
init
();
init
();
for
(
auto
it
=
mod
->
begin
();
it
!=
mod
->
end
();
it
++
)
for
(
auto
it
=
mod
->
begin
();
it
!=
mod
->
end
();
it
++
)
{
{
auto
s
=
it
->
get_shape
();
auto
s
=
it
->
get_shape
();
auto
attrs
=
it
->
get_operator
().
attributes
();
if
(
apply_map
.
count
(
it
->
name
())
>
0
)
if
(
apply_map
.
count
(
it
->
name
())
>
0
)
{
{
check_shape
(
s
,
apply_map
.
at
(
it
->
name
())(
it
));
check_shape
(
s
,
apply_map
.
at
(
it
->
name
())(
it
));
...
@@ -180,11 +183,37 @@ struct miopen_apply
...
@@ -180,11 +183,37 @@ struct miopen_apply
{
{
check_shape
(
s
,
insert_precompile_op
(
it
));
check_shape
(
s
,
insert_precompile_op
(
it
));
}
}
else
if
(
attrs
.
contains
(
"target"
))
{
check_shape
(
s
,
insert_custom_op
(
it
,
attrs
));
}
}
}
copy_params
();
copy_params
();
}
}
instruction_ref
insert_custom_op
(
instruction_ref
ins
,
const
value
&
attrs
)
const
{
const
auto
&
custom_op
=
ins
->
get_operator
();
if
(
attrs
.
at
(
"target"
)
==
"cpu"
)
{
auto
s
=
ins
->
get_shape
();
std
::
vector
<
instruction_ref
>
cpu_inputs
;
auto
inputs
=
ins
->
inputs
();
auto
output
=
inputs
.
back
();
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
cpu_inputs
),
[
&
](
auto
in
)
{
return
mod
->
insert_instruction
(
ins
,
make_op
(
"hip::copy_from_gpu"
),
in
);
});
cpu_inputs
.
front
()
=
mod
->
insert_instruction
(
ins
,
make_op
(
"hip::sync_stream"
),
cpu_inputs
);
auto
cpu_out
=
mod
->
insert_instruction
(
ins
,
custom_op
,
cpu_inputs
);
auto
gpu_out
=
mod
->
insert_instruction
(
ins
,
make_op
(
"hip::copy_to_gpu"
),
cpu_out
,
output
);
return
mod
->
replace_instruction
(
ins
,
gpu_out
);
}
return
ins
;
}
instruction_ref
insert_precompile_op
(
instruction_ref
ins
)
const
instruction_ref
insert_precompile_op
(
instruction_ref
ins
)
const
{
{
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
...
...
test/api/test_custom_op.cpp
View file @
8b25fd3e
...
@@ -43,6 +43,8 @@ struct sigmoid_custom_op final : migraphx::experimental_custom_op_base
...
@@ -43,6 +43,8 @@ struct sigmoid_custom_op final : migraphx::experimental_custom_op_base
return
inputs
[
1
];
return
inputs
[
1
];
}
}
virtual
bool
runs_on_offload_target
()
const
override
{
return
true
;
}
virtual
migraphx
::
shape
compute_shape
(
migraphx
::
shapes
inputs
)
const
override
virtual
migraphx
::
shape
compute_shape
(
migraphx
::
shapes
inputs
)
const
override
{
{
if
(
inputs
.
size
()
!=
2
)
if
(
inputs
.
size
()
!=
2
)
...
@@ -111,4 +113,45 @@ TEST_CASE(run_sigmoid_with_incorrect_shape)
...
@@ -111,4 +113,45 @@ TEST_CASE(run_sigmoid_with_incorrect_shape)
"Error in compute_shape of: sigmoid_custom_op: op must have two inputs"
));
"Error in compute_shape of: sigmoid_custom_op: op must have two inputs"
));
}
}
struct
identity_custom_op
final
:
migraphx
::
experimental_custom_op_base
{
virtual
std
::
string
name
()
const
override
{
return
"identity_custom_op"
;
}
virtual
migraphx
::
argument
compute
(
migraphx
::
context
,
migraphx
::
shape
,
migraphx
::
arguments
inputs
)
const
override
{
return
inputs
[
0
];
}
virtual
bool
runs_on_offload_target
()
const
override
{
return
true
;
}
virtual
migraphx
::
shape
compute_shape
(
migraphx
::
shapes
inputs
)
const
override
{
if
(
inputs
.
size
()
!=
1
)
{
throw
std
::
runtime_error
(
"Identity op must have only one input"
);
}
return
inputs
.
back
();
}
virtual
std
::
vector
<
size_t
>
output_alias
(
migraphx
::
shapes
)
const
override
{
return
{
0
,
1
};
}
};
TEST_CASE
(
run_custom_op_with_invalid_output_alias
)
{
identity_custom_op
i_op
;
migraphx
::
register_experimental_custom_op
(
i_op
);
auto
op
=
migraphx
::
operation
(
"identity_custom_op"
);
EXPECT
(
op
.
name
()
==
"identity_custom_op"
);
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx_shape_float_type
,
{
12
}};
migraphx
::
module
m
=
p
.
get_main_module
();
auto
x
=
m
.
add_parameter
(
"x"
,
s
);
auto
i_ins
=
m
.
add_instruction
(
migraphx
::
operation
(
"identity_custom_op"
),
{
x
});
migraphx_test_private_disable_exception_catch
(
true
);
EXPECT
(
test
::
throws
<
std
::
exception
>
(
[
&
]
{
p
.
compile
(
migraphx
::
target
(
"ref"
));
},
"Currently, CustomOps in MIGraphX only supports one output_alias"
));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/api/test_custom_op_gpu.cpp
View file @
8b25fd3e
...
@@ -24,40 +24,89 @@
...
@@ -24,40 +24,89 @@
#include <hip/hip_runtime_api.h>
#include <hip/hip_runtime_api.h>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include <migraphx/migraphx.hpp>
#include <numeric>
#include <stdexcept>
#include <stdexcept>
#include "test.hpp"
#include "test.hpp"
#define MIGRAPHX_HIP_ASSERT(x) (EXPECT(x == hipSuccess))
#define MIGRAPHX_HIP_ASSERT(x) (EXPECT(x == hipSuccess))
struct
simple_custom_op
final
:
migraphx
::
experimental_custom_op_base
struct
half_copy_host
final
:
migraphx
::
experimental_custom_op_base
{
{
virtual
std
::
string
name
()
const
override
{
return
"simple_custom_op"
;
}
virtual
std
::
string
name
()
const
override
{
return
"half_copy_host"
;
}
virtual
bool
runs_on_offload_target
()
const
override
{
return
false
;
}
virtual
migraphx
::
argument
virtual
migraphx
::
argument
compute
(
migraphx
::
context
ctx
,
migraphx
::
shape
,
migraphx
::
arguments
inputs
)
const
override
compute
(
migraphx
::
context
ctx
,
migraphx
::
shape
,
migraphx
::
arguments
inputs
)
const
override
{
{
// sets first half size_bytes of the input 0, and rest of the half bytes are copied.
// This custom op simply sets first half size_bytes of the input to 0, and rest of the half
int
*
h_output
=
nullptr
;
// bytes are copied. for this custom_op, it does its computation on the host. Therefore,
auto
*
d_output
=
reinterpret_cast
<
int
*>
(
inputs
[
0
].
data
());
// `runs_on_offload_target()` is set to false. MIGraphX would inject necessary buffer copies
auto
input_bytes
=
inputs
[
0
].
get_shape
().
bytes
();
// to and from GPU to Host based on `runs_on_offload_targe()` flag for input buffers as well
auto
*
output_ptr
=
inputs
[
1
].
data
();
// as the output buffers
auto
copy_bytes
=
input_bytes
/
2
;
auto
*
input_buffer_ptr
=
inputs
[
0
].
data
();
auto
*
output_buffer_ptr
=
inputs
[
1
].
data
();
auto
input_bytes
=
inputs
[
0
].
get_shape
().
bytes
();
auto
copy_bytes
=
input_bytes
/
2
;
MIGRAPHX_HIP_ASSERT
(
hipSetDevice
(
0
));
MIGRAPHX_HIP_ASSERT
(
hipSetDevice
(
0
));
MIGRAPHX_HIP_ASSERT
(
hipHostMalloc
(
&
h_output
,
input_bytes
));
MIGRAPHX_HIP_ASSERT
(
hipMemcpyAsync
(
output_buffer_ptr
,
MIGRAPHX_HIP_ASSERT
(
hipMemcpyAsync
(
input_buffer_ptr
,
h_output
,
d_output
,
input_bytes
,
hipMemcpyDeviceToHost
,
ctx
.
get_queue
<
hipStream_t
>
()));
input_bytes
,
hipMemcpyHostToHost
,
ctx
.
get_queue
<
hipStream_t
>
()));
MIGRAPHX_HIP_ASSERT
(
hipDeviceSynchronize
());
MIGRAPHX_HIP_ASSERT
(
hipMemset
(
output_buffer_ptr
,
0
,
copy_bytes
));
MIGRAPHX_HIP_ASSERT
(
hipDeviceSynchronize
());
MIGRAPHX_HIP_ASSERT
(
hipDeviceSynchronize
());
MIGRAPHX_HIP_ASSERT
(
hipMemset
(
h_output
,
0
,
copy_bytes
));
return
inputs
[
1
];
}
virtual
migraphx
::
shape
compute_shape
(
migraphx
::
shapes
inputs
)
const
override
{
if
(
not
inputs
[
0
].
standard
()
or
not
inputs
[
1
].
standard
())
{
throw
std
::
runtime_error
(
"Input args must be standard shaped"
);
}
if
(
inputs
.
size
()
!=
2
)
{
throw
std
::
runtime_error
(
"number of inputs must be 2"
);
}
return
inputs
.
back
();
}
};
struct
half_copy_device
final
:
migraphx
::
experimental_custom_op_base
{
virtual
std
::
string
name
()
const
override
{
return
"half_copy_device"
;
}
virtual
bool
runs_on_offload_target
()
const
override
{
return
true
;
}
virtual
migraphx
::
argument
compute
(
migraphx
::
context
ctx
,
migraphx
::
shape
,
migraphx
::
arguments
inputs
)
const
override
{
// This custom op simply sets first half size_bytes of the input to 0, and rest of the half
// bytes are copied. for this custom_op, it does its computation on the "GPU". Therefore,
// `runs_on_offload_target()` is set to "true".
auto
*
input_buffer_ptr
=
inputs
[
0
].
data
();
auto
*
output_buffer_ptr
=
inputs
[
1
].
data
();
auto
input_bytes
=
inputs
[
0
].
get_shape
().
bytes
();
auto
copy_bytes
=
input_bytes
/
2
;
MIGRAPHX_HIP_ASSERT
(
hipSetDevice
(
0
));
MIGRAPHX_HIP_ASSERT
(
hipMemcpyAsync
(
output_buffer_ptr
,
input_buffer_ptr
,
input_bytes
,
hipMemcpyDeviceToDevice
,
ctx
.
get_queue
<
hipStream_t
>
()));
MIGRAPHX_HIP_ASSERT
(
hipDeviceSynchronize
());
MIGRAPHX_HIP_ASSERT
(
hipDeviceSynchronize
());
MIGRAPHX_HIP_ASSERT
(
hipMem
cpy
(
output_
ptr
,
h_output
,
input_bytes
,
hipMemcpyHostToDevice
));
MIGRAPHX_HIP_ASSERT
(
hipMem
set
(
output_
buffer_ptr
,
0
,
copy_bytes
));
MIGRAPHX_HIP_ASSERT
(
hipDeviceSynchronize
());
MIGRAPHX_HIP_ASSERT
(
hipDeviceSynchronize
());
MIGRAPHX_HIP_ASSERT
(
hipHostFree
(
h_output
));
return
inputs
[
1
];
return
inputs
[
1
];
}
}
virtual
migraphx
::
shape
compute_shape
(
migraphx
::
shapes
inputs
)
const
override
virtual
migraphx
::
shape
compute_shape
(
migraphx
::
shapes
inputs
)
const
override
{
{
if
(
not
inputs
[
0
].
standard
())
if
(
not
inputs
[
0
].
standard
()
or
not
inputs
[
1
].
standard
()
)
{
{
throw
std
::
runtime_error
(
"
firs
t arg must be standard shaped"
);
throw
std
::
runtime_error
(
"
Inpu
t arg
s
must be standard shaped"
);
}
}
if
(
inputs
.
size
()
!=
2
)
if
(
inputs
.
size
()
!=
2
)
{
{
...
@@ -67,36 +116,208 @@ struct simple_custom_op final : migraphx::experimental_custom_op_base
...
@@ -67,36 +116,208 @@ struct simple_custom_op final : migraphx::experimental_custom_op_base
}
}
};
};
TEST_CASE
(
run_simple_custom_op
)
// overwrites input buffer
struct
half_copy_device_same_buffer
final
:
migraphx
::
experimental_custom_op_base
{
virtual
std
::
string
name
()
const
override
{
return
"half_copy_device_same_buffer"
;
}
virtual
bool
runs_on_offload_target
()
const
override
{
return
true
;
}
virtual
migraphx
::
argument
compute
(
migraphx
::
context
,
migraphx
::
shape
,
migraphx
::
arguments
inputs
)
const
override
{
// This custom op simply sets first half size_bytes of the input 0, and rest of the half
// bytes are copied. for this custom_op, it does its computation on the "device". Therefore,
// `runs_on_offload_target()` is set to "true"
auto
*
buffer_ptr
=
inputs
[
0
].
data
();
auto
input_bytes
=
inputs
[
0
].
get_shape
().
bytes
();
auto
copy_bytes
=
input_bytes
/
2
;
MIGRAPHX_HIP_ASSERT
(
hipSetDevice
(
0
));
MIGRAPHX_HIP_ASSERT
(
hipMemset
(
buffer_ptr
,
0
,
copy_bytes
));
MIGRAPHX_HIP_ASSERT
(
hipDeviceSynchronize
());
return
inputs
[
0
];
}
virtual
migraphx
::
shape
compute_shape
(
migraphx
::
shapes
inputs
)
const
override
{
if
(
not
inputs
[
0
].
standard
())
{
throw
std
::
runtime_error
(
"Input arg must be standard shaped"
);
}
return
inputs
.
front
();
}
};
TEST_CASE
(
register_half_copy_op
)
{
half_copy_host
hch
;
migraphx
::
register_experimental_custom_op
(
hch
);
auto
op
=
migraphx
::
operation
(
"half_copy_host"
);
EXPECT
(
op
.
name
()
==
"half_copy_host"
);
half_copy_device
hcd
;
migraphx
::
register_experimental_custom_op
(
hcd
);
op
=
migraphx
::
operation
(
"half_copy_device"
);
EXPECT
(
op
.
name
()
==
"half_copy_device"
);
half_copy_device_same_buffer
hcdsb
;
migraphx
::
register_experimental_custom_op
(
hcdsb
);
op
=
migraphx
::
operation
(
"half_copy_device_same_buffer"
);
EXPECT
(
op
.
name
()
==
"half_copy_device_same_buffer"
);
}
TEST_CASE
(
half_copy_custom_op_test
)
{
auto
run_test_prog
=
[](
const
std
::
string
&
op_name
,
bool
buffer_alloc
)
{
migraphx
::
program
p
;
migraphx
::
module
m
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx_shape_float_type
,
{
4
,
3
}};
auto
x
=
m
.
add_parameter
(
"x"
,
s
);
migraphx
::
instructions
inputs
=
{
x
};
if
(
buffer_alloc
)
{
auto
alloc
=
m
.
add_allocation
(
s
);
inputs
=
{
x
,
alloc
};
}
auto
half_copy_ins
=
m
.
add_instruction
(
migraphx
::
operation
(
op_name
.
c_str
()),
inputs
);
m
.
add_return
({
half_copy_ins
});
migraphx
::
compile_options
options
;
options
.
set_offload_copy
();
p
.
compile
(
migraphx
::
target
(
"gpu"
),
options
);
migraphx
::
program_parameters
pp
;
std
::
vector
<
float
>
x_data
(
12
);
std
::
iota
(
x_data
.
begin
(),
x_data
.
end
(),
0
);
pp
.
add
(
"x"
,
migraphx
::
argument
(
s
,
x_data
.
data
()));
auto
results
=
p
.
eval
(
pp
);
auto
result
=
results
[
0
];
auto
result_vec
=
result
.
as_vector
<
float
>
();
std
::
vector
<
float
>
expected_result
(
12
,
0
);
std
::
iota
(
expected_result
.
begin
()
+
6
,
expected_result
.
end
(),
6
);
EXPECT
(
bool
{
result
==
migraphx
::
argument
(
s
,
expected_result
.
data
())});
};
// register all the ops
half_copy_host
hch
;
migraphx
::
register_experimental_custom_op
(
hch
);
half_copy_device
hcd
;
migraphx
::
register_experimental_custom_op
(
hcd
);
half_copy_device_same_buffer
hcdsb
;
migraphx
::
register_experimental_custom_op
(
hcdsb
);
std
::
vector
<
std
::
pair
<
std
::
string
,
bool
>>
tests_config
=
{
{
"half_copy_host"
,
true
},
{
"half_copy_device"
,
true
},
{
"half_copy_device_same_buffer"
,
false
}};
for
(
const
auto
&
i
:
tests_config
)
{
run_test_prog
(
i
.
first
,
i
.
second
);
}
}
struct
stride_two
final
:
migraphx
::
experimental_custom_op_base
{
virtual
std
::
string
name
()
const
override
{
return
"stride_two"
;
}
virtual
migraphx
::
argument
compute
(
migraphx
::
context
,
migraphx
::
shape
out_shape
,
migraphx
::
arguments
inputs
)
const
override
{
return
{
out_shape
,
inputs
[
0
].
data
()};
}
virtual
migraphx
::
shape
compute_shape
(
migraphx
::
shapes
inputs
)
const
override
{
if
(
inputs
.
size
()
!=
1
)
{
throw
std
::
runtime_error
(
"stride_two op must have only one input argument"
);
};
if
(
not
inputs
[
0
].
standard
())
{
throw
std
::
runtime_error
(
"stride_two op only works on the standard input shapes"
);
}
migraphx
::
shape
input_s
=
inputs
[
0
];
std
::
vector
<
size_t
>
dims
=
input_s
.
lengths
();
std
::
vector
<
size_t
>
new_dims
;
std
::
vector
<
size_t
>
strides
=
input_s
.
strides
();
std
::
vector
<
size_t
>
new_strides
;
std
::
for_each
(
dims
.
begin
(),
dims
.
end
(),
[
&
](
auto
i
)
{
new_dims
.
push_back
(
i
/
2
);
});
std
::
for_each
(
strides
.
begin
(),
strides
.
end
(),
[
&
](
auto
i
)
{
new_strides
.
push_back
(
i
*
2
);
});
migraphx
::
shape
output_shape
{
input_s
.
type
(),
new_dims
,
new_strides
};
return
output_shape
;
}
virtual
bool
runs_on_offload_target
()
const
override
{
return
true
;
}
virtual
std
::
vector
<
size_t
>
output_alias
(
migraphx
::
shapes
)
const
override
{
return
{
0
};
};
};
TEST_CASE
(
stride_two_custom_op_test
)
{
{
simple_custom_op
simple_op
;
stride_two
st
;
migraphx
::
register_experimental_custom_op
(
simple_op
);
migraphx
::
register_experimental_custom_op
(
st
);
migraphx
::
program
p
;
migraphx
::
module
m
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx_shape_float_type
,
{
4
,
4
,
4
}};
auto
x
=
m
.
add_parameter
(
"x"
,
s
);
auto
stride_two_ins
=
m
.
add_instruction
(
migraphx
::
operation
(
"stride_two"
),
{
x
});
m
.
add_return
({
stride_two_ins
});
migraphx
::
compile_options
options
;
options
.
set_offload_copy
();
p
.
compile
(
migraphx
::
target
(
"gpu"
),
options
);
migraphx
::
program_parameters
pp
;
std
::
vector
<
float
>
x_data
(
64
);
std
::
iota
(
x_data
.
begin
(),
x_data
.
end
(),
0
);
pp
.
add
(
"x"
,
migraphx
::
argument
(
s
,
x_data
.
data
()));
auto
results
=
p
.
eval
(
pp
);
auto
result
=
results
[
0
];
auto
result_vec
=
result
.
as_vector
<
float
>
();
std
::
vector
<
float
>
expected_result
=
{
0
,
2
,
8
,
10
,
32
,
34
,
40
,
42
};
EXPECT
(
result_vec
==
expected_result
);
}
TEST_CASE
(
custom_op_with_pre_and_post_subgraph_test
)
{
half_copy_host
hco
;
migraphx
::
register_experimental_custom_op
(
hco
);
stride_two
st
;
migraphx
::
register_experimental_custom_op
(
st
);
migraphx
::
program
p
;
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx_shape_int32_type
,
{
4
,
3
}};
migraphx
::
shape
s
{
migraphx_shape_float_type
,
{
4
,
6
}};
migraphx
::
shape
trans_shape
{
migraphx_shape_int32_type
,
{
3
,
4
}};
migraphx
::
module
m
=
p
.
get_main_module
();
migraphx
::
module
m
=
p
.
get_main_module
();
auto
x
=
m
.
add_parameter
(
"x"
,
s
);
auto
x
=
m
.
add_parameter
(
"x"
,
s
);
auto
neg
=
m
.
add_instruction
(
migraphx
::
operation
(
"neg"
),
x
);
// pre-subgraph
auto
alloc
=
m
.
add_allocation
(
trans_shape
);
auto
neg_ins
=
m
.
add_instruction
(
migraphx
::
operation
(
"neg"
),
x
);
auto
neg_trans
=
auto
trans_ins
=
m
.
add_instruction
(
migraphx
::
operation
(
"transpose"
,
"{permutation: [1, 0]}"
),
{
neg
});
m
.
add_instruction
(
migraphx
::
operation
(
"transpose"
,
"{permutation: [1, 0]}"
),
{
neg_ins
});
auto
neg_cont
=
m
.
add_instruction
(
migraphx
::
operation
(
"contiguous"
),
{
neg_trans
});
auto
cont_ins
=
m
.
add_instruction
(
migraphx
::
operation
(
"contiguous"
),
{
trans_ins
});
auto
custom_kernel
=
// custom_op
m
.
add_instruction
(
migraphx
::
operation
(
"simple_custom_op"
),
{
neg_cont
,
alloc
});
migraphx
::
shape
trans_shape
{
migraphx_shape_float_type
,
{
6
,
4
}};
auto
relu
=
m
.
add_instruction
(
migraphx
::
operation
(
"relu"
),
custom_kernel
);
auto
alloc
=
m
.
add_allocation
(
trans_shape
);
m
.
add_return
({
relu
});
auto
half_copy_ins
=
m
.
add_instruction
(
migraphx
::
operation
(
"half_copy_host"
),
{
cont_ins
,
alloc
});
// post-subgraph
auto
abs_ins
=
m
.
add_instruction
(
migraphx
::
operation
(
"abs"
),
{
half_copy_ins
});
// another custom_op
auto
stride_two_ins
=
m
.
add_instruction
(
migraphx
::
operation
(
"stride_two"
),
{
abs_ins
});
// post-subgraph
auto
relu_ins
=
m
.
add_instruction
(
migraphx
::
operation
(
"relu"
),
{
stride_two_ins
});
m
.
add_return
({
relu_ins
});
migraphx
::
compile_options
options
;
migraphx
::
compile_options
options
;
options
.
set_offload_copy
();
options
.
set_offload_copy
();
p
.
compile
(
migraphx
::
target
(
"gpu"
),
options
);
p
.
compile
(
migraphx
::
target
(
"gpu"
),
options
);
migraphx
::
program_parameters
pp
;
migraphx
::
program_parameters
pp
;
std
::
vector
<
int
>
x_data
(
12
,
-
3
);
std
::
vector
<
float
>
x_data
(
s
.
elements
());
std
::
iota
(
x_data
.
begin
(),
x_data
.
end
(),
0
);
pp
.
add
(
"x"
,
migraphx
::
argument
(
s
,
x_data
.
data
()));
pp
.
add
(
"x"
,
migraphx
::
argument
(
s
,
x_data
.
data
()));
auto
results
=
p
.
eval
(
pp
);
auto
results
=
p
.
eval
(
pp
);
auto
result
=
results
[
0
];
auto
result
=
results
[
0
];
auto
result_vec
=
result
.
as_vector
<
in
t
>
();
auto
result_vec
=
result
.
as_vector
<
floa
t
>
();
std
::
vector
<
in
t
>
expected_result
(
12
,
0
)
;
std
::
vector
<
floa
t
>
expected_result
=
{
0
,
0
,
0
,
0
,
4
,
16
}
;
std
::
fill
(
expected_result
.
begin
()
+
6
,
expected_result
.
end
(),
3
);
EXPECT
(
bool
{
result
==
migraphx
::
argument
(
migraphx
::
shape
{
migraphx_shape_float_type
,
{
3
,
2
}},
EXPECT
(
bool
{
result
==
migraphx
::
argument
(
trans_shape
,
expected_result
.
data
())});
expected_result
.
data
())});
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment