Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2f268bc2
Commit
2f268bc2
authored
Jun 12, 2022
by
Paul
Browse files
Merge branch 'develop' into mlir-c
parents
f75c5a38
aa7ff911
Changes
205
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
486 additions
and
204 deletions
+486
-204
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
+9
-1
src/targets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp
...gets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp
+31
-8
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
...gets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
+147
-8
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
+2
-1
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
+73
-2
src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
...argets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
+33
-15
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+20
-8
src/targets/gpu/pack_int8_args.cpp
src/targets/gpu/pack_int8_args.cpp
+9
-3
src/targets/gpu/prefuse_ops.cpp
src/targets/gpu/prefuse_ops.cpp
+76
-0
src/targets/gpu/quant_convolution.cpp
src/targets/gpu/quant_convolution.cpp
+13
-6
src/targets/gpu/schedule_model.cpp
src/targets/gpu/schedule_model.cpp
+8
-8
src/targets/gpu/sync_device.cpp
src/targets/gpu/sync_device.cpp
+4
-4
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+3
-0
src/targets/gpu/write_literals.cpp
src/targets/gpu/write_literals.cpp
+7
-7
src/targets/ref/gemm.cpp
src/targets/ref/gemm.cpp
+4
-2
src/targets/ref/lowering.cpp
src/targets/ref/lowering.cpp
+2
-119
test/api/CMakeLists.txt
test/api/CMakeLists.txt
+3
-2
test/api/test_array_base.cpp
test/api/test_array_base.cpp
+32
-0
test/api/test_gpu.cpp
test/api/test_gpu.cpp
+2
-0
test/api/test_module_construct.cpp
test/api/test_module_construct.cpp
+8
-10
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
View file @
2f268bc2
...
...
@@ -83,11 +83,12 @@ struct shape
}
}
/// Convert single index into a multi-index
constexpr
index_array
multi
(
index_int
idx
)
const
{
index_array
result
;
index_int
tidx
=
idx
;
for
(
std
::
ptr
diff_t
is
=
result
.
size
()
-
1
;
is
>
0
;
is
--
)
for
(
diff_
in
t
is
=
result
.
size
()
-
1
;
is
>
0
;
is
--
)
{
result
[
is
]
=
tidx
%
lens
[
is
];
tidx
=
tidx
/
lens
[
is
];
...
...
@@ -95,6 +96,13 @@ struct shape
result
[
0
]
=
tidx
;
return
result
;
}
/// Convert multi-index into a single index
constexpr
index_int
single
(
index_array
idx
)
const
{
if
(
idx
.
empty
())
return
0
;
return
inner_product
(
lens
.
begin
()
+
1
,
lens
.
end
(),
idx
.
begin
(),
idx
.
back
());
}
constexpr
shape
get_shape
()
const
{
return
*
this
;
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp
View file @
2f268bc2
...
...
@@ -11,7 +11,7 @@ template <class T>
struct
tensor_view_iterator_read
{
T
*
view
;
constexpr
auto
&
operator
()(
std
::
size_
t
n
)
const
constexpr
auto
&
operator
()(
index_in
t
n
)
const
{
MIGRAPHX_ASSERT
(
view
!=
nullptr
);
return
(
*
view
)[
n
];
...
...
@@ -21,18 +21,31 @@ struct tensor_view_iterator_read
template
<
class
T
,
class
Shape
>
struct
tensor_view
{
using
type
=
T
;
using
shape_type
=
Shape
;
using
iterator
=
basic_iota_iterator
<
tensor_view_iterator_read
<
const
tensor_view
>
,
index_int
>
;
using
type
=
T
;
using
shape_type
=
Shape
;
using
index_array
=
typename
Shape
::
index_array
;
using
iterator
=
basic_iota_iterator
<
tensor_view_iterator_read
<
const
tensor_view
>
,
index_int
>
;
constexpr
Shape
get_shape
()
const
{
return
Shape
{};
}
constexpr
auto
size
()
const
{
return
get_shape
().
elements
();
}
template
<
class
U
>
constexpr
T
&
operator
[](
U
i
)
const
struct
index_to_offset
{
MIGRAPHX_ASSERT
(
get_shape
().
index
(
i
)
<
get_shape
().
element_space
());
return
x
[
get_shape
().
index
(
i
)];
index_int
offset
;
template
<
class
U
>
constexpr
index_to_offset
(
U
i
)
:
offset
(
Shape
{}.
index
(
i
))
{
}
};
constexpr
T
&
operator
[](
MIGRAPHX_CAPTURE_SOURCE_LOCATION
(
index_to_offset
)
i
)
const
{
index_to_offset
ito
=
i
;
MIGRAPHX_WARN
(
ito
.
offset
<
get_shape
().
element_space
(),
i
,
"Out of bounds access at offset: "
,
ito
.
offset
);
return
x
[
ito
.
offset
];
}
constexpr
T
*
data
()
const
{
return
x
;
}
...
...
@@ -40,6 +53,13 @@ struct tensor_view
constexpr
auto
begin
()
const
{
return
iterator
{
0
,
{
this
}};
}
constexpr
auto
end
()
const
{
return
iterator
{
this
->
size
(),
{
this
}};
}
constexpr
auto
begin_at
(
index_array
i
)
const
{
MIGRAPHX_ASSERT
(
get_shape
().
single
(
i
)
<
get_shape
().
elements
());
MIGRAPHX_ASSERT
(
get_shape
().
index
(
i
)
<
get_shape
().
element_space
());
return
iterator
{
get_shape
().
single
(
i
),
{
this
}};
}
template
<
class
U
>
constexpr
tensor_view
<
U
,
Shape
>
with
(
U
*
y
)
const
{
...
...
@@ -50,6 +70,9 @@ struct tensor_view
T
*
x
;
};
template
<
class
T
>
using
get_shape_c
=
typename
T
::
shape_type
;
template
<
class
T
,
class
Shape
>
constexpr
tensor_view
<
T
,
Shape
>
make_tensor_view
(
T
*
x
,
Shape
)
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
View file @
2f268bc2
...
...
@@ -6,6 +6,15 @@
namespace
migraphx
{
template
<
class
T
,
class
U
=
T
&&
>
U
private_declval
(
int
);
template
<
class
T
>
T
private_declval
(
long
);
template
<
class
T
>
auto
declval
()
noexcept
->
decltype
(
private_declval
<
T
>
(
0
));
template
<
class
T
>
struct
type_identity
{
...
...
@@ -26,20 +35,88 @@ struct enable_if<true, T>
template
<
bool
B
,
class
T
=
void
>
using
enable_if_t
=
typename
enable_if
<
B
,
T
>::
type
;
template
<
class
From
,
class
To
>
struct
is_convertible
:
bool_constant
<
__is_convertible
(
From
,
To
)
>
template
<
bool
B
,
class
T
,
class
F
>
struct
conditional
{
using
type
=
T
;
};
template
<
class
T
,
class
U
>
struct
is_same
:
false_type
template
<
class
T
,
class
F
>
struct
conditional
<
false
,
T
,
F
>
{
using
type
=
F
;
};
template
<
class
T
>
struct
is_same
<
T
,
T
>
:
true_type
{
};
template
<
bool
B
,
class
T
,
class
F
>
using
conditional_t
=
typename
conditional
<
B
,
T
,
F
>::
type
;
// NOLINTNEXTLINE
#define MIGRAPHX_BUILTIN_TYPE_TRAIT1(name) \
template <class T> \
struct name : bool_constant<__##name(T)> \
{ \
}
// NOLINTNEXTLINE
#define MIGRAPHX_BUILTIN_TYPE_TRAIT2(name) \
template <class T, class U> \
struct name : bool_constant<__##name(T, U)> \
{ \
}
// NOLINTNEXTLINE
#define MIGRAPHX_BUILTIN_TYPE_TRAITN(name) \
template <class... Ts> \
struct name : bool_constant<__##name(Ts...)> \
{ \
}
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_arithmetic);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_destructible);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_nothrow_destructible);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_pointer);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_scalar);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_signed);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_void);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_abstract
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_aggregate
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_array
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_class
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_compound
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_const
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_empty
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_enum
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_final
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_floating_point
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_function
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_fundamental
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_integral
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_literal_type
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_lvalue_reference
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_member_function_pointer
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_member_object_pointer
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_member_pointer
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_object
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_pod
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_polymorphic
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_reference
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_rvalue_reference
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_standard_layout
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_trivial
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_trivially_copyable
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_trivially_destructible
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_union
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_unsigned
);
MIGRAPHX_BUILTIN_TYPE_TRAIT1
(
is_volatile
);
MIGRAPHX_BUILTIN_TYPE_TRAIT2
(
is_assignable
);
MIGRAPHX_BUILTIN_TYPE_TRAIT2
(
is_base_of
);
MIGRAPHX_BUILTIN_TYPE_TRAIT2
(
is_convertible
);
MIGRAPHX_BUILTIN_TYPE_TRAIT2
(
is_nothrow_assignable
);
MIGRAPHX_BUILTIN_TYPE_TRAIT2
(
is_same
);
MIGRAPHX_BUILTIN_TYPE_TRAIT2
(
is_trivially_assignable
);
MIGRAPHX_BUILTIN_TYPE_TRAITN
(
is_constructible
);
MIGRAPHX_BUILTIN_TYPE_TRAITN
(
is_nothrow_constructible
);
MIGRAPHX_BUILTIN_TYPE_TRAITN
(
is_trivially_constructible
);
template
<
class
T
>
struct
remove_reference
...
...
@@ -68,6 +145,68 @@ struct add_pointer : type_identity<typename remove_reference<T>::type*>
template
<
class
T
>
using
add_pointer_t
=
typename
add_pointer
<
T
>::
type
;
template
<
class
...
Ts
>
struct
common_type
;
template
<
class
T
>
struct
common_type
<
T
>
{
using
type
=
T
;
};
template
<
class
T
,
class
U
>
struct
common_type
<
T
,
U
>
{
using
type
=
decltype
(
true
?
declval
<
T
>
()
:
declval
<
U
>
());
};
template
<
class
T
,
class
U
,
class
...
Us
>
struct
common_type
<
T
,
U
,
Us
...
>
{
using
type
=
typename
common_type
<
typename
common_type
<
T
,
U
>::
type
,
Us
...
>::
type
;
};
template
<
class
...
Ts
>
using
common_type_t
=
typename
common_type
<
Ts
...
>::
type
;
constexpr
unsigned
long
int_max
(
unsigned
long
n
)
{
return
(
1u
<<
(
n
*
8
))
-
1
;
}
template
<
class
T
>
constexpr
T
numeric_max
()
{
if
constexpr
(
is_integral
<
T
>
{})
{
if
constexpr
(
is_unsigned
<
T
>
{})
return
int_max
(
sizeof
(
T
))
*
2
;
else
return
int_max
(
sizeof
(
T
));
}
else
if
constexpr
(
is_same
<
T
,
double
>
{})
return
__DBL_MAX__
;
else
if
constexpr
(
is_same
<
T
,
float
>
{})
return
__FLT_MAX__
;
else
if
constexpr
(
is_same
<
T
,
migraphx
::
half
>
{})
return
__FLT16_MAX__
;
else
return
0
;
}
template
<
class
T
>
constexpr
T
numeric_lowest
()
{
if
constexpr
(
is_integral
<
T
>
{})
{
if
constexpr
(
is_unsigned
<
T
>
{})
return
0
;
else
return
-
numeric_max
<
T
>
()
-
1
;
}
else
{
return
-
numeric_max
<
T
>
();
}
}
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
View file @
2f268bc2
...
...
@@ -13,7 +13,8 @@ using diff_int = std::int32_t;
template
<
class
T
,
index_int
N
>
using
vec
=
T
__attribute__
((
ext_vector_type
(
N
)));
using
half
=
_Float16
;
using
half
=
_Float16
;
using
half2
=
migraphx
::
vec
<
half
,
2
>
;
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
View file @
2f268bc2
...
...
@@ -46,6 +46,9 @@ constexpr auto vec_at(T x, I i)
}
}
template
<
class
T
>
using
vec_type
=
decltype
(
vec_at
(
T
{},
0
));
template
<
class
...
Ts
>
constexpr
auto
common_vec_size
()
{
...
...
@@ -57,17 +60,26 @@ constexpr auto common_vec_size()
})(
vec_size
<
Ts
>
()...);
}
// Bools can not be used as a vector type so convert it to uint8
template
<
class
T
>
__device__
__host__
T
*
remove_bool
(
T
*
x
)
{
return
x
;
}
inline
__device__
__host__
uint8_t
*
remove_bool
(
bool
*
x
)
{
return
reinterpret_cast
<
uint8_t
*>
(
x
);
}
template
<
index_int
N
,
class
T
>
__device__
__host__
auto
as_vec
(
T
*
x
)
{
if
constexpr
(
N
==
0
)
if
constexpr
(
N
<
2
)
return
x
;
else
return
reinterpret_cast
<
vec
<
T
,
N
>*>
(
x
);
}
template
<
class
T
,
index_int
N
>
using
safe_vec
=
vec
<
std
::
conditional_t
<
std
::
is_same
<
T
,
bool
>
{},
uint8_t
,
T
>
,
N
>
;
using
safe_vec
=
vec
<
conditional_t
<
is_same
<
T
,
bool
>
{},
uint8_t
,
T
>
,
N
>
;
template
<
class
...
Ts
>
constexpr
auto
vec_transform
(
Ts
...
xs
)
...
...
@@ -89,5 +101,64 @@ constexpr auto vec_transform(Ts... xs)
};
}
// Return a vector type of N from index i in another larger vector
// N will be 2 for half2 packing
template
<
index_int
N
,
class
T
,
class
I
>
constexpr
vec
<
vec_type
<
T
>
,
N
>
vec_packed_at
(
T
x
,
I
i
)
{
if
constexpr
(
vec_size
<
T
>
()
==
0
)
return
vec
<
T
,
N
>
{
x
};
else
{
MIGRAPHX_ASSERT
((
i
+
N
)
<
vec_size
<
T
>
());
vec
<
vec_type
<
T
>
,
N
>
result
=
{
0
};
for
(
int
j
=
0
;
j
<
N
;
j
++
)
{
result
[
j
]
=
x
[
i
+
j
];
}
return
result
;
}
}
template
<
index_int
N
,
class
...
Ts
>
constexpr
auto
vec_packed_transform
(
Ts
...
xs
)
{
return
[
=
](
auto
f
)
{
if
constexpr
(
is_any_vec
<
Ts
...
>
())
{
using
type
=
vec_type
<
decltype
(
f
(
vec_packed_at
<
N
>
(
xs
,
0
)...))
>
;
constexpr
auto
size
=
common_vec_size
<
Ts
...
>
();
safe_vec
<
type
,
size
>
result
=
{
0
};
for
(
int
i
=
0
;
i
<
size
/
N
;
i
++
)
{
// Call the function with packed vectors
safe_vec
<
type
,
N
>
r
=
f
(
vec_packed_at
<
N
>
(
xs
,
i
*
N
)...);
// Copy the packed vectors to the result
for
(
int
j
=
0
;
j
<
N
;
j
++
)
result
[
i
*
N
+
j
]
=
r
[
j
];
}
return
result
;
}
else
{
return
f
(
xs
...);
}
};
}
template
<
class
T
,
class
Op
>
constexpr
auto
vec_reduce
(
T
x
,
Op
op
)
{
if
constexpr
(
vec_size
<
T
>
()
<
2
)
return
x
;
else
{
vec_type
<
T
>
result
=
x
[
0
];
for
(
int
i
=
1
;
i
<
vec_size
<
T
>
();
i
++
)
result
=
op
(
result
,
x
[
i
]);
return
result
;
}
}
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
View file @
2f268bc2
...
...
@@ -50,19 +50,10 @@ constexpr auto shape_step(Shape s, Axis)
});
}
// Bools can not be used as a vector type so convert it to uint8
template
<
class
T
>
__device__
__host__
T
*
remove_bool
(
T
*
x
)
{
return
x
;
}
inline
__device__
__host__
uint8_t
*
remove_bool
(
bool
*
x
)
{
return
reinterpret_cast
<
uint8_t
*>
(
x
);
}
template
<
index_int
N
,
class
T
,
class
Axis
>
__device__
__host__
auto
as_vec
(
T
x
,
Axis
axis
)
{
if
constexpr
(
N
==
0
)
if
constexpr
(
N
<
2
)
return
x
;
else
return
make_tensor_view
(
as_vec
<
N
>
(
remove_bool
(
x
.
data
())),
...
...
@@ -72,7 +63,7 @@ __device__ __host__ auto as_vec(T x, Axis axis)
template
<
index_int
N
,
class
T
,
class
Axis
>
constexpr
auto
tensor_step
(
T
x
,
Axis
axis
)
{
if
constexpr
(
N
==
0
)
if
constexpr
(
N
<
2
)
{
return
x
;
}
...
...
@@ -157,11 +148,11 @@ constexpr auto find_vectorize_size(P pred)
else
if
constexpr
(
decltype
(
pred
(
_c
<
2
>
)){})
return
_c
<
2
>
;
else
return
_c
<
0
>
;
return
_c
<
1
>
;
}
template
<
class
T
>
__host__
__device__
auto
vectorize
(
T
x
)
__host__
__device__
auto
auto_
vectorize
(
T
x
)
{
if
constexpr
(
tensor_vec_size
<
T
>
()
==
0
)
{
...
...
@@ -194,7 +185,7 @@ inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs)
{
MIGRAPHX_ASSERT
(
s
.
strides
[
axis
]
==
0
or
s
.
strides
[
axis
]
==
1
);
MIGRAPHX_ASSERT
(
s
.
lens
[
axis
]
>
0
);
MIGRAPHX_ASSERT
(
n
==
0
or
s
.
lens
[
axis
]
%
n
==
0
);
MIGRAPHX_ASSERT
(
n
==
1
or
s
.
lens
[
axis
]
%
n
==
0
);
if
constexpr
(
s
.
strides
[
axis
]
==
0
)
return
tensor_step
<
n
>
(
x
,
axis
);
else
...
...
@@ -215,7 +206,34 @@ inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs)
inline
__device__
__host__
auto
auto_vectorize
()
{
return
[](
auto
...
xs
)
{
return
[
=
](
auto
f
)
{
auto_vectorize_impl
(
f
,
xs
...);
};
};
return
make_transform
([](
auto
f
,
auto
...
xs
)
{
auto_vectorize_impl
(
f
,
xs
...);
});
}
template
<
index_int
N
,
index_int
Axis
,
class
T
>
__device__
__host__
auto
vectorize_tensor
(
T
x
)
{
constexpr
auto
shape
=
get_shape_c
<
T
>
{};
if
constexpr
(
shape
.
lens
[
Axis
]
==
1
)
return
x
;
else
if
constexpr
(
shape
.
strides
[
Axis
]
==
0
)
return
tensor_step
<
N
>
(
x
,
_c
<
Axis
>
);
else
return
as_vec
<
N
>
(
x
,
_c
<
Axis
>
);
}
template
<
index_int
N
,
index_int
Axis
>
__device__
__host__
auto
vectorize
()
{
return
make_transform
([](
auto
f
,
auto
...
xs
)
{
if
constexpr
(
N
<
2
)
{
f
(
xs
...);
}
else
{
f
(
vectorize_tensor
<
N
,
Axis
>
(
xs
)...);
}
});
}
}
// namespace migraphx
...
...
src/targets/gpu/lowering.cpp
View file @
2f268bc2
...
...
@@ -181,16 +181,11 @@ struct miopen_apply
add_extend_op
(
"pad"
);
add_extend_op
(
"pooling"
);
add_extend_op
(
"prefix_scan_sum"
);
add_extend_op
(
"reduce_max"
);
add_extend_op
(
"reduce_mean"
);
add_extend_op
(
"reduce_min"
);
add_extend_op
(
"reduce_prod"
);
add_extend_op
(
"reduce_sum"
);
add_extend_op
(
"reverse"
);
add_extend_op
(
"rnn_var_sl_last_output"
);
add_extend_op
(
"rnn_var_sl_shift_output"
);
add_extend_op
(
"rnn_var_sl_shift_sequence"
);
add_extend_op
(
"scatter"
);
add_extend_op
(
"scatter
_none
"
);
add_extend_op
(
"softmax"
);
add_extend_op
(
"topk"
);
...
...
@@ -370,8 +365,22 @@ struct miopen_apply
{
apply_map
.
emplace
(
"quant_convolution"
,
[
=
](
instruction_ref
ins
)
{
auto
&&
op
=
any_cast
<
op
::
quant_convolution
>
(
ins
->
get_operator
());
auto
conv
=
miopen_quant_convolution
{
op
,
make_conv
(
op
)};
auto
ws
=
conv
.
compile
(
get_context
(),
ins
->
get_shape
(),
to_shapes
(
ins
->
inputs
()));
shape
ws
;
miopen_quant_convolution
conv
;
auto
compile_quant_conv_with_format
=
[
&
](
bool
format
)
{
conv
=
miopen_quant_convolution
{
op
,
format
,
make_conv
(
op
)};
ws
=
conv
.
compile
(
get_context
(),
ins
->
get_shape
(),
to_shapes
(
ins
->
inputs
()));
};
try
{
compile_quant_conv_with_format
(
int8_x4_format
);
}
catch
(
migraphx
::
exception
&
)
{
// In case no solver supports the default format, retry using the other format.
compile_quant_conv_with_format
(
!
int8_x4_format
);
}
auto
args
=
ins
->
inputs
();
auto
workspace
=
insert_allocation
(
ins
,
ws
,
"workspace"
);
...
...
@@ -381,6 +390,9 @@ struct miopen_apply
});
}
// add_generic_op just constructs the operator with no fields whereas add_extend_op copies over
// the fields Since it doesn't have fields its default constructed
void
add_generic_op
(
const
std
::
string
&
name
)
{
add_generic_op
(
name
,
"gpu::"
+
name
);
}
void
add_generic_op
(
const
std
::
string
&
op_name
,
const
std
::
string
&
gpu_name
)
...
...
src/targets/gpu/pack_int8_args.cpp
View file @
2f268bc2
...
...
@@ -22,10 +22,10 @@ static instruction_ref pad_ins(module& m, instruction_ref ins, int offset)
auto
pad_k
=
(
k
+
3
)
/
4
*
4
;
auto
pad_lens
=
lens
;
pad_lens
[
lens
.
size
()
+
offset
]
=
pad_k
;
std
::
vector
<
int64_t
>
pad_dims
(
lens
.
size
()
*
2
,
0
);
auto
ret_ins
=
ins
;
auto
ret_ins
=
ins
;
if
(
pad_k
!=
k
)
{
std
::
vector
<
int64_t
>
pad_dims
(
lens
.
size
()
*
2
,
0
);
pad_dims
[
lens
.
size
()
+
offset
]
=
pad_k
-
k
;
shape
ps
{
s
.
type
(),
pad_lens
};
auto
ins_out
=
...
...
@@ -118,7 +118,7 @@ void pack_int8_args::apply(module& m) const
assert
(
val
.
contains
(
"int8_x4_format"
));
if
(
not
val
.
at
(
"int8_x4_format"
).
to
<
bool
>
())
{
return
;
continue
;
}
auto
inputs
=
ins
->
inputs
();
auto
lens
=
inputs
.
at
(
0
)
->
get_shape
().
lens
();
...
...
@@ -156,6 +156,12 @@ void pack_int8_args::apply(module& m) const
}
else
if
(
ins
->
name
()
==
"gpu::quant_convolution"
)
{
auto
val
=
ins
->
get_operator
().
to_value
();
if
(
not
val
.
at
(
"int8_x4_format"
).
to
<
bool
>
())
{
continue
;
}
auto
inputs
=
ins
->
inputs
();
auto
packed_x
=
m
.
insert_instruction
(
ins
,
...
...
src/targets/gpu/prefuse_ops.cpp
0 → 100644
View file @
2f268bc2
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/make_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
{
struct
find_layernorm
{
auto
matcher
()
const
{
return
match
::
layernorm
();
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
if
(
not
x_ins
->
get_shape
().
standard
())
x_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
x_ins
);
auto
relements
=
x_ins
->
get_shape
().
lens
().
back
();
if
(
relements
>
1024
or
(
relements
%
4
!=
0
and
relements
>
256
))
return
;
auto
a
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::allocate"
,
{{
"shape"
,
to_value
(
x_ins
->
get_shape
())}}));
m
.
replace_instruction
(
ins
,
make_op
(
"gpu::layernorm"
),
x_ins
,
a
);
}
};
struct
find_triaddlayernorm
{
auto
matcher
()
const
{
auto
add1
=
match
::
name
(
"add"
)(
match
::
none_of
(
match
::
is_constant
()),
match
::
args
(
match
::
any
().
bind
(
"z1"
),
match
::
any
().
bind
(
"z2"
)));
auto
add2
=
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
add1
,
match
::
any
().
bind
(
"z3"
)));
return
match
::
layernorm
()(
match
::
var
(
"x"
)(
add2
));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"z1"
];
auto
y_ins
=
r
.
instructions
[
"z2"
];
auto
z_ins
=
r
.
instructions
[
"z3"
];
for
(
auto
*
pins
:
{
&
x_ins
,
&
y_ins
,
&
z_ins
})
{
if
(
not
(
*
pins
)
->
get_shape
().
standard
())
*
pins
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
*
pins
);
}
auto
relements
=
x_ins
->
get_shape
().
lens
().
back
();
if
(
relements
>
1024
or
(
relements
%
4
!=
0
and
relements
>
256
))
return
;
auto
a
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::allocate"
,
{{
"shape"
,
to_value
(
x_ins
->
get_shape
())}}));
m
.
replace_instruction
(
ins
,
make_op
(
"gpu::triadd_layernorm"
),
x_ins
,
y_ins
,
z_ins
,
a
);
}
};
}
// namespace
void
prefuse_ops
::
apply
(
module
&
m
)
const
{
match
::
find_matches
(
m
,
find_triaddlayernorm
{},
find_layernorm
{});
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/quant_convolution.cpp
View file @
2f268bc2
...
...
@@ -16,8 +16,8 @@ argument miopen_quant_convolution::compute(context& ctx,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
(),
true
);
auto
w_desc
=
make_tensor
(
args
[
1
].
get_shape
(),
true
);
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
(),
int8_x4_format
);
auto
w_desc
=
make_tensor
(
args
[
1
].
get_shape
(),
int8_x4_format
);
auto
y_desc
=
make_tensor
(
output_shape
);
float
alpha
=
1
;
...
...
@@ -49,8 +49,8 @@ shape miopen_quant_convolution::compile(context& ctx,
std
::
vector
<
shape
>
inputs
)
{
shape
workspace_shape
{};
auto
x_desc
=
make_tensor
(
inputs
[
0
],
true
);
auto
w_desc
=
make_tensor
(
inputs
[
1
],
true
);
auto
x_desc
=
make_tensor
(
inputs
[
0
],
int8_x4_format
);
auto
w_desc
=
make_tensor
(
inputs
[
1
],
int8_x4_format
);
auto
y_desc
=
make_tensor
(
output_shape
);
std
::
size_t
workspace_size
=
0
;
...
...
@@ -62,8 +62,15 @@ shape miopen_quant_convolution::compile(context& ctx,
&
workspace_size
);
workspace_shape
=
shape
{
shape
::
int8_type
,
{
workspace_size
}};
auto
arg_vec4_x
=
to_gpu
(
generate_argument
(
pack_int8_shape
(
inputs
[
0
])));
auto
arg_vec4_w
=
to_gpu
(
generate_argument
(
pack_int8_shape
(
inputs
[
1
])));
auto
x_shape
=
inputs
[
0
];
auto
w_shape
=
inputs
[
1
];
if
(
int8_x4_format
)
{
x_shape
=
pack_int8_shape
(
x_shape
);
w_shape
=
pack_int8_shape
(
w_shape
);
}
auto
arg_vec4_x
=
to_gpu
(
generate_argument
(
x_shape
));
auto
arg_vec4_w
=
to_gpu
(
generate_argument
(
w_shape
));
auto
y
=
allocate_gpu
(
output_shape
);
auto
workspace
=
allocate_gpu
(
workspace_shape
);
...
...
src/targets/gpu/schedule_model.cpp
View file @
2f268bc2
...
...
@@ -77,28 +77,28 @@ MIGRAPHX_REGISTER_OP(wait_event)
MIGRAPHX_REGISTER_OP
(
set_stream
)
std
::
size_t
schedule_model
::
concurrency
()
const
{
return
streams
;
}
void
schedule_model
::
sched
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
n
)
const
void
schedule_model
::
sched
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
n
)
const
{
auto
last_stream
=
std
::
find_if
(
std
::
make_reverse_iterator
(
ins
),
std
::
make_reverse_iterator
(
p
.
begin
()),
std
::
make_reverse_iterator
(
m
.
begin
()),
[
&
](
auto
&&
i
)
{
return
i
.
name
()
==
"gpu::set_stream"
;
});
if
(
last_stream
!=
std
::
make_reverse_iterator
(
p
.
begin
()))
if
(
last_stream
!=
std
::
make_reverse_iterator
(
m
.
begin
()))
{
auto
&&
op
=
any_cast
<
set_stream
>
(
last_stream
->
get_operator
());
// If the same stream was set earlier then skip
if
(
op
.
stream
==
n
)
return
;
}
p
.
insert_instruction
(
ins
,
set_stream
{
n
});
m
.
insert_instruction
(
ins
,
set_stream
{
n
});
}
void
schedule_model
::
wait
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
void
schedule_model
::
wait
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
{
p
.
insert_instruction
(
ins
,
wait_event
{
wait_id
});
m
.
insert_instruction
(
ins
,
wait_event
{
wait_id
});
}
void
schedule_model
::
record
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
void
schedule_model
::
record
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
{
p
.
insert_instruction
(
std
::
next
(
ins
),
record_event
{
wait_id
});
m
.
insert_instruction
(
std
::
next
(
ins
),
record_event
{
wait_id
});
}
static
std
::
unordered_map
<
std
::
string
,
std
::
size_t
>
create_weight_map
()
...
...
src/targets/gpu/sync_device.cpp
View file @
2f268bc2
...
...
@@ -8,9 +8,9 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
void
sync_device
::
apply
(
module
&
p
)
const
void
sync_device
::
apply
(
module
&
m
)
const
{
auto
last
=
std
::
prev
(
p
.
end
());
auto
last
=
std
::
prev
(
m
.
end
());
if
(
last
->
name
()
==
"@return"
)
{
auto
inputs
=
last
->
inputs
();
...
...
@@ -18,10 +18,10 @@ void sync_device::apply(module& p) const
return
(
i
->
name
()
==
"hip::copy_from_gpu"
);
}))
{
auto
sync_in
=
p
.
insert_instruction
(
last
,
make_op
(
"hip::sync_stream"
),
inputs
);
auto
sync_in
=
m
.
insert_instruction
(
last
,
make_op
(
"hip::sync_stream"
),
inputs
);
if
(
not
inputs
.
empty
())
{
p
.
replace_instruction
(
inputs
.
front
(),
sync_in
);
m
.
replace_instruction
(
inputs
.
front
(),
sync_in
);
}
}
}
...
...
src/targets/gpu/target.cpp
View file @
2f268bc2
...
...
@@ -32,6 +32,7 @@
#include <migraphx/gpu/eliminate_workspace.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/schedule_model.hpp>
...
...
@@ -96,6 +97,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_algebra
{},
simplify_reshapes
{},
simplify_algebra
{},
prefuse_ops
{},
dead_code_elimination
{},
auto_contiguous
{},
simplify_reshapes
{},
propagate_constant
{},
...
...
src/targets/gpu/write_literals.cpp
View file @
2f268bc2
...
...
@@ -11,25 +11,25 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_COPY_LITERALS
)
void
write_literals
::
apply
(
module
&
p
)
const
void
write_literals
::
apply
(
module
&
m
)
const
{
assert
(
ctx
!=
nullptr
);
std
::
size_t
n
=
0
;
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
==
"@literal"
)
{
if
(
enabled
(
MIGRAPHX_COPY_LITERALS
{}))
{
literal
l
=
ins
->
get_literal
();
auto
pre
=
p
.
add_literal
(
l
);
auto
alloc
=
p
.
insert_instruction
(
std
::
next
(
pre
),
hip_allocate
{
l
.
get_shape
()});
p
.
replace_instruction
(
ins
,
hip_copy_to_gpu
{},
pre
,
alloc
);
auto
pre
=
m
.
add_literal
(
l
);
auto
alloc
=
m
.
insert_instruction
(
std
::
next
(
pre
),
hip_allocate
{
l
.
get_shape
()});
m
.
replace_instruction
(
ins
,
hip_copy_to_gpu
{},
pre
,
alloc
);
}
else
{
std
::
string
id
=
p
.
name
()
+
":@literal:"
+
std
::
to_string
(
n
);
p
.
replace_instruction
(
ins
,
hip_copy_literal
{
ins
->
get_literal
(),
id
});
std
::
string
id
=
m
.
name
()
+
":@literal:"
+
std
::
to_string
(
n
);
m
.
replace_instruction
(
ins
,
hip_copy_literal
{
ins
->
get_literal
(),
id
});
n
++
;
}
}
...
...
src/targets/ref/gemm.cpp
View file @
2f268bc2
#include <migraphx/ref/gemm.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/
shape_for_each
.hpp>
#include <migraphx/
par_for
.hpp>
#include <blaze/math/CustomMatrix.h>
namespace
migraphx
{
...
...
@@ -74,8 +74,10 @@ void migemm_impl(
assert
(
amat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
cmat
.
get_shape
().
lens
()[
dim_0
]
==
amat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
cmat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_1
]);
auto
cs
=
cmat
.
get_shape
();
shape_for_each
(
cmat
.
get_shape
(),
[
&
](
const
auto
&
c_idx
)
{
par_for
(
cs
.
elements
(),
[
&
](
auto
i
)
{
auto
c_idx
=
cs
.
multi
(
i
);
auto
a_idx
=
c_idx
;
auto
b_idx
=
c_idx
;
double
s
=
0.0
;
...
...
src/targets/ref/lowering.cpp
View file @
2f268bc2
...
...
@@ -16,7 +16,6 @@
#include <migraphx/op/loop.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/softmax.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
...
...
@@ -335,109 +334,6 @@ struct ref_im2col
};
MIGRAPHX_REGISTER_OP
(
ref_im2col
)
struct
max_pool
{
static
std
::
string
name
()
{
return
"max"
;
}
template
<
class
T
>
static
T
start
()
{
return
std
::
numeric_limits
<
T
>::
lowest
();
}
static
double
apply
(
double
x
,
double
y
)
{
double
m
=
std
::
max
(
x
,
y
);
return
(
m
);
}
static
double
final
(
double
x
,
std
::
size_t
)
{
return
(
x
);
}
};
struct
avg_pool
{
static
std
::
string
name
()
{
return
"average"
;
}
template
<
class
T
>
static
double
start
()
{
return
0.0
;
}
static
double
apply
(
double
x
,
double
y
)
{
return
x
+
y
;
}
static
double
final
(
double
x
,
std
::
size_t
y
)
{
return
(
y
==
0
)
?
0.0
:
(
x
/
y
);
}
};
template
<
class
Op
>
struct
ref_pooling
:
auto_register_op
<
ref_pooling
<
Op
>>
{
ref_pooling
()
=
default
;
ref_pooling
(
op
::
pooling
pop
)
:
op
(
std
::
move
(
pop
))
{}
op
::
pooling
op
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
migraphx
::
reflect
(
self
.
op
,
f
);
}
std
::
string
name
()
const
{
return
"ref::pooling_"
+
Op
::
name
();
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
normalize_compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
auto
in_s
=
input
.
get_shape
();
auto
in_lens
=
in_s
.
lens
();
std
::
vector
<
std
::
size_t
>
vec_len
(
in_lens
.
begin
()
+
2
,
in_lens
.
end
());
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
idx_o
=
output_shape
.
multi
(
i
);
auto
n_dim
=
idx_o
.
size
();
std
::
vector
<
std
::
size_t
>
win_start
;
std
::
vector
<
std
::
size_t
>
win_size
;
for
(
std
::
size_t
dim
=
2
;
dim
<
n_dim
;
++
dim
)
{
auto
d_2
=
dim
-
2
;
int
start
=
static_cast
<
int
>
(
idx_o
[
dim
]
*
op
.
stride
[
d_2
])
-
static_cast
<
int
>
(
op
.
padding
[
d_2
]);
int
end
=
std
::
min
(
start
+
op
.
lengths
[
d_2
],
in_lens
[
dim
]);
start
=
std
::
max
(
start
,
0
);
win_start
.
push_back
(
start
);
win_size
.
push_back
(
end
-
start
);
}
shape
win_shape
{
output_shape
.
type
(),
win_size
};
auto
pool_size
=
win_shape
.
elements
();
double
acc
=
Op
::
template
start
<
type
>();
shape_for_each
(
win_shape
,
[
&
](
auto
idx_w
)
{
auto
idx
=
idx_o
;
std
::
transform
(
idx_w
.
begin
(),
idx_w
.
end
(),
win_start
.
begin
(),
idx
.
begin
()
+
2
,
[](
auto
ii
,
auto
jj
)
{
return
ii
+
jj
;
});
if
(
std
::
all_of
(
idx
.
begin
()
+
2
,
idx
.
end
(),
[
&
](
auto
ii
)
{
return
ii
>=
0
;
})
and
idx
<
in_lens
)
{
acc
=
Op
::
apply
(
acc
,
input
[
in_s
.
index
(
idx
)]);
}
});
output
[
i
]
=
type
(
Op
::
final
(
acc
,
pool_size
));
});
});
return
result
;
}
};
struct
ref_op
{
operation
op
=
op
::
identity
{};
...
...
@@ -609,7 +505,7 @@ struct ref_unary : auto_register_op<ref_unary<Op>>
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
s
=
inputs
.
at
(
0
);
const
auto
&
s
=
inputs
.
at
(
0
);
return
{
s
.
type
(),
s
.
lens
()};
}
...
...
@@ -783,11 +679,7 @@ struct ref_apply
init
();
for
(
auto
it
:
iterator_for
(
*
mod
))
{
if
(
it
->
name
()
==
"pooling"
)
{
apply_pooling
(
it
);
}
else
if
(
apply_map
.
count
(
it
->
name
())
>
0
)
if
(
apply_map
.
count
(
it
->
name
())
>
0
)
{
apply_map
.
at
(
it
->
name
())(
it
);
}
...
...
@@ -815,15 +707,6 @@ struct ref_apply
auto
&&
op
=
any_cast
<
Op
>
(
ins
->
get_operator
());
mod
->
replace_instruction
(
ins
,
T
{
op
},
ins
->
inputs
());
}
void
apply_pooling
(
instruction_ref
ins
)
const
{
auto
&&
op
=
any_cast
<
op
::
pooling
>
(
ins
->
get_operator
());
if
(
op
.
mode
==
op
::
pooling_mode
::
max
)
mod
->
replace_instruction
(
ins
,
ref_pooling
<
max_pool
>
{
op
},
ins
->
inputs
());
else
if
(
op
.
mode
==
op
::
pooling_mode
::
average
)
mod
->
replace_instruction
(
ins
,
ref_pooling
<
avg_pool
>
{
op
},
ins
->
inputs
());
}
};
void
lowering
::
apply
(
module
&
m
)
const
{
ref_apply
{
&
m
}.
apply
();
}
...
...
test/api/CMakeLists.txt
View file @
2f268bc2
function
(
add_api_test TEST_NAME TEST_SRC TEST_DIR
)
set
(
NAME test_api_
${
TEST_NAME
}
)
add_executable
(
${
NAME
}
EXCLUDE_FROM_ALL
${
TEST_SRC
}
)
...
...
@@ -10,6 +9,7 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
add_dependencies
(
check
${
NAME
}
)
endfunction
()
add_api_test
(
array_base test_array_base.cpp
${
TEST_ONNX_DIR
}
)
add_api_test
(
assign test_assign.cpp
${
TEST_ONNX_DIR
}
)
add_api_test
(
custom_op test_custom_op.cpp
${
TEST_ONNX_DIR
}
)
add_api_test
(
compile_options test_compile_options.cpp
${
TEST_ONNX_DIR
}
)
...
...
@@ -19,7 +19,8 @@ add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR})
add_api_test
(
save_load test_save_load.cpp
${
TEST_ONNX_DIR
}
)
add_api_test
(
op test_op_construct.cpp
${
TEST_ONNX_DIR
}
)
add_api_test
(
tf_parser test_tf_parser.cpp
${
TEST_TF_DIR
}
)
# GPU-based tests
if
(
MIGRAPHX_ENABLE_GPU
)
add_api_test
(
gpu test_gpu.cpp
${
TEST_ONNX_DIR
}
)
# GPU-based tests
target_link_libraries
(
test_api_gpu migraphx_gpu
)
endif
()
test/api/test_array_base.cpp
0 → 100644
View file @
2f268bc2
#include <migraphx/migraphx.hpp>
#include "test.hpp"
struct
array2
:
migraphx
::
array_base
<
array2
>
{
std
::
vector
<
int
>
v
;
array2
()
=
default
;
array2
(
std
::
initializer_list
<
int
>
x
)
:
v
(
x
)
{}
std
::
size_t
size
()
const
{
return
v
.
size
();
}
int
operator
[](
std
::
size_t
i
)
const
{
return
v
[
i
];
}
};
TEST_CASE
(
iterators
)
{
array2
a
=
{
1
,
2
,
3
};
EXPECT
(
bool
{
std
::
equal
(
a
.
begin
(),
a
.
end
(),
a
.
v
.
begin
())});
}
TEST_CASE
(
front_back
)
{
array2
a
=
{
1
,
2
,
3
};
EXPECT
(
a
.
front
()
==
1
);
EXPECT
(
a
.
back
()
==
3
);
}
TEST_CASE
(
empty
)
{
array2
a
=
{
1
,
2
,
3
};
EXPECT
(
not
a
.
empty
());
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/api/test_gpu.cpp
View file @
2f268bc2
#include <numeric>
#include <hip/hip_runtime_api.h>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"
...
...
@@ -38,6 +39,7 @@ TEST_CASE(load_and_run_ctx)
pp
.
add
(
name
,
migraphx
::
argument
::
generate
(
param_shapes
[
name
]));
}
auto
ctx
=
p
.
experimental_get_context
();
EXPECT
(
ctx
.
get_queue
<
hipStream_t
>
()
!=
nullptr
);
p
.
eval
(
pp
);
ctx
.
finish
();
}
...
...
test/api/test_module_construct.cpp
View file @
2f268bc2
...
...
@@ -3,23 +3,21 @@
#include <migraphx/migraphx.hpp>
#include "test.hpp"
TEST_CASE
(
add_
op
)
TEST_CASE
(
add_
literals
)
{
migraphx
::
program
p
;
migraphx
::
module
m
=
p
.
get_main_module
();
migraphx
::
shape
param_shape
{
migraphx_shape_float_type
,
{
3
,
3
}};
auto
x
=
m
.
add_parameter
(
"x"
,
param_shape
);
auto
y
=
m
.
add_parameter
(
"y"
,
param_shape
);
std
::
vector
<
float
>
x_values
(
9
,
1
);
auto
x
=
m
.
add_literal
(
param_shape
,
x_values
.
data
());
std
::
vector
<
float
>
y_values
(
9
,
-
1
);
auto
y
=
m
.
add_literal
(
param_shape
,
y_values
.
data
());
auto
add_op
=
migraphx
::
operation
(
"add"
);
auto
r
=
m
.
add_instruction
(
add_op
,
{
x
,
y
});
m
.
add_return
({
r
});
// run on ref target
p
.
compile
(
migraphx
::
target
(
"ref"
));
migraphx
::
program_parameters
pp
;
std
::
vector
<
float
>
x_data
(
9
,
1
);
std
::
vector
<
float
>
y_data
(
9
,
-
1
);
pp
.
add
(
"x"
,
migraphx
::
argument
(
param_shape
,
x_data
.
data
()));
pp
.
add
(
"y"
,
migraphx
::
argument
(
param_shape
,
y_data
.
data
()));
auto
outputs
=
p
.
eval
(
pp
);
auto
output
=
outputs
[
0
];
std
::
vector
<
float
>
expected
(
9
,
0
);
...
...
@@ -60,16 +58,16 @@ TEST_CASE(if_then_else_op)
p
.
compile
(
migraphx
::
target
(
"ref"
));
auto
outputs
=
p
.
eval
({{
"cond"
,
migraphx
::
argument
(
cond_s
,
&
cond
)},
{
"x"
,
x_arg
},
{
"y"
,
y_arg
}});
return
outputs
;
return
outputs
[
0
]
;
};
// then branch
auto
then_res
=
run_prog
(
true
);
CHECK
(
bool
{
then_res
[
0
]
==
x_arg
});
CHECK
(
bool
{
then_res
==
x_arg
});
// else branch
auto
else_res
=
run_prog
(
false
);
CHECK
(
bool
{
else_res
[
0
]
==
y_arg
});
CHECK
(
bool
{
else_res
==
y_arg
});
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Prev
1
…
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment