Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
84add0fc
Commit
84add0fc
authored
Oct 06, 2022
by
charlie
Browse files
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into refactor_auto_pad_conv
parents
c1caf40a
f7d987ba
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
940 additions
and
125 deletions
+940
-125
examples/migraphx/custom_op_hip_kernel/custom_op_hip_kernel.cpp
...es/migraphx/custom_op_hip_kernel/custom_op_hip_kernel.cpp
+9
-2
examples/migraphx/custom_op_miopen_kernel/custom_op_miopen_kernel.cpp
...raphx/custom_op_miopen_kernel/custom_op_miopen_kernel.cpp
+7
-0
examples/migraphx/custom_op_rocblas_kernel/custom_op_rocblas_kernel.cpp
...phx/custom_op_rocblas_kernel/custom_op_rocblas_kernel.cpp
+8
-1
src/api/api.cpp
src/api/api.cpp
+130
-2
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+27
-1
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+125
-20
src/api/migraphx.py
src/api/migraphx.py
+13
-0
src/include/migraphx/context.hpp
src/include/migraphx/context.hpp
+67
-1
src/include/migraphx/execution_environment.hpp
src/include/migraphx/execution_environment.hpp
+41
-0
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+3
-2
src/program.cpp
src/program.cpp
+53
-40
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+17
-0
src/targets/gpu/include/migraphx/gpu/context.hpp
src/targets/gpu/include/migraphx/gpu/context.hpp
+26
-2
src/targets/gpu/jit/softmax.cpp
src/targets/gpu/jit/softmax.cpp
+5
-0
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+8
-8
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
+9
-5
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+31
-2
test/api/test_custom_op.cpp
test/api/test_custom_op.cpp
+43
-0
test/api/test_custom_op_gpu.cpp
test/api/test_custom_op_gpu.cpp
+258
-37
test/api/test_gpu.cpp
test/api/test_gpu.cpp
+60
-2
No files found.
examples/migraphx/custom_op_hip_kernel/custom_op_hip_kernel.cpp
View file @
84add0fc
...
...
@@ -45,6 +45,13 @@ __global__ void vector_square(T* C_d, const T* A_d, size_t N)
struct
square_custom_op
final
:
migraphx
::
experimental_custom_op_base
{
virtual
std
::
string
name
()
const
override
{
return
"square_custom_op"
;
}
// flag to identify whether custom op runs on the GPU or on the host.
// Based on this flag MIGraphX would inject necessary copies to and from GPU for the input and
// output buffers as necessary. Therefore if custom_op runs on GPU then it can assume its input
// buffers are in GPU memory, and similarly for the host
virtual
bool
runs_on_offload_target
()
const
override
{
return
true
;
}
virtual
migraphx
::
argument
compute
(
migraphx
::
context
ctx
,
migraphx
::
shape
,
migraphx
::
arguments
inputs
)
const
override
{
...
...
@@ -54,7 +61,7 @@ struct square_custom_op final : migraphx::experimental_custom_op_base
// is output argument, so it should be returned from compute method.
auto
*
input_buffer
=
reinterpret_cast
<
float
*>
(
inputs
[
0
].
data
());
auto
*
output_buffer
=
reinterpret_cast
<
float
*>
(
inputs
[
1
].
data
());
size_t
n_elements
=
inputs
[
0
].
get_shape
().
bytes
()
/
sizeof
(
inputs
[
0
].
get_shape
().
type
()
);
size_t
n_elements
=
inputs
[
0
].
get_shape
().
elements
(
);
MIGRAPHX_HIP_ASSERT
(
hipSetDevice
(
0
));
const
unsigned
blocks
=
512
;
const
unsigned
threads_per_block
=
256
;
...
...
@@ -103,7 +110,7 @@ int main(int argc, const char* argv[])
options
.
set_offload_copy
();
p
.
compile
(
migraphx
::
target
(
"gpu"
),
options
);
migraphx
::
program_parameters
pp
;
std
::
vector
<
float
>
x_data
(
s
.
bytes
()
/
sizeof
(
s
.
type
()
));
std
::
vector
<
float
>
x_data
(
s
.
elements
(
));
std
::
iota
(
x_data
.
begin
(),
x_data
.
end
(),
0
);
pp
.
add
(
"x"
,
migraphx
::
argument
(
s
,
x_data
.
data
()));
auto
results
=
p
.
eval
(
pp
);
...
...
examples/migraphx/custom_op_miopen_kernel/custom_op_miopen_kernel.cpp
View file @
84add0fc
...
...
@@ -93,6 +93,13 @@ inline auto make_activation_descriptor(miopenActivationMode_t mode,
struct
abs_custom_op
final
:
migraphx
::
experimental_custom_op_base
{
virtual
std
::
string
name
()
const
override
{
return
"abs_custom_op"
;
}
// flag to identify whether custom op runs on the GPU or on the host.
// Based on this flag MIGraphX would inject necessary copies to and from GPU for the input and
// output buffers as necessary. Therefore if custom_op runs on GPU then it can assume its input
// buffers are in GPU memory, and similarly for the host
virtual
bool
runs_on_offload_target
()
const
override
{
return
true
;
}
virtual
migraphx
::
argument
compute
(
migraphx
::
context
ctx
,
migraphx
::
shape
output_shape
,
migraphx
::
arguments
args
)
const
override
...
...
examples/migraphx/custom_op_rocblas_kernel/custom_op_rocblas_kernel.cpp
View file @
84add0fc
...
...
@@ -51,6 +51,13 @@ rocblas_handle create_rocblas_handle_ptr(migraphx::context& ctx)
struct
sscal_custom_op
final
:
migraphx
::
experimental_custom_op_base
{
virtual
std
::
string
name
()
const
override
{
return
"sscal_custom_op"
;
}
// flag to identify whether custom op runs on the GPU or on the host.
// Based on this flag MIGraphX would inject necessary copies to and from GPU for the input and
// output buffers as necessary. Therefore if custom_op runs on GPU then it can assume its input
// buffers are in GPU memory, and similarly for the host
virtual
bool
runs_on_offload_target
()
const
override
{
return
true
;
}
virtual
migraphx
::
argument
compute
(
migraphx
::
context
ctx
,
migraphx
::
shape
output_shape
,
migraphx
::
arguments
args
)
const
override
...
...
@@ -107,7 +114,7 @@ int main(int argc, const char* argv[])
options
.
set_offload_copy
();
p
.
compile
(
migraphx
::
target
(
"gpu"
),
options
);
migraphx
::
program_parameters
pp
;
std
::
vector
<
float
>
x_data
(
x_shape
.
bytes
()
/
sizeof
(
x_shape
.
type
()
));
std
::
vector
<
float
>
x_data
(
x_shape
.
elements
(
));
std
::
vector
<
float
>
scale_data
{
-
1
};
std
::
iota
(
x_data
.
begin
(),
x_data
.
end
(),
0
);
pp
.
add
(
"x"
,
migraphx
::
argument
(
x_shape
,
x_data
.
data
()));
...
...
src/api/api.cpp
View file @
84add0fc
...
...
@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/execution_environment.hpp>
#include <migraphx/migraphx.h>
#include <migraphx/rank.hpp>
#include <migraphx/shape.hpp>
...
...
@@ -166,6 +167,13 @@ void set_output_names(tf_options& options, std::vector<const char*> names)
options
.
output_node_names
=
std
::
vector
<
std
::
string
>
(
names
.
begin
(),
names
.
end
());
}
std
::
vector
<
argument
>
run_async
(
program
&
p
,
const
parameter_map
&
params
,
void
*
s
,
std
::
string_view
name
)
{
execution_environment
exec_env
{
any_ptr
(
s
,
name
),
true
};
return
p
.
eval
(
params
,
exec_env
);
}
template
<
class
Value
>
std
::
vector
<
const
char
*>
get_names
(
const
std
::
unordered_map
<
std
::
string
,
Value
>&
m
)
{
...
...
@@ -265,11 +273,18 @@ struct experimental_custom_op
template
<
class
CustomOp
>
struct
custom_operation
{
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
,
F
)
{
return
pack
();
}
value
attributes
()
const
{
return
{{
"custom_op"
,
true
},
{
"target"
,
op
.
runs_on_offload_target
()
?
"gpu"
:
"cpu"
}};
}
CustomOp
op
;
std
::
string
name
()
const
{
return
op
.
xobject
.
name
;
}
...
...
@@ -284,6 +299,23 @@ struct custom_operation
{
return
op
.
compute
(
std
::
move
(
ctx
),
std
::
move
(
output_shape
),
std
::
move
(
inputs
));
}
std
::
ptrdiff_t
output_alias
(
std
::
vector
<
shape
>
inputs
)
const
{
auto
alias_vec
=
op
.
output_alias
(
std
::
move
(
inputs
));
// TODO: For now, only support one output alias
if
(
alias_vec
.
empty
())
{
return
-
1
;
}
if
(
alias_vec
.
size
()
>
1
)
{
MIGRAPHX_THROW
(
"Currently, CustomOps in MIGraphX only supports one output_alias"
);
}
return
alias_vec
.
front
();
}
bool
runs_on_offload_target
()
const
{
return
op
.
runs_on_offload_target
();
}
};
template
<
class
CustomOp
>
...
...
@@ -613,9 +645,9 @@ struct migraphx_experimental_custom_op
migraphx
::
shape
output
,
std
::
vector
<
migraphx
::
argument
>
inputs
)
const
{
std
::
remove_pointer_t
<
migraphx_argument_t
>
out
;
if
(
compute_f
==
nullptr
)
throw
std
::
runtime_error
(
"compute function is missing."
);
std
::
remove_pointer_t
<
migraphx_argument_t
>
out
;
std
::
array
<
char
,
256
>
exception_msg
;
exception_msg
.
front
()
=
'\0'
;
auto
api_error_result
=
compute_f
(
&
out
,
...
...
@@ -637,9 +669,9 @@ struct migraphx_experimental_custom_op
migraphx_experimental_custom_op_compute_shape
compute_shape_f
=
nullptr
;
migraphx
::
shape
compute_shape
(
std
::
vector
<
migraphx
::
shape
>
inputs
)
const
{
std
::
remove_pointer_t
<
migraphx_shape_t
>
out
;
if
(
compute_shape_f
==
nullptr
)
throw
std
::
runtime_error
(
"compute_shape function is missing."
);
std
::
remove_pointer_t
<
migraphx_shape_t
>
out
;
std
::
array
<
char
,
256
>
exception_msg
;
exception_msg
.
front
()
=
'\0'
;
auto
api_error_result
=
compute_shape_f
(
&
out
,
...
...
@@ -655,6 +687,49 @@ struct migraphx_experimental_custom_op
}
return
(
&
out
)
->
object
;
}
migraphx_experimental_custom_op_output_alias
output_alias_f
=
nullptr
;
std
::
vector
<
size_t
>
output_alias
(
std
::
vector
<
migraphx
::
shape
>
inputs
)
const
{
if
(
output_alias_f
==
nullptr
)
throw
std
::
runtime_error
(
"output_alias function is missing."
);
std
::
array
<
size_t
,
1024
>
out
;
std
::
remove_pointer_t
<
size_t
*>
out_size
=
1024
;
std
::
array
<
char
,
256
>
exception_msg
;
exception_msg
.
front
()
=
'\0'
;
auto
api_error_result
=
output_alias_f
(
out
.
data
(),
&
out_size
,
object_ptr
.
data
,
exception_msg
.
data
(),
exception_msg
.
size
(),
object_cast
<
migraphx_shapes_t
>
(
&
(
inputs
)));
if
(
api_error_result
!=
migraphx_status_success
)
{
const
std
::
string
exception_str
(
exception_msg
.
data
());
throw
std
::
runtime_error
(
"Error in output_alias of: "
+
std
::
string
(
object_ptr
.
obj_typename
)
+
": "
+
exception_str
);
}
return
{
out
.
begin
(),
out
.
begin
()
+
out_size
};
// cppcheck-suppress returnDanglingLifetime;
}
migraphx_experimental_custom_op_runs_on_offload_target
runs_on_offload_target_f
=
nullptr
;
bool
runs_on_offload_target
()
const
{
if
(
runs_on_offload_target_f
==
nullptr
)
throw
std
::
runtime_error
(
"runs_on_offload_target function is missing."
);
std
::
remove_pointer_t
<
bool
*>
out
;
std
::
array
<
char
,
256
>
exception_msg
;
exception_msg
.
front
()
=
'\0'
;
auto
api_error_result
=
runs_on_offload_target_f
(
&
out
,
object_ptr
.
data
,
exception_msg
.
data
(),
exception_msg
.
size
());
if
(
api_error_result
!=
migraphx_status_success
)
{
const
std
::
string
exception_str
(
exception_msg
.
data
());
throw
std
::
runtime_error
(
"Error in runs_on_offload_target of: "
+
std
::
string
(
object_ptr
.
obj_typename
)
+
": "
+
exception_str
);
}
return
out
;
}
};
extern
"C"
migraphx_status
migraphx_shape_destroy
(
migraphx_shape_t
shape
)
...
...
@@ -758,6 +833,16 @@ extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out,
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_shape_elements
(
size_t
*
out
,
const_migraphx_shape_t
shape
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
shape
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter shape: Null pointer"
);
*
out
=
(
shape
->
object
).
elements
();
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_shape_bytes
(
size_t
*
out
,
const_migraphx_shape_t
shape
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
...
...
@@ -791,6 +876,16 @@ extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_sha
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_shape_index
(
size_t
*
out
,
const_migraphx_shape_t
shape
,
size_t
i
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
shape
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter shape: Null pointer"
);
*
out
=
(
shape
->
object
).
index
((
i
));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_argument_destroy
(
migraphx_argument_t
argument
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
destroy
((
argument
));
});
...
...
@@ -1348,6 +1443,23 @@ extern "C" migraphx_status migraphx_program_run(migraphx_arguments_t* out,
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_program_run_async
(
migraphx_arguments_t
*
out
,
migraphx_program_t
program
,
migraphx_program_parameters_t
params
,
void
*
s
,
const
char
*
name
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
program
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter program: Null pointer"
);
if
(
params
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter params: Null pointer"
);
*
out
=
allocate
<
migraphx_arguments_t
>
(
migraphx
::
run_async
((
program
->
object
),
(
params
->
object
),
(
s
),
(
name
)));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_program_equal
(
bool
*
out
,
const_migraphx_program_t
program
,
const_migraphx_program_t
x
)
{
...
...
@@ -1879,6 +1991,22 @@ extern "C" migraphx_status migraphx_experimental_custom_op_set_compute_shape(
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_experimental_custom_op_set_output_alias
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_output_alias
input
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
(
obj
)
->
output_alias_f
=
(
input
);
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_experimental_custom_op_set_runs_on_offload_target
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_runs_on_offload_target
input
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
(
obj
)
->
runs_on_offload_target_f
=
(
input
);
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_experimental_custom_op_register
(
migraphx_experimental_custom_op_t
experimental_custom_op
)
{
...
...
src/api/include/migraphx/migraphx.h
View file @
84add0fc
...
...
@@ -26,7 +26,6 @@
#include <stdlib.h>
#include <stdbool.h>
// Add new types here
// clang-format off
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
...
...
@@ -144,6 +143,16 @@ typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraph
size_t
exception_msg_size
,
migraphx_shapes_t
inputs
);
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_output_alias
)(
size_t
*
out
,
size_t
*
out_size
,
void
*
obj
,
char
*
exception_msg
,
size_t
exception_msg_size
,
migraphx_shapes_t
inputs
);
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_runs_on_offload_target
)(
bool
*
out
,
void
*
obj
,
char
*
exception_msg
,
size_t
exception_msg_size
);
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_copy
)(
void
**
out
,
void
*
input
);
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_delete
)(
void
*
input
);
...
...
@@ -175,6 +184,8 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap
migraphx_status
migraphx_shape_type
(
migraphx_shape_datatype_t
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_elements
(
size_t
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_bytes
(
size_t
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
...
...
@@ -182,6 +193,8 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha
migraphx_status
migraphx_shape_standard
(
bool
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_index
(
size_t
*
out
,
const_migraphx_shape_t
shape
,
size_t
i
);
migraphx_status
migraphx_argument_destroy
(
migraphx_argument_t
argument
);
migraphx_status
migraphx_argument_assign_to
(
migraphx_argument_t
output
,
...
...
@@ -344,6 +357,12 @@ migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_program_t
program
,
migraphx_program_parameters_t
params
);
migraphx_status
migraphx_program_run_async
(
migraphx_arguments_t
*
out
,
migraphx_program_t
program
,
migraphx_program_parameters_t
params
,
void
*
s
,
const
char
*
name
);
migraphx_status
migraphx_program_equal
(
bool
*
out
,
const_migraphx_program_t
program
,
const_migraphx_program_t
x
);
...
...
@@ -502,6 +521,13 @@ migraphx_experimental_custom_op_set_compute(migraphx_experimental_custom_op_t ob
migraphx_status
migraphx_experimental_custom_op_set_compute_shape
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_compute_shape
input
);
migraphx_status
migraphx_experimental_custom_op_set_output_alias
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_output_alias
input
);
migraphx_status
migraphx_experimental_custom_op_set_runs_on_offload_target
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_runs_on_offload_target
input
);
migraphx_status
migraphx_experimental_custom_op_register
(
migraphx_experimental_custom_op_t
experimental_custom_op
);
...
...
src/api/include/migraphx/migraphx.hpp
View file @
84add0fc
...
...
@@ -25,10 +25,12 @@
#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#include "migraphx.h"
#include <algorithm>
#include <cstring>
#include <initializer_list>
#include <migraphx/migraphx.h>
#include <memory>
#include <numeric>
#include <exception>
#include <vector>
#include <cassert>
...
...
@@ -340,6 +342,11 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
this->set_handle(p, std::move(lifetime)); \
}
template
<
size_t
N
>
struct
out_params
{
};
template
<
class
Base
>
struct
interface_base
:
Base
{
...
...
@@ -391,7 +398,22 @@ struct interface_base : Base
}
template
<
class
T
,
class
Setter
,
class
F
>
void
set_fp
(
Setter
setter
,
F
pf
)
void
set_fp
(
Setter
setter
,
F
pf
,
out_params
<
2
>
)
{
static
F
f
=
pf
;
(
void
)
f
;
// avoid warning on gcc
call
(
setter
,
this
->
get_handle_ptr
(),
[](
auto
out1
,
auto
out2
,
void
*
obj
,
char
*
ex_msg
,
size_t
ex_msg_size
,
auto
...
xs
)
->
migraphx_status
{
return
try_
([
&
]
{
call_cast_arg
<
T
>
(
rank
<
2
>
{},
f
,
out1
,
out2
,
obj
,
xs
...);
},
ex_msg
,
ex_msg_size
);
});
}
template
<
class
T
,
class
Setter
,
class
F
>
void
set_fp
(
Setter
setter
,
F
pf
,
out_params
<
1
>
)
{
static
F
f
=
pf
;
(
void
)
f
;
// avoid warning on gcc
...
...
@@ -405,11 +427,27 @@ struct interface_base : Base
}
template
<
class
T
,
class
Setter
,
class
F
>
void
set_
auto_
fp
(
Setter
setter
,
F
f
)
void
set_fp
(
Setter
setter
,
F
pf
,
out_params
<
0
>
)
{
return
set_fp
<
T
>
(
setter
,
[
=
](
T
&
obj
,
auto
out
,
auto
...
xs
)
{
auto_invoke
(
f
,
out
,
obj
,
auto_convert_param
(
rank
<
2
>
{},
xs
)...);
});
static
F
f
=
pf
;
(
void
)
f
;
// avoid warning on gcc
call
(
setter
,
this
->
get_handle_ptr
(),
[](
void
*
obj
,
char
*
ex_msg
,
size_t
ex_msg_size
,
auto
...
xs
)
->
migraphx_status
{
return
try_
(
[
&
]
{
call_cast_arg
<
T
>
(
rank
<
0
>
{},
f
,
obj
,
xs
...);
},
ex_msg
,
ex_msg_size
);
});
}
template
<
class
T
,
class
Setter
,
class
F
,
class
Out
>
void
set_auto_fp
(
Setter
setter
,
F
f
,
Out
nums
)
{
return
set_fp
<
T
>
(
setter
,
[
=
](
T
&
obj
,
auto
out1
,
auto
out2
,
auto
...
xs
)
{
auto_invoke
(
f
,
out1
,
out2
,
obj
,
auto_convert_param
(
rank
<
2
>
{},
xs
)...);
},
nums
);
}
struct
no_out_arg
...
...
@@ -419,7 +457,7 @@ struct interface_base : Base
template
<
class
T
,
class
F
,
class
X
,
class
...
Xs
,
class
=
std
::
enable_if_t
<
std
::
is_void
<
X
>{}
>>
static
void
call_cast_arg
(
rank
<
0
>
,
F
f
,
X
*
obj
,
Xs
...
xs
)
{
f
(
reinterpret_cast
<
T
*>
(
obj
),
no_out_arg
{},
xs
...);
f
(
reinterpret_cast
<
T
*>
(
obj
),
no_out_arg
{},
no_out_arg
{},
xs
...);
}
template
<
class
T
,
...
...
@@ -430,17 +468,35 @@ struct interface_base : Base
class
=
std
::
enable_if_t
<
std
::
is_void
<
X
>{}
>>
static
void
call_cast_arg
(
rank
<
1
>
,
F
f
,
R
result
,
X
*
obj
,
Xs
...
xs
)
{
f
(
*
reinterpret_cast
<
T
*>
(
obj
),
result
,
xs
...);
f
(
*
reinterpret_cast
<
T
*>
(
obj
),
result
,
no_out_arg
{},
xs
...);
}
template
<
class
T
,
class
F
,
class
R1
,
class
R2
,
class
X
,
class
...
Xs
,
class
=
std
::
enable_if_t
<
std
::
is_void
<
X
>{}
>>
static
void
call_cast_arg
(
rank
<
2
>
,
F
f
,
R1
result1
,
R2
result2
,
X
*
obj
,
Xs
...
xs
)
{
f
(
*
reinterpret_cast
<
T
*>
(
obj
),
result1
,
result2
,
xs
...);
}
template
<
class
F
,
class
T1
,
class
T2
,
class
...
Ts
>
void
auto_invoke
(
F
f
,
T1
*
out1
,
T2
*
out2
,
Ts
&&
...
xs
)
{
auto_assign
(
rank
<
2
>
{},
out1
,
out2
,
f
(
std
::
forward
<
Ts
>
(
xs
)...));
}
template
<
class
F
,
class
T
,
class
...
Ts
>
void
auto_invoke
(
F
f
,
T
*
out
,
Ts
&&
...
xs
)
void
auto_invoke
(
F
f
,
T
*
out
,
no_out_arg
,
Ts
&&
...
xs
)
{
auto_assign
(
rank
<
2
>
{},
out
,
f
(
std
::
forward
<
Ts
>
(
xs
)...));
auto_assign
(
rank
<
1
>
{},
out
,
f
(
std
::
forward
<
Ts
>
(
xs
)...));
}
template
<
class
F
,
class
T
,
class
...
Ts
>
void
auto_invoke
(
F
f
,
no_out_arg
,
Ts
&&
...
xs
)
void
auto_invoke
(
F
f
,
no_out_arg
,
no_out_arg
,
Ts
&&
...
xs
)
{
f
(
std
::
forward
<
Ts
>
(
xs
)...);
}
...
...
@@ -469,7 +525,7 @@ struct interface_base : Base
template
<
class
T
,
class
U
>
void
auto_assign
(
rank
<
0
>
,
T
*
out
,
U
x
)
{
return
*
out
=
x
;
*
out
=
x
;
}
template
<
class
T
,
class
U
>
...
...
@@ -477,12 +533,21 @@ struct interface_base : Base
{
x
.
assign_to_handle
(
out
);
}
template
<
class
T1
,
class
T2
,
class
U
,
class
=
std
::
enable_if_t
<
std
::
is_same
<
T2
,
size_t
>{}
>>
auto
auto_assign
(
rank
<
2
>
,
T1
*
out_ptr
,
T2
*
out_size
,
U
x
)
{
*
out_size
=
std
::
min
(
*
out_size
,
x
.
size
());
std
::
copy_n
(
x
.
begin
(),
*
out_size
,
out_ptr
);
}
};
// NOLINTNEXTLINE
#define MIGRAPHX_INTERFACE_LIFT(T, prefix, name) \
this->set_auto_fp<T>(&migraphx_##prefix##_set_##name, \
[](T& x, auto... xs) { return x.name(xs...); })
#define MIGRAPHX_INTERFACE_LIFT(n_out, T, prefix, name) \
this->set_auto_fp<T>( \
&migraphx_##prefix##_set_##name, \
[](T& x, auto... xs) { return x.name(xs...); }, \
out_params<n_out>{})
template
<
class
Base
,
class
T
>
using
require_interface
=
...
...
@@ -567,6 +632,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return
pout
;
}
size_t
elements
()
const
{
size_t
pout
;
call
(
&
migraphx_shape_elements
,
&
pout
,
this
->
get_handle_ptr
());
return
pout
;
}
size_t
bytes
()
const
{
size_t
pout
;
...
...
@@ -581,6 +653,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return
result
;
}
// map element index to space index
size_t
index
(
size_t
i
)
const
{
size_t
result
;
call
(
&
migraphx_shape_index
,
&
result
,
this
->
get_handle_ptr
(),
i
);
return
result
;
}
friend
bool
operator
==
(
const
shape
&
px
,
const
shape
&
py
)
{
bool
pout
;
...
...
@@ -628,9 +708,15 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
template
<
typename
T
>
std
::
vector
<
T
>
as_vector
()
const
{
size_t
vector_len
=
this
->
get_shape
().
bytes
()
/
sizeof
(
T
);
T
*
buffer_ptr
=
reinterpret_cast
<
T
*>
(
this
->
data
());
return
{
buffer_ptr
,
buffer_ptr
+
vector_len
};
auto
ss
=
this
->
get_shape
();
auto
num_elements
=
ss
.
elements
();
std
::
vector
<
T
>
res
(
num_elements
);
T
*
buffer_ptr
=
reinterpret_cast
<
T
*>
(
this
->
data
());
for
(
size_t
i
=
0
;
i
<
num_elements
;
i
++
)
{
res
[
i
]
=
buffer_ptr
[
ss
.
index
(
i
)];
}
return
res
;
}
/// Generate an argument using random data
...
...
@@ -979,6 +1065,20 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return
arguments
(
pout
,
own
{});
}
template
<
class
Stream
>
/// Overloaded to allow for execution_environment input
arguments
run_async
(
const
program_parameters
&
pparams
,
Stream
*
s
)
const
{
migraphx_arguments_t
pout
;
call
(
&
migraphx_program_run_async
,
&
pout
,
this
->
get_handle_ptr
(),
pparams
.
get_handle_ptr
(),
s
,
get_type_name
<
Stream
>
().
c_str
());
return
arguments
(
pout
,
own
{});
}
void
print
()
const
{
call
(
&
migraphx_program_print
,
this
->
get_handle_ptr
());
}
program
sort
()
...
...
@@ -1255,7 +1355,10 @@ struct experimental_custom_op_base
virtual
std
::
string
name
()
const
=
0
;
virtual
argument
compute
(
context
ctx
,
shape
output
,
arguments
inputs
)
const
=
0
;
virtual
shape
compute_shape
(
shapes
inputs
)
const
=
0
;
virtual
~
experimental_custom_op_base
()
=
default
;
virtual
std
::
vector
<
size_t
>
output_alias
(
shapes
)
const
{
return
{};
}
// TODO: Return target string instead of bool
virtual
bool
runs_on_offload_target
()
const
=
0
;
virtual
~
experimental_custom_op_base
()
=
default
;
};
struct
experimental_custom_op
:
interface_base
<
MIGRAPHX_HANDLE_BASE
(
experimental_custom_op
)
>
...
...
@@ -1267,8 +1370,10 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental
obj
,
get_type_name
(
obj
).
c_str
(),
obj
.
name
().
c_str
());
MIGRAPHX_INTERFACE_LIFT
(
T
,
experimental_custom_op
,
compute_shape
);
MIGRAPHX_INTERFACE_LIFT
(
T
,
experimental_custom_op
,
compute
);
MIGRAPHX_INTERFACE_LIFT
(
1
,
T
,
experimental_custom_op
,
compute_shape
);
MIGRAPHX_INTERFACE_LIFT
(
1
,
T
,
experimental_custom_op
,
compute
);
MIGRAPHX_INTERFACE_LIFT
(
2
,
T
,
experimental_custom_op
,
output_alias
);
MIGRAPHX_INTERFACE_LIFT
(
1
,
T
,
experimental_custom_op
,
runs_on_offload_target
);
}
void
register_op
()
{
call
(
&
migraphx_experimental_custom_op_register
,
this
->
get_handle_ptr
());
}
...
...
src/api/migraphx.py
View file @
84add0fc
...
...
@@ -115,6 +115,7 @@ def shape(h):
const
=
True
)
h
.
method
(
'strides'
,
returns
=
'const std::vector<size_t>&'
,
const
=
True
)
h
.
method
(
'type'
,
returns
=
'migraphx::shape::type_t'
,
const
=
True
)
h
.
method
(
'elements'
,
returns
=
'size_t'
,
const
=
True
)
h
.
method
(
'bytes'
,
returns
=
'size_t'
,
const
=
True
)
h
.
method
(
'equal'
,
api
.
params
(
x
=
'const migraphx::shape&'
),
...
...
@@ -122,6 +123,7 @@ def shape(h):
returns
=
'bool'
,
const
=
True
)
h
.
method
(
'standard'
,
returns
=
'bool'
,
const
=
True
)
h
.
method
(
'index'
,
api
.
params
(
i
=
'size_t'
),
returns
=
'size_t'
,
const
=
True
)
@
auto_handle
()
...
...
@@ -274,6 +276,13 @@ def program(h):
params
=
'std::unordered_map<std::string, migraphx::argument>'
),
invoke
=
'migraphx::run($@)'
,
returns
=
'std::vector<migraphx::argument>'
)
h
.
method
(
'run_async'
,
api
.
params
(
params
=
'std::unordered_map<std::string, migraphx::argument>'
,
s
=
'void*'
,
name
=
'const char *'
),
invoke
=
'migraphx::run_async($@)'
,
returns
=
'std::vector<migraphx::argument>'
)
h
.
method
(
'equal'
,
api
.
params
(
x
=
'const migraphx::program&'
),
invoke
=
'migraphx::equal($@)'
,
...
...
@@ -450,4 +459,8 @@ def experimental_custom_op(h):
h
.
virtual
(
'compute_shape'
,
api
.
params
(
inputs
=
'std::vector<migraphx::shape>'
),
returns
=
'migraphx::shape'
)
h
.
virtual
(
'output_alias'
,
api
.
params
(
inputs
=
'std::vector<migraphx::shape>'
),
returns
=
'std::vector<size_t>'
)
h
.
virtual
(
'runs_on_offload_target'
,
returns
=
'bool'
)
h
.
method
(
'register'
,
invoke
=
'migraphx::register_custom_op($@)'
)
src/include/migraphx/context.hpp
View file @
84add0fc
...
...
@@ -66,6 +66,15 @@ any_ptr get_queue_context(T&)
{
return
{};
}
template
<
class
T
>
void
wait_for_context
(
T
&
,
any_ptr
)
{
}
template
<
class
T
>
void
finish_on_context
(
T
&
,
any_ptr
)
{
}
#ifdef TYPE_ERASED_DECLARATION
...
...
@@ -78,6 +87,10 @@ struct context
void
from_value
(
const
value
&
v
);
// (optional)
any_ptr
get_queue
();
// (optional)
void
wait_for
(
any_ptr
queue
);
// (optional)
void
finish_on
(
any_ptr
queue
);
//
void
finish
()
const
;
};
...
...
@@ -165,6 +178,18 @@ struct context
return
(
*
this
).
private_detail_te_get_handle
().
get_queue
();
}
void
wait_for
(
any_ptr
queue
)
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
wait_for
(
queue
);
}
void
finish_on
(
any_ptr
queue
)
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
finish_on
(
queue
);
}
void
finish
()
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
...
...
@@ -187,6 +212,8 @@ struct context
virtual
value
to_value
()
const
=
0
;
virtual
void
from_value
(
const
value
&
v
)
=
0
;
virtual
any_ptr
get_queue
()
=
0
;
virtual
void
wait_for
(
any_ptr
queue
)
=
0
;
virtual
void
finish_on
(
any_ptr
queue
)
=
0
;
virtual
void
finish
()
const
=
0
;
};
...
...
@@ -231,6 +258,33 @@ struct context
return
get_queue_context
(
private_detail_te_self
);
}
template
<
class
T
>
static
auto
private_detail_te_default_wait_for
(
char
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
->
decltype
(
private_detail_te_self
.
wait_for
(
queue
))
{
private_detail_te_self
.
wait_for
(
queue
);
}
template
<
class
T
>
static
void
private_detail_te_default_wait_for
(
float
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
{
wait_for_context
(
private_detail_te_self
,
queue
);
}
template
<
class
T
>
static
auto
private_detail_te_default_finish_on
(
char
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
->
decltype
(
private_detail_te_self
.
finish_on
(
queue
))
{
private_detail_te_self
.
finish_on
(
queue
);
}
template
<
class
T
>
static
void
private_detail_te_default_finish_on
(
float
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
{
finish_on_context
(
private_detail_te_self
,
queue
);
}
template
<
typename
PrivateDetailTypeErasedT
>
struct
private_detail_te_handle_type
:
private_detail_te_handle_base_type
{
...
...
@@ -248,7 +302,7 @@ struct context
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
)
)
:
private_detail_te_value
(
value
)
{
}
...
...
@@ -277,6 +331,18 @@ struct context
return
private_detail_te_default_get_queue
(
char
(
0
),
private_detail_te_value
);
}
void
wait_for
(
any_ptr
queue
)
override
{
private_detail_te_default_wait_for
(
char
(
0
),
private_detail_te_value
,
queue
);
}
void
finish_on
(
any_ptr
queue
)
override
{
private_detail_te_default_finish_on
(
char
(
0
),
private_detail_te_value
,
queue
);
}
void
finish
()
const
override
{
private_detail_te_value
.
finish
();
}
PrivateDetailTypeErasedT
private_detail_te_value
;
...
...
src/include/migraphx/execution_environment.hpp
0 → 100644
View file @
84add0fc
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP
#include <migraphx/any_ptr.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
execution_environment
{
any_ptr
queue
=
any_ptr
{};
bool
async
=
false
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
/* MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP */
src/include/migraphx/program.hpp
View file @
84add0fc
...
...
@@ -37,6 +37,7 @@
#include <migraphx/assignment_options.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <migraphx/execution_environment.hpp>
#include <algorithm>
#include <iostream>
...
...
@@ -76,8 +77,8 @@ struct program
std
::
unordered_map
<
std
::
string
,
shape
>
get_parameter_shapes
()
const
;
std
::
vector
<
argument
>
eval
(
parameter_map
params
)
const
;
std
::
vector
<
argument
>
eval
(
parameter_map
params
,
execution_environment
exec_env
=
execution_environment
{})
const
;
std
::
size_t
size
()
const
;
std
::
vector
<
shape
>
get_output_shapes
()
const
;
...
...
src/program.cpp
View file @
84add0fc
...
...
@@ -398,7 +398,7 @@ std::vector<argument> generic_eval(const program& p,
return
generic_eval
(
mm
,
ctx
,
params
,
{},
make_trace
);
}
std
::
vector
<
argument
>
program
::
eval
(
parameter_map
params
)
const
std
::
vector
<
argument
>
program
::
eval
(
parameter_map
params
,
execution_environment
exec_env
)
const
{
auto
&
ctx
=
this
->
impl
->
ctx
;
#ifndef NDEBUG
...
...
@@ -423,6 +423,12 @@ std::vector<argument> program::eval(parameter_map params) const
#endif
auto
trace_level
=
value_of
(
MIGRAPHX_TRACE_EVAL
{});
std
::
vector
<
argument
>
ret
;
if
(
exec_env
.
async
)
{
ctx
.
wait_for
(
exec_env
.
queue
);
}
if
(
trace_level
>
0
)
{
...
...
@@ -434,49 +440,56 @@ std::vector<argument> program::eval(parameter_map params) const
ins_out
[
x
]
=
ss
.
str
();
});
ret
urn
generic_eval
(
*
this
,
ctx
,
std
::
move
(
params
),
with_check_context
([
&
](
auto
&
ins
,
auto
f
,
auto
&&
check_context
)
{
ctx
.
finish
();
std
::
cout
<<
"Run instruction: "
<<
ins_out
.
at
(
ins
)
<<
std
::
endl
;
timer
t
{};
auto
result
=
check_context
(
f
);
double
t1
=
t
.
record
<
milliseconds
>
();
ctx
.
finish
();
double
t2
=
t
.
record
<
milliseconds
>
();
std
::
cout
<<
"Time: "
<<
t1
<<
"ms, "
<<
t2
<<
"ms"
<<
std
::
endl
;
if
(
trace_level
>
1
and
ins
->
name
().
front
()
!=
'@'
and
ins
->
name
()
!=
"load"
and
not
result
.
empty
())
{
target
tgt
=
make_target
(
this
->
impl
->
target_name
);
auto
buffer
=
tgt
.
copy_from
(
result
);
if
(
trace_level
==
2
)
{
std
::
cout
<<
"Output has "
<<
to_string_range
(
classify_argument
(
buffer
))
<<
std
::
endl
;
std
::
cout
<<
"Output: "
;
preview_argument
(
std
::
cout
,
buffer
);
std
::
cout
<<
std
::
endl
;
}
else
{
std
::
cout
<<
"Output: "
<<
buffer
<<
std
::
endl
;
}
}
return
result
;
}));
ret
=
generic_eval
(
*
this
,
ctx
,
std
::
move
(
params
),
with_check_context
([
&
](
auto
&
ins
,
auto
f
,
auto
&&
check_context
)
{
ctx
.
finish
();
std
::
cout
<<
"Run instruction: "
<<
ins_out
.
at
(
ins
)
<<
std
::
endl
;
timer
t
{};
auto
result
=
check_context
(
f
);
double
t1
=
t
.
record
<
milliseconds
>
();
ctx
.
finish
();
double
t2
=
t
.
record
<
milliseconds
>
();
std
::
cout
<<
"Time: "
<<
t1
<<
"ms, "
<<
t2
<<
"ms"
<<
std
::
endl
;
if
(
trace_level
>
1
and
ins
->
name
().
front
()
!=
'@'
and
ins
->
name
()
!=
"load"
and
not
result
.
empty
())
{
target
tgt
=
make_target
(
this
->
impl
->
target_name
);
auto
buffer
=
tgt
.
copy_from
(
result
);
if
(
trace_level
==
2
)
{
std
::
cout
<<
"Output has "
<<
to_string_range
(
classify_argument
(
buffer
))
<<
std
::
endl
;
std
::
cout
<<
"Output: "
;
preview_argument
(
std
::
cout
,
buffer
);
std
::
cout
<<
std
::
endl
;
}
else
{
std
::
cout
<<
"Output: "
<<
buffer
<<
std
::
endl
;
}
}
return
result
;
}));
}
else
{
ret
urn
generic_eval
(
*
this
,
ctx
,
std
::
move
(
params
),
with_check_context
([
&
](
auto
&
,
auto
f
,
auto
&&
check_context
)
{
return
check_context
(
f
);
}));
ret
=
generic_eval
(
*
this
,
ctx
,
std
::
move
(
params
),
with_check_context
([
&
](
auto
&
,
auto
f
,
auto
&&
check_context
)
{
return
check_context
(
f
);
}));
}
if
(
exec_env
.
async
)
{
ctx
.
finish_on
(
exec_env
.
queue
);
}
return
ret
;
}
const
int
program_file_version
=
5
;
...
...
src/py/migraphx_py.cpp
View file @
84add0fc
...
...
@@ -355,6 +355,23 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
}
return
p
.
eval
(
pm
);
})
.
def
(
"run_async"
,
[](
migraphx
::
program
&
p
,
py
::
dict
params
,
std
::
uintptr_t
stream
,
std
::
string
stream_name
)
{
migraphx
::
parameter_map
pm
;
for
(
auto
x
:
params
)
{
std
::
string
key
=
x
.
first
.
cast
<
std
::
string
>
();
py
::
buffer
b
=
x
.
second
.
cast
<
py
::
buffer
>
();
py
::
buffer_info
info
=
b
.
request
();
pm
[
key
]
=
migraphx
::
argument
(
to_shape
(
info
),
info
.
ptr
);
}
migraphx
::
execution_environment
exec_env
{
migraphx
::
any_ptr
(
reinterpret_cast
<
void
*>
(
stream
),
stream_name
),
true
};
return
p
.
eval
(
pm
,
exec_env
);
})
.
def
(
"sort"
,
&
migraphx
::
program
::
sort
)
.
def
(
"print"
,
[](
const
migraphx
::
program
&
p
)
{
std
::
cout
<<
p
<<
std
::
endl
;
})
.
def
(
"__eq__"
,
std
::
equal_to
<
migraphx
::
program
>
{})
...
...
src/targets/gpu/include/migraphx/gpu/context.hpp
View file @
84add0fc
...
...
@@ -197,7 +197,9 @@ struct hip_device
struct
context
{
context
(
std
::
size_t
device_id
=
0
,
std
::
size_t
n
=
value_of
(
MIGRAPHX_NSTREAMS
{},
1
))
:
current_device
(
std
::
make_shared
<
hip_device
>
(
device_id
,
n
))
:
current_device
(
std
::
make_shared
<
hip_device
>
(
device_id
,
n
)),
begin_event
(
create_event
()),
finish_event
(
create_event
())
{
}
...
...
@@ -274,6 +276,24 @@ struct context
this
->
current_device
=
std
::
make_shared
<
hip_device
>
(
0
,
n_streams
);
}
void
wait_for
(
any_ptr
queue
)
{
auto
status
=
hipEventRecord
(
begin_event
.
get
(),
queue
.
get
<
hipStream_t
>
());
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"failed to record "
+
hip_error
(
status
));
get_stream
().
wait
(
begin_event
.
get
());
}
void
finish_on
(
any_ptr
queue
)
{
get_stream
().
record
(
finish_event
.
get
());
auto
status
=
hipStreamWaitEvent
(
queue
.
get
<
hipStream_t
>
(),
finish_event
.
get
(),
0
);
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"Failed to wait on event "
+
hip_error
(
status
));
}
any_ptr
get_queue
()
{
return
get_stream
().
get
();
}
void
enable_perf_measurement
(
bool
b
=
true
)
...
...
@@ -316,9 +336,13 @@ struct context
// TODO: Make this a vector to support multiple devices
std
::
shared_ptr
<
hip_device
>
current_device
;
std
::
vector
<
shared
<
hip_event_ptr
>>
events
;
bool
measure_perf
=
false
;
bool
measure_perf
=
false
;
// for event perf timing
shared
<
hip_event_ptr
>
start_event
=
nullptr
;
shared
<
hip_event_ptr
>
stop_event
=
nullptr
;
// for stream syncronization
shared
<
hip_event_ptr
>
begin_event
=
nullptr
;
shared
<
hip_event_ptr
>
finish_event
=
nullptr
;
};
inline
void
migraphx_to_value
(
value
&
v
,
const
context
&
ctx
)
{
v
=
ctx
.
to_value
();
}
...
...
src/targets/gpu/jit/softmax.cpp
View file @
84add0fc
...
...
@@ -32,6 +32,8 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_USE_FAST_SOFTMAX
)
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
static
const
char
*
const
softmax_kernel
=
R"__migraphx__(
...
...
@@ -81,6 +83,9 @@ struct softmax_compiler : compiler<softmax_compiler>
options
.
inputs
=
inputs
;
options
.
kernel_name
=
"softmax_kernel"
;
if
(
enabled
(
MIGRAPHX_USE_FAST_SOFTMAX
{}))
options
.
params
=
"-DMIGRAPHX_USE_FAST_SOFTMAX"
;
auto
src
=
interpolate_string
(
softmax_kernel
,
{{
"transformers"
,
make_transformer_args
(
vec
)},
{
"axis"
,
to_string
(
axis
)}});
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
84add0fc
...
...
@@ -197,11 +197,11 @@ struct block
struct
reducer
{
index
idx
;
Slicer
slice
r
;
Slicer
slice
;
template
<
class
Op
,
class
T
,
class
Read
>
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
{
return
sliced
(
slice
r
,
[
=
](
auto
x
,
auto
...
xs
)
{
return
sliced
(
slice
,
[
=
](
auto
x
,
auto
...
xs
)
{
return
block_reduce
(
idx
,
op
,
init
,
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
return
vec_reduce
(
read
(
x
[
j
],
xs
[
j
]...),
op
);
});
...
...
@@ -218,7 +218,7 @@ struct block
template
<
class
F
>
__device__
auto
inner
(
F
f
)
const
{
return
sliced
(
slice
r
,
[
=
](
auto
x
,
auto
...
xs
)
{
return
sliced
(
slice
,
[
=
](
auto
x
,
auto
...
xs
)
{
idx
.
local_stride
(
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
f
(
x
[
j
],
xs
[
j
]...);
});
});
}
...
...
@@ -226,7 +226,7 @@ struct block
template
<
class
Input
>
constexpr
auto
elements
()
const
{
using
reduce_type
=
decltype
(
slice
r
(
Input
{}));
using
reduce_type
=
decltype
(
slice
(
Input
{}));
using
value_type
=
typename
Input
::
type
;
constexpr
auto
relements
=
get_shape_c
<
reduce_type
>
{}.
elements
();
if
constexpr
(
vec_size
<
value_type
>
()
>
1
)
...
...
@@ -260,11 +260,11 @@ struct lane
struct
reducer
{
index
idx
;
Slicer
slice
r
;
Slicer
slice
;
template
<
class
Op
,
class
T
,
class
Read
>
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
{
return
sliced
(
slice
r
,
[
=
](
auto
x
,
auto
...
xs
)
{
return
sliced
(
slice
,
[
=
](
auto
x
,
auto
...
xs
)
{
using
type
=
typename
decltype
(
x
)
::
type
;
type
r
=
init
;
for
(
index_int
j
=
0
;
j
<
x
.
get_shape
().
elements
();
j
++
)
...
...
@@ -284,7 +284,7 @@ struct lane
template
<
class
F
>
__device__
auto
inner
(
F
f
)
const
{
return
sliced
(
slice
r
,
[
=
](
auto
x
,
auto
...
xs
)
{
return
sliced
(
slice
,
[
=
](
auto
x
,
auto
...
xs
)
{
for
(
index_int
j
=
0
;
j
<
x
.
get_shape
().
elements
();
j
++
)
{
f
(
x
[
j
],
xs
[
j
]...);
...
...
@@ -295,7 +295,7 @@ struct lane
template
<
class
Input
>
constexpr
auto
elements
()
const
{
using
reduce_type
=
decltype
(
slice
r
(
Input
{}));
using
reduce_type
=
decltype
(
slice
(
Input
{}));
return
get_shape_c
<
reduce_type
>
{}.
elements
();
}
};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
View file @
84add0fc
...
...
@@ -33,11 +33,15 @@ template <index_int Axis, class Input, class Output>
__device__
void
softmax
(
Input
input
,
Output
output
)
{
reduce
::
block
::
run
<
reduce
::
with_axis
<
Input
,
Axis
>>
([
&
](
auto
,
auto
r
)
{
auto
batch_max
=
r
.
reduce
(
op
::
max
{},
lowest
{},
op
::
id
{})(
input
);
auto
batch_sum
=
r
.
reduce
(
op
::
sum
{},
0
,
[
&
](
auto
x
)
{
return
migraphx
::
exp
(
x
-
batch_max
);
})(
input
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
migraphx
::
exp
(
x
-
batch_max
)
/
batch_sum
;
})(
output
,
input
);
#ifdef MIGRAPHX_USE_FAST_SOFTMAX
const
auto
c
=
vec_at
(
r
.
slice
(
input
)[
0
],
0
);
#else
const
auto
c
=
r
.
reduce
(
op
::
max
{},
lowest
{},
op
::
id
{})(
input
);
#endif
auto
batch_sum
=
r
.
reduce
(
op
::
sum
{},
0
,
[
&
](
auto
x
)
{
return
migraphx
::
convert
<
float
>
(
migraphx
::
exp
(
x
-
c
));
})(
input
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
migraphx
::
exp
(
x
-
c
)
/
batch_sum
;
})(
output
,
input
);
});
}
...
...
src/targets/gpu/lowering.cpp
View file @
84add0fc
...
...
@@ -26,6 +26,8 @@
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp>
...
...
@@ -171,7 +173,8 @@ struct miopen_apply
init
();
for
(
auto
it
=
mod
->
begin
();
it
!=
mod
->
end
();
it
++
)
{
auto
s
=
it
->
get_shape
();
auto
s
=
it
->
get_shape
();
auto
attrs
=
it
->
get_operator
().
attributes
();
if
(
apply_map
.
count
(
it
->
name
())
>
0
)
{
check_shape
(
s
,
apply_map
.
at
(
it
->
name
())(
it
));
...
...
@@ -180,11 +183,37 @@ struct miopen_apply
{
check_shape
(
s
,
insert_precompile_op
(
it
));
}
else
if
(
attrs
.
contains
(
"target"
))
{
check_shape
(
s
,
insert_custom_op
(
it
,
attrs
));
}
}
copy_params
();
}
instruction_ref
insert_custom_op
(
instruction_ref
ins
,
const
value
&
attrs
)
const
{
const
auto
&
custom_op
=
ins
->
get_operator
();
if
(
attrs
.
at
(
"target"
)
==
"cpu"
)
{
auto
s
=
ins
->
get_shape
();
std
::
vector
<
instruction_ref
>
cpu_inputs
;
auto
inputs
=
ins
->
inputs
();
auto
output
=
inputs
.
back
();
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
cpu_inputs
),
[
&
](
auto
in
)
{
return
mod
->
insert_instruction
(
ins
,
make_op
(
"hip::copy_from_gpu"
),
in
);
});
cpu_inputs
.
front
()
=
mod
->
insert_instruction
(
ins
,
make_op
(
"hip::sync_stream"
),
cpu_inputs
);
auto
cpu_out
=
mod
->
insert_instruction
(
ins
,
custom_op
,
cpu_inputs
);
auto
gpu_out
=
mod
->
insert_instruction
(
ins
,
make_op
(
"hip::copy_to_gpu"
),
cpu_out
,
output
);
return
mod
->
replace_instruction
(
ins
,
gpu_out
);
}
return
ins
;
}
instruction_ref
insert_precompile_op
(
instruction_ref
ins
)
const
{
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
...
...
test/api/test_custom_op.cpp
View file @
84add0fc
...
...
@@ -43,6 +43,8 @@ struct sigmoid_custom_op final : migraphx::experimental_custom_op_base
return
inputs
[
1
];
}
virtual
bool
runs_on_offload_target
()
const
override
{
return
true
;
}
virtual
migraphx
::
shape
compute_shape
(
migraphx
::
shapes
inputs
)
const
override
{
if
(
inputs
.
size
()
!=
2
)
...
...
@@ -111,4 +113,45 @@ TEST_CASE(run_sigmoid_with_incorrect_shape)
"Error in compute_shape of: sigmoid_custom_op: op must have two inputs"
));
}
struct
identity_custom_op
final
:
migraphx
::
experimental_custom_op_base
{
virtual
std
::
string
name
()
const
override
{
return
"identity_custom_op"
;
}
virtual
migraphx
::
argument
compute
(
migraphx
::
context
,
migraphx
::
shape
,
migraphx
::
arguments
inputs
)
const
override
{
return
inputs
[
0
];
}
virtual
bool
runs_on_offload_target
()
const
override
{
return
true
;
}
virtual
migraphx
::
shape
compute_shape
(
migraphx
::
shapes
inputs
)
const
override
{
if
(
inputs
.
size
()
!=
1
)
{
throw
std
::
runtime_error
(
"Identity op must have only one input"
);
}
return
inputs
.
back
();
}
virtual
std
::
vector
<
size_t
>
output_alias
(
migraphx
::
shapes
)
const
override
{
return
{
0
,
1
};
}
};
TEST_CASE
(
run_custom_op_with_invalid_output_alias
)
{
identity_custom_op
i_op
;
migraphx
::
register_experimental_custom_op
(
i_op
);
auto
op
=
migraphx
::
operation
(
"identity_custom_op"
);
EXPECT
(
op
.
name
()
==
"identity_custom_op"
);
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx_shape_float_type
,
{
12
}};
migraphx
::
module
m
=
p
.
get_main_module
();
auto
x
=
m
.
add_parameter
(
"x"
,
s
);
auto
i_ins
=
m
.
add_instruction
(
migraphx
::
operation
(
"identity_custom_op"
),
{
x
});
migraphx_test_private_disable_exception_catch
(
true
);
EXPECT
(
test
::
throws
<
std
::
exception
>
(
[
&
]
{
p
.
compile
(
migraphx
::
target
(
"ref"
));
},
"Currently, CustomOps in MIGraphX only supports one output_alias"
));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/api/test_custom_op_gpu.cpp
View file @
84add0fc
...
...
@@ -24,40 +24,89 @@
#include <hip/hip_runtime_api.h>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include <numeric>
#include <stdexcept>
#include "test.hpp"
#define MIGRAPHX_HIP_ASSERT(x) (EXPECT(x == hipSuccess))
struct
simple_custom_op
final
:
migraphx
::
experimental_custom_op_base
struct
half_copy_host
final
:
migraphx
::
experimental_custom_op_base
{
virtual
std
::
string
name
()
const
override
{
return
"simple_custom_op"
;
}
virtual
std
::
string
name
()
const
override
{
return
"half_copy_host"
;
}
virtual
bool
runs_on_offload_target
()
const
override
{
return
false
;
}
virtual
migraphx
::
argument
compute
(
migraphx
::
context
ctx
,
migraphx
::
shape
,
migraphx
::
arguments
inputs
)
const
override
{
// sets first half size_bytes of the input 0, and rest of the half bytes are copied.
int
*
h_output
=
nullptr
;
auto
*
d_output
=
reinterpret_cast
<
int
*>
(
inputs
[
0
].
data
());
auto
input_bytes
=
inputs
[
0
].
get_shape
().
bytes
();
auto
*
output_ptr
=
inputs
[
1
].
data
();
auto
copy_bytes
=
input_bytes
/
2
;
// This custom op simply sets first half size_bytes of the input to 0, and rest of the half
// bytes are copied. for this custom_op, it does its computation on the host. Therefore,
// `runs_on_offload_target()` is set to false. MIGraphX would inject necessary buffer copies
// to and from GPU to Host based on `runs_on_offload_targe()` flag for input buffers as well
// as the output buffers
auto
*
input_buffer_ptr
=
inputs
[
0
].
data
();
auto
*
output_buffer_ptr
=
inputs
[
1
].
data
();
auto
input_bytes
=
inputs
[
0
].
get_shape
().
bytes
();
auto
copy_bytes
=
input_bytes
/
2
;
MIGRAPHX_HIP_ASSERT
(
hipSetDevice
(
0
));
MIGRAPHX_HIP_ASSERT
(
hipHostMalloc
(
&
h_output
,
input_bytes
));
MIGRAPHX_HIP_ASSERT
(
hipMemcpyAsync
(
h_output
,
d_output
,
input_bytes
,
hipMemcpyDeviceToHost
,
ctx
.
get_queue
<
hipStream_t
>
()));
MIGRAPHX_HIP_ASSERT
(
hipMemcpyAsync
(
output_buffer_ptr
,
input_buffer_ptr
,
input_bytes
,
hipMemcpyHostToHost
,
ctx
.
get_queue
<
hipStream_t
>
()));
MIGRAPHX_HIP_ASSERT
(
hipDeviceSynchronize
());
MIGRAPHX_HIP_ASSERT
(
hipMemset
(
output_buffer_ptr
,
0
,
copy_bytes
));
MIGRAPHX_HIP_ASSERT
(
hipDeviceSynchronize
());
MIGRAPHX_HIP_ASSERT
(
hipMemset
(
h_output
,
0
,
copy_bytes
));
return
inputs
[
1
];
}
virtual
migraphx
::
shape
compute_shape
(
migraphx
::
shapes
inputs
)
const
override
{
if
(
not
inputs
[
0
].
standard
()
or
not
inputs
[
1
].
standard
())
{
throw
std
::
runtime_error
(
"Input args must be standard shaped"
);
}
if
(
inputs
.
size
()
!=
2
)
{
throw
std
::
runtime_error
(
"number of inputs must be 2"
);
}
return
inputs
.
back
();
}
};
struct
half_copy_device
final
:
migraphx
::
experimental_custom_op_base
{
virtual
std
::
string
name
()
const
override
{
return
"half_copy_device"
;
}
virtual
bool
runs_on_offload_target
()
const
override
{
return
true
;
}
virtual
migraphx
::
argument
compute
(
migraphx
::
context
ctx
,
migraphx
::
shape
,
migraphx
::
arguments
inputs
)
const
override
{
// This custom op simply sets first half size_bytes of the input to 0, and rest of the half
// bytes are copied. for this custom_op, it does its computation on the "GPU". Therefore,
// `runs_on_offload_target()` is set to "true".
auto
*
input_buffer_ptr
=
inputs
[
0
].
data
();
auto
*
output_buffer_ptr
=
inputs
[
1
].
data
();
auto
input_bytes
=
inputs
[
0
].
get_shape
().
bytes
();
auto
copy_bytes
=
input_bytes
/
2
;
MIGRAPHX_HIP_ASSERT
(
hipSetDevice
(
0
));
MIGRAPHX_HIP_ASSERT
(
hipMemcpyAsync
(
output_buffer_ptr
,
input_buffer_ptr
,
input_bytes
,
hipMemcpyDeviceToDevice
,
ctx
.
get_queue
<
hipStream_t
>
()));
MIGRAPHX_HIP_ASSERT
(
hipDeviceSynchronize
());
MIGRAPHX_HIP_ASSERT
(
hipMem
cpy
(
output_
ptr
,
h_output
,
input_bytes
,
hipMemcpyHostToDevice
));
MIGRAPHX_HIP_ASSERT
(
hipMem
set
(
output_
buffer_ptr
,
0
,
copy_bytes
));
MIGRAPHX_HIP_ASSERT
(
hipDeviceSynchronize
());
MIGRAPHX_HIP_ASSERT
(
hipHostFree
(
h_output
));
return
inputs
[
1
];
}
virtual
migraphx
::
shape
compute_shape
(
migraphx
::
shapes
inputs
)
const
override
{
if
(
not
inputs
[
0
].
standard
())
if
(
not
inputs
[
0
].
standard
()
or
not
inputs
[
1
].
standard
()
)
{
throw
std
::
runtime_error
(
"
firs
t arg must be standard shaped"
);
throw
std
::
runtime_error
(
"
Inpu
t arg
s
must be standard shaped"
);
}
if
(
inputs
.
size
()
!=
2
)
{
...
...
@@ -67,36 +116,208 @@ struct simple_custom_op final : migraphx::experimental_custom_op_base
}
};
TEST_CASE
(
run_simple_custom_op
)
// overwrites input buffer
struct
half_copy_device_same_buffer
final
:
migraphx
::
experimental_custom_op_base
{
virtual
std
::
string
name
()
const
override
{
return
"half_copy_device_same_buffer"
;
}
virtual
bool
runs_on_offload_target
()
const
override
{
return
true
;
}
virtual
migraphx
::
argument
compute
(
migraphx
::
context
,
migraphx
::
shape
,
migraphx
::
arguments
inputs
)
const
override
{
// This custom op simply sets first half size_bytes of the input 0, and rest of the half
// bytes are copied. for this custom_op, it does its computation on the "device". Therefore,
// `runs_on_offload_target()` is set to "true"
auto
*
buffer_ptr
=
inputs
[
0
].
data
();
auto
input_bytes
=
inputs
[
0
].
get_shape
().
bytes
();
auto
copy_bytes
=
input_bytes
/
2
;
MIGRAPHX_HIP_ASSERT
(
hipSetDevice
(
0
));
MIGRAPHX_HIP_ASSERT
(
hipMemset
(
buffer_ptr
,
0
,
copy_bytes
));
MIGRAPHX_HIP_ASSERT
(
hipDeviceSynchronize
());
return
inputs
[
0
];
}
virtual
migraphx
::
shape
compute_shape
(
migraphx
::
shapes
inputs
)
const
override
{
if
(
not
inputs
[
0
].
standard
())
{
throw
std
::
runtime_error
(
"Input arg must be standard shaped"
);
}
return
inputs
.
front
();
}
};
TEST_CASE
(
register_half_copy_op
)
{
half_copy_host
hch
;
migraphx
::
register_experimental_custom_op
(
hch
);
auto
op
=
migraphx
::
operation
(
"half_copy_host"
);
EXPECT
(
op
.
name
()
==
"half_copy_host"
);
half_copy_device
hcd
;
migraphx
::
register_experimental_custom_op
(
hcd
);
op
=
migraphx
::
operation
(
"half_copy_device"
);
EXPECT
(
op
.
name
()
==
"half_copy_device"
);
half_copy_device_same_buffer
hcdsb
;
migraphx
::
register_experimental_custom_op
(
hcdsb
);
op
=
migraphx
::
operation
(
"half_copy_device_same_buffer"
);
EXPECT
(
op
.
name
()
==
"half_copy_device_same_buffer"
);
}
TEST_CASE
(
half_copy_custom_op_test
)
{
auto
run_test_prog
=
[](
const
std
::
string
&
op_name
,
bool
buffer_alloc
)
{
migraphx
::
program
p
;
migraphx
::
module
m
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx_shape_float_type
,
{
4
,
3
}};
auto
x
=
m
.
add_parameter
(
"x"
,
s
);
migraphx
::
instructions
inputs
=
{
x
};
if
(
buffer_alloc
)
{
auto
alloc
=
m
.
add_allocation
(
s
);
inputs
=
{
x
,
alloc
};
}
auto
half_copy_ins
=
m
.
add_instruction
(
migraphx
::
operation
(
op_name
.
c_str
()),
inputs
);
m
.
add_return
({
half_copy_ins
});
migraphx
::
compile_options
options
;
options
.
set_offload_copy
();
p
.
compile
(
migraphx
::
target
(
"gpu"
),
options
);
migraphx
::
program_parameters
pp
;
std
::
vector
<
float
>
x_data
(
12
);
std
::
iota
(
x_data
.
begin
(),
x_data
.
end
(),
0
);
pp
.
add
(
"x"
,
migraphx
::
argument
(
s
,
x_data
.
data
()));
auto
results
=
p
.
eval
(
pp
);
auto
result
=
results
[
0
];
auto
result_vec
=
result
.
as_vector
<
float
>
();
std
::
vector
<
float
>
expected_result
(
12
,
0
);
std
::
iota
(
expected_result
.
begin
()
+
6
,
expected_result
.
end
(),
6
);
EXPECT
(
bool
{
result
==
migraphx
::
argument
(
s
,
expected_result
.
data
())});
};
// register all the ops
half_copy_host
hch
;
migraphx
::
register_experimental_custom_op
(
hch
);
half_copy_device
hcd
;
migraphx
::
register_experimental_custom_op
(
hcd
);
half_copy_device_same_buffer
hcdsb
;
migraphx
::
register_experimental_custom_op
(
hcdsb
);
std
::
vector
<
std
::
pair
<
std
::
string
,
bool
>>
tests_config
=
{
{
"half_copy_host"
,
true
},
{
"half_copy_device"
,
true
},
{
"half_copy_device_same_buffer"
,
false
}};
for
(
const
auto
&
i
:
tests_config
)
{
run_test_prog
(
i
.
first
,
i
.
second
);
}
}
struct
stride_two
final
:
migraphx
::
experimental_custom_op_base
{
virtual
std
::
string
name
()
const
override
{
return
"stride_two"
;
}
virtual
migraphx
::
argument
compute
(
migraphx
::
context
,
migraphx
::
shape
out_shape
,
migraphx
::
arguments
inputs
)
const
override
{
return
{
out_shape
,
inputs
[
0
].
data
()};
}
virtual
migraphx
::
shape
compute_shape
(
migraphx
::
shapes
inputs
)
const
override
{
if
(
inputs
.
size
()
!=
1
)
{
throw
std
::
runtime_error
(
"stride_two op must have only one input argument"
);
};
if
(
not
inputs
[
0
].
standard
())
{
throw
std
::
runtime_error
(
"stride_two op only works on the standard input shapes"
);
}
migraphx
::
shape
input_s
=
inputs
[
0
];
std
::
vector
<
size_t
>
dims
=
input_s
.
lengths
();
std
::
vector
<
size_t
>
new_dims
;
std
::
vector
<
size_t
>
strides
=
input_s
.
strides
();
std
::
vector
<
size_t
>
new_strides
;
std
::
for_each
(
dims
.
begin
(),
dims
.
end
(),
[
&
](
auto
i
)
{
new_dims
.
push_back
(
i
/
2
);
});
std
::
for_each
(
strides
.
begin
(),
strides
.
end
(),
[
&
](
auto
i
)
{
new_strides
.
push_back
(
i
*
2
);
});
migraphx
::
shape
output_shape
{
input_s
.
type
(),
new_dims
,
new_strides
};
return
output_shape
;
}
virtual
bool
runs_on_offload_target
()
const
override
{
return
true
;
}
virtual
std
::
vector
<
size_t
>
output_alias
(
migraphx
::
shapes
)
const
override
{
return
{
0
};
};
};
TEST_CASE
(
stride_two_custom_op_test
)
{
simple_custom_op
simple_op
;
migraphx
::
register_experimental_custom_op
(
simple_op
);
stride_two
st
;
migraphx
::
register_experimental_custom_op
(
st
);
migraphx
::
program
p
;
migraphx
::
module
m
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx_shape_float_type
,
{
4
,
4
,
4
}};
auto
x
=
m
.
add_parameter
(
"x"
,
s
);
auto
stride_two_ins
=
m
.
add_instruction
(
migraphx
::
operation
(
"stride_two"
),
{
x
});
m
.
add_return
({
stride_two_ins
});
migraphx
::
compile_options
options
;
options
.
set_offload_copy
();
p
.
compile
(
migraphx
::
target
(
"gpu"
),
options
);
migraphx
::
program_parameters
pp
;
std
::
vector
<
float
>
x_data
(
64
);
std
::
iota
(
x_data
.
begin
(),
x_data
.
end
(),
0
);
pp
.
add
(
"x"
,
migraphx
::
argument
(
s
,
x_data
.
data
()));
auto
results
=
p
.
eval
(
pp
);
auto
result
=
results
[
0
];
auto
result_vec
=
result
.
as_vector
<
float
>
();
std
::
vector
<
float
>
expected_result
=
{
0
,
2
,
8
,
10
,
32
,
34
,
40
,
42
};
EXPECT
(
result_vec
==
expected_result
);
}
TEST_CASE
(
custom_op_with_pre_and_post_subgraph_test
)
{
half_copy_host
hco
;
migraphx
::
register_experimental_custom_op
(
hco
);
stride_two
st
;
migraphx
::
register_experimental_custom_op
(
st
);
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx_shape_int32_type
,
{
4
,
3
}};
migraphx
::
shape
trans_shape
{
migraphx_shape_int32_type
,
{
3
,
4
}};
migraphx
::
shape
s
{
migraphx_shape_float_type
,
{
4
,
6
}};
migraphx
::
module
m
=
p
.
get_main_module
();
auto
x
=
m
.
add_parameter
(
"x"
,
s
);
auto
neg
=
m
.
add_instruction
(
migraphx
::
operation
(
"neg"
),
x
);
auto
alloc
=
m
.
add_allocation
(
trans_shape
);
auto
neg_trans
=
m
.
add_instruction
(
migraphx
::
operation
(
"transpose"
,
"{permutation: [1, 0]}"
),
{
neg
});
auto
neg_cont
=
m
.
add_instruction
(
migraphx
::
operation
(
"contiguous"
),
{
neg_trans
});
auto
custom_kernel
=
m
.
add_instruction
(
migraphx
::
operation
(
"simple_custom_op"
),
{
neg_cont
,
alloc
});
auto
relu
=
m
.
add_instruction
(
migraphx
::
operation
(
"relu"
),
custom_kernel
);
m
.
add_return
({
relu
});
// pre-subgraph
auto
neg_ins
=
m
.
add_instruction
(
migraphx
::
operation
(
"neg"
),
x
);
auto
trans_ins
=
m
.
add_instruction
(
migraphx
::
operation
(
"transpose"
,
"{permutation: [1, 0]}"
),
{
neg_ins
});
auto
cont_ins
=
m
.
add_instruction
(
migraphx
::
operation
(
"contiguous"
),
{
trans_ins
});
// custom_op
migraphx
::
shape
trans_shape
{
migraphx_shape_float_type
,
{
6
,
4
}};
auto
alloc
=
m
.
add_allocation
(
trans_shape
);
auto
half_copy_ins
=
m
.
add_instruction
(
migraphx
::
operation
(
"half_copy_host"
),
{
cont_ins
,
alloc
});
// post-subgraph
auto
abs_ins
=
m
.
add_instruction
(
migraphx
::
operation
(
"abs"
),
{
half_copy_ins
});
// another custom_op
auto
stride_two_ins
=
m
.
add_instruction
(
migraphx
::
operation
(
"stride_two"
),
{
abs_ins
});
// post-subgraph
auto
relu_ins
=
m
.
add_instruction
(
migraphx
::
operation
(
"relu"
),
{
stride_two_ins
});
m
.
add_return
({
relu_ins
});
migraphx
::
compile_options
options
;
options
.
set_offload_copy
();
p
.
compile
(
migraphx
::
target
(
"gpu"
),
options
);
migraphx
::
program_parameters
pp
;
std
::
vector
<
int
>
x_data
(
12
,
-
3
);
std
::
vector
<
float
>
x_data
(
s
.
elements
());
std
::
iota
(
x_data
.
begin
(),
x_data
.
end
(),
0
);
pp
.
add
(
"x"
,
migraphx
::
argument
(
s
,
x_data
.
data
()));
auto
results
=
p
.
eval
(
pp
);
auto
result
=
results
[
0
];
auto
result_vec
=
result
.
as_vector
<
in
t
>
();
std
::
vector
<
in
t
>
expected_result
(
12
,
0
)
;
std
::
fill
(
expected_result
.
begin
()
+
6
,
expected_result
.
end
(),
3
);
EXPECT
(
bool
{
result
==
migraphx
::
argument
(
trans_shape
,
expected_result
.
data
())});
auto
results
=
p
.
eval
(
pp
);
auto
result
=
results
[
0
];
auto
result_vec
=
result
.
as_vector
<
floa
t
>
();
std
::
vector
<
floa
t
>
expected_result
=
{
0
,
0
,
0
,
0
,
4
,
16
}
;
EXPECT
(
bool
{
result
==
migraphx
::
argument
(
migraphx
::
shape
{
migraphx_shape_float_type
,
{
3
,
2
}},
expected_result
.
data
())});
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/api/test_gpu.cpp
View file @
84add0fc
...
...
@@ -25,6 +25,8 @@
#include <hip/hip_runtime_api.h>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include <migraphx/manage_ptr.hpp>
#include "test.hpp"
TEST_CASE
(
load_and_run
)
...
...
@@ -44,11 +46,67 @@ TEST_CASE(load_and_run)
{
pp
.
add
(
name
,
migraphx
::
argument
::
generate
(
param_shapes
[
name
]));
}
auto
outputs
=
p
.
eval
(
pp
);
CHECK
(
shapes_before
.
size
()
==
outputs
.
size
());
CHECK
(
bool
{
shapes_before
.
front
()
==
outputs
.
front
().
get_shape
()});
}
using
hip_ptr
=
MIGRAPHX_MANAGE_PTR
(
void
,
hipFree
);
using
stream_ptr
=
MIGRAPHX_MANAGE_PTR
(
hipStream_t
,
hipStreamDestroy
);
stream_ptr
get_stream
()
{
hipStream_t
stream
;
auto
err
=
hipStreamCreateWithFlags
(
&
stream
,
0
);
EXPECT
(
err
==
hipSuccess
);
return
stream_ptr
{
stream
};
}
hip_ptr
get_hip_buffer
(
size_t
size
)
{
void
*
ptr
;
auto
err
=
hipMalloc
(
&
ptr
,
size
);
EXPECT
(
err
==
hipSuccess
);
return
hip_ptr
{
ptr
};
}
TEST_CASE
(
load_and_run_async
)
{
auto
p
=
migraphx
::
parse_onnx
(
"conv_relu_maxpool_test.onnx"
);
auto
shapes_before
=
p
.
get_output_shapes
();
migraphx
::
compile_options
options
;
options
.
set_offload_copy
(
false
);
p
.
compile
(
migraphx
::
target
(
"gpu"
),
options
);
auto
shapes_after
=
p
.
get_output_shapes
();
CHECK
(
shapes_before
.
size
()
==
1
);
CHECK
(
shapes_before
.
size
()
==
shapes_after
.
size
());
CHECK
(
bool
{
shapes_before
.
front
()
==
shapes_after
.
front
()});
migraphx
::
program_parameters
pp
;
auto
param_shapes
=
p
.
get_parameter_shapes
();
stream_ptr
stream
=
get_stream
();
std
::
vector
<
hip_ptr
>
buffs
;
std
::
vector
<
migraphx
::
argument
>
args
;
for
(
auto
&&
name
:
param_shapes
.
names
())
{
args
.
push_back
(
migraphx
::
argument
::
generate
(
param_shapes
[
name
]));
buffs
.
push_back
(
get_hip_buffer
(
args
.
rbegin
()
->
get_shape
().
bytes
()));
auto
err
=
hipMemcpy
(
buffs
.
rbegin
()
->
get
(),
args
.
rbegin
()
->
data
(),
args
.
rbegin
()
->
get_shape
().
bytes
(),
hipMemcpyHostToDevice
);
EXPECT
(
err
==
hipSuccess
);
pp
.
add
(
name
,
migraphx
::
argument
(
args
.
rbegin
()
->
get_shape
(),
buffs
.
rbegin
()
->
get
()));
}
auto
outputs
=
p
.
run_async
(
pp
,
stream
.
get
());
CHECK
(
shapes_before
.
size
()
==
outputs
.
size
());
CHECK
(
bool
{
shapes_before
.
front
()
==
outputs
.
front
().
get_shape
()});
}
TEST_CASE
(
load_and_run_ctx
)
{
auto
p
=
migraphx
::
parse_onnx
(
"conv_relu_maxpool_test.onnx"
);
...
...
@@ -82,10 +140,10 @@ TEST_CASE(if_pl_test)
migraphx
::
program_parameters
pp
;
auto
param_shapes
=
p
.
get_parameter_shapes
();
auto
xs
=
param_shapes
[
"x"
];
std
::
vector
<
float
>
xd
(
xs
.
bytes
()
/
sizeof
(
float
),
1.0
);
std
::
vector
<
float
>
xd
(
xs
.
elements
(
),
1.0
);
pp
.
add
(
"x"
,
migraphx
::
argument
(
xs
,
xd
.
data
()));
auto
ys
=
param_shapes
[
"y"
];
std
::
vector
<
float
>
yd
(
ys
.
bytes
()
/
sizeof
(
float
),
2.0
);
std
::
vector
<
float
>
yd
(
ys
.
elements
(
),
2.0
);
pp
.
add
(
"y"
,
migraphx
::
argument
(
ys
,
yd
.
data
()));
char
ccond
=
cond
;
pp
.
add
(
"cond"
,
migraphx
::
argument
(
param_shapes
[
"cond"
],
&
ccond
));
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment