Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9bff4331
Commit
9bff4331
authored
Mar 21, 2023
by
Paul
Browse files
Merge
parents
214b313f
94a7f6ee
Changes
274
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
540 additions
and
210 deletions
+540
-210
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+292
-51
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
+1
-8
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
+9
-7
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
...gets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
+24
-0
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
+39
-2
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+16
-3
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+24
-60
src/targets/gpu/perfdb.cpp
src/targets/gpu/perfdb.cpp
+4
-0
src/targets/gpu/prefuse_ops.cpp
src/targets/gpu/prefuse_ops.cpp
+13
-6
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+4
-11
src/targets/ref/include/migraphx/ref/target.hpp
src/targets/ref/include/migraphx/ref/target.hpp
+0
-2
src/targets/ref/lowering.cpp
src/targets/ref/lowering.cpp
+9
-9
test/CMakeLists.txt
test/CMakeLists.txt
+39
-38
test/api/CMakeLists.txt
test/api/CMakeLists.txt
+4
-2
test/api/test_cpu.cpp
test/api/test_cpu.cpp
+2
-2
test/api/test_gpu.cpp
test/api/test_gpu.cpp
+1
-0
test/api/test_save_load.cpp
test/api/test_save_load.cpp
+0
-1
test/fpga/get_target_assignments.cpp
test/fpga/get_target_assignments.cpp
+0
-1
test/gpu/hip.cpp
test/gpu/hip.cpp
+50
-0
test/gpu/jit.cpp
test/gpu/jit.cpp
+9
-7
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
9bff4331
...
...
@@ -103,10 +103,10 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
#else
constexpr
index_int
lanes_per_thread
=
64
;
#endif
using
type
=
decltype
(
f
(
0
));
using
type
=
decltype
(
index
::
invoke_loop
(
f
,
0
,
_c
<
0
>
));
__shared__
type
buffer
[
idx
.
max_nlocal
()
/
lanes_per_thread
];
type
x
=
init
;
idx
.
local_stride
(
n
,
[
&
](
auto
i
)
{
x
=
op
(
x
,
f
(
i
));
});
idx
.
local_stride
(
n
,
[
&
](
auto
i
,
auto
d
)
{
x
=
op
(
x
,
index
::
invoke_loop
(
f
,
i
,
d
));
});
dpp_reduce
(
x
,
op
);
const
auto
ldsidx
=
idx
.
local
/
lanes_per_thread
;
...
...
@@ -128,10 +128,10 @@ template <class Op, class T, class Index, class F>
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
{
MIGRAPHX_ASSERT
(
idx
.
max_nlocal
()
==
idx
.
nlocal
());
using
type
=
decltype
(
f
(
0
));
using
type
=
decltype
(
index
::
invoke_loop
(
f
,
0
,
_c
<
0
>
));
__shared__
type
buffer
[
idx
.
max_nlocal
()];
type
x
=
init
;
idx
.
local_stride
(
n
,
[
&
](
auto
i
)
{
x
=
op
(
x
,
f
(
i
));
});
idx
.
local_stride
(
n
,
[
&
](
auto
i
,
auto
d
)
{
x
=
op
(
x
,
index
::
invoke_loop
(
f
,
i
,
d
));
});
buffer
[
idx
.
local
]
=
x
;
__syncthreads
();
...
...
@@ -167,6 +167,25 @@ constexpr auto reduce_slice(Input input, T i)
namespace
reduce
{
struct
inner_storage_tag
{
};
template
<
class
T
>
using
is_inner_storage
=
is_base_of
<
inner_storage_tag
,
remove_cv_t
<
remove_reference_t
<
T
>>>
;
template
<
class
R
,
class
F
>
struct
storage_access
:
F
{
using
type
=
R
;
};
template
<
class
R
,
class
F
>
constexpr
storage_access
<
R
,
F
>
make_storage_access
(
F
f
)
{
return
{{
f
}};
}
template
<
class
Slicer
,
class
F
>
constexpr
auto
sliced
(
Slicer
slicer
,
F
f
)
{
...
...
@@ -191,20 +210,140 @@ constexpr auto compute_reduce_axis()
template
<
class
Input
,
index_int
Axis
>
using
with_axis
=
decltype
(
compute_reduce_axis
<
Input
,
Axis
>
());
template
<
class
Derived
>
struct
reducer_base
{
template
<
class
T
>
__device__
auto
make_inner_slice
(
T
x
)
const
{
if
constexpr
(
is_inner_storage
<
T
>
{})
{
return
x
;
}
else
{
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
auto
t
=
derived
.
slice
(
x
);
return
make_storage_access
<
typename
decltype
(
t
)
::
type
>
([
=
](
auto
i
,
auto
...)
->
auto
&
{
return
t
[
i
];
});
}
}
template
<
class
T
,
class
...
Ts
>
constexpr
auto
get_size
(
T
&&
x
,
[[
maybe_unused
]]
Ts
&&
...
xs
)
const
{
MIGRAPHX_ASSERT
(
get_size
(
x
)
==
get_size
(
xs
...));
return
get_size
(
x
);
}
template
<
class
T
,
class
...
Ts
>
constexpr
auto
get_size
(
T
&&
x
)
const
{
if
constexpr
(
is_inner_storage
<
T
>
{})
{
return
x
.
rsize
();
}
else
{
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
auto
t
=
derived
.
slice
(
x
);
return
t
.
size
();
}
}
template
<
class
F
>
__device__
auto
inner_sliced
(
F
f
)
const
{
return
[
=
](
auto
&&
...
xs
)
{
return
f
(
get_size
(
xs
...),
make_inner_slice
(
xs
)...);
};
}
template
<
class
T
>
static
__device__
typename
T
::
type
&
decl_inner_storage
(
const
T
&
);
template
<
class
F
>
__device__
auto
inner
(
F
f
)
const
{
return
this
->
inner_sliced
([
=
](
auto
n
,
auto
&&
...
xs
)
{
using
result_type
=
decltype
(
f
(
decl_inner_storage
(
xs
)...));
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
if
constexpr
(
is_void
<
result_type
>
{})
{
derived
.
inner_void_impl
(
f
,
n
,
xs
...);
}
else
{
return
derived
.
template
inner_impl
<
result_type
>(
f
,
n
,
xs
...);
}
});
}
template
<
class
Op
,
class
T
,
class
Read
>
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
{
return
this
->
inner_sliced
([
=
](
auto
n
,
auto
&&
...
xs
)
{
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
return
derived
.
reduce_impl
(
op
,
init
,
read
,
n
,
xs
...);
});
}
template
<
class
Op
,
class
T
>
__device__
auto
reduce
(
Op
op
,
T
init
)
const
{
return
this
->
reduce
(
op
,
init
,
op
::
id
{});
}
template
<
class
F
>
__device__
void
outer
(
F
f
)
const
{
f
();
}
template
<
class
Input
>
constexpr
auto
elements
()
const
{
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
using
reduce_type
=
decltype
(
derived
.
slice
(
Input
{}));
using
value_type
=
typename
Input
::
type
;
constexpr
auto
relements
=
get_shape_c
<
reduce_type
>
{}.
elements
();
if
constexpr
(
vec_size
<
value_type
>
()
>
1
)
return
relements
*
vec_size
<
value_type
>
();
else
return
relements
;
}
};
struct
block
{
template
<
class
Slicer
>
struct
reducer
struct
reducer
:
reducer_base
<
reducer
<
Slicer
>>
{
index
idx
;
Slicer
slice
;
template
<
class
Op
,
class
T
,
class
Read
>
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
template
<
class
T
,
index_int
N
,
class
Size
>
struct
inner_storage
:
inner_storage_tag
{
using
type
=
T
;
array
<
T
,
N
>
arr
;
constexpr
Size
rsize
()
const
{
return
{};
}
template
<
class
U
,
class
V
>
constexpr
auto
&
operator
()(
U
,
V
d
)
const
{
return
arr
[
d
];
}
template
<
class
U
,
class
V
>
constexpr
auto
&
operator
()(
U
,
V
d
)
{
return
arr
[
d
];
}
};
template
<
class
Op
,
class
T
,
class
Read
,
class
N
,
class
...
Ts
>
__device__
auto
reduce_impl
(
Op
op
,
T
init
,
Read
read
,
N
n
,
Ts
&&
...
xs
)
const
{
return
sliced
(
slice
,
[
=
](
auto
x
,
auto
...
xs
)
{
return
block_reduce
(
idx
,
op
,
init
,
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
return
vec_reduce
(
read
(
x
[
j
],
xs
[
j
]...),
op
);
});
return
block_reduce
(
idx
,
op
,
init
,
n
,
[
&
](
auto
j
,
auto
d
)
{
return
vec_reduce
(
read
(
xs
(
j
,
d
)...),
op
);
});
}
...
...
@@ -215,31 +354,99 @@ struct block
f
();
}
template
<
class
F
>
__device__
auto
inner
(
F
f
)
const
template
<
class
F
,
class
N
,
class
...
Ts
>
__device__
void
inner_void_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
idx
.
local_stride
(
n
,
[
&
](
auto
j
,
auto
d
)
{
f
(
xs
(
j
,
d
)...);
});
}
template
<
class
R
,
class
F
,
class
N
,
class
...
Ts
>
__device__
auto
inner_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
return
sliced
(
slice
,
[
=
](
auto
x
,
auto
...
xs
)
{
idx
.
local_stride
(
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
f
(
x
[
j
],
xs
[
j
]...);
});
using
max_iterations
=
decltype
(
idx
.
max_local_stride_iterations
(
n
));
inner_storage
<
R
,
max_iterations
{},
N
>
storage
;
idx
.
local_stride
(
n
,
[
&
](
auto
j
,
auto
d
)
{
storage
(
j
,
d
)
=
f
(
xs
(
j
,
d
)...);
});
return
storage
;
}
};
template
<
class
Slicer
>
static
__device__
auto
make
(
index
idx
,
Slicer
slicer
)
{
return
reducer
<
Slicer
>
{{},
idx
,
slicer
};
}
template
<
class
Output
,
class
F
>
static
__device__
void
run
(
F
f
)
{
auto
idx
=
make_index
();
constexpr
auto
nelements
=
get_shape_c
<
Output
>
{}.
elements
();
idx
.
global_stride
(
nelements
*
idx
.
nlocal
(),
[
&
](
auto
i
)
{
const
auto
out_idx
=
get_shape_c
<
Output
>
{}.
multi
(
i
/
idx
.
nlocal
());
f
(
out_idx
,
make
(
idx
,
[
&
](
auto
input
)
{
return
reduce_slice
<
Output
>
(
input
,
out_idx
);
}));
});
}
};
struct
block_large
{
template
<
class
Slicer
>
struct
reducer
:
reducer_base
<
reducer
<
Slicer
>>
{
index
idx
;
Slicer
slice
;
template
<
class
Size
,
class
F
>
struct
inner_storage
:
inner_storage_tag
{
using
type
=
remove_reference_t
<
decltype
(
declval
<
F
>
()(
0
,
_c
<
0
>
))
>
;
F
f
;
constexpr
Size
rsize
()
const
{
return
{};
}
template
<
class
U
,
class
V
>
constexpr
auto
operator
()(
U
j
,
V
d
)
const
{
return
f
(
j
,
d
);
}
};
template
<
class
Size
,
class
F
>
static
constexpr
inner_storage
<
Size
,
F
>
make_inner_storage
(
Size
,
F
f
)
{
return
{{},
{
f
}};
}
template
<
class
Op
,
class
T
,
class
Read
,
class
N
,
class
...
Ts
>
__device__
auto
reduce_impl
(
Op
op
,
T
init
,
Read
read
,
N
n
,
Ts
&&
...
xs
)
const
{
return
block_reduce
(
idx
,
op
,
init
,
index_int
{
n
},
[
&
](
auto
j
,
auto
d
)
{
return
vec_reduce
(
read
(
xs
(
j
,
d
)...),
op
);
});
}
template
<
class
Input
>
constexpr
auto
elements
(
)
const
template
<
class
F
>
__device__
void
outer
(
F
f
)
const
{
using
reduce_type
=
decltype
(
slice
(
Input
{}));
using
value_type
=
typename
Input
::
type
;
constexpr
auto
relements
=
get_shape_c
<
reduce_type
>
{}.
elements
();
if
constexpr
(
vec_size
<
value_type
>
()
>
1
)
return
relements
*
vec_size
<
value_type
>
();
else
return
relements
;
if
(
idx
.
local
==
0
)
f
();
}
template
<
class
F
,
class
N
,
class
...
Ts
>
__device__
void
inner_void_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
idx
.
local_stride
(
index_int
{
n
},
[
&
](
auto
j
,
auto
d
)
{
f
(
xs
(
j
,
d
)...);
});
}
template
<
class
R
,
class
F
,
class
N
,
class
...
Ts
>
__device__
auto
inner_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
return
make_inner_storage
(
n
,
[
=
](
auto
j
,
auto
d
)
{
return
f
(
xs
(
j
,
d
)...);
});
}
};
template
<
class
Slicer
>
static
__device__
auto
make
(
index
idx
,
Slicer
slicer
)
{
return
reducer
<
Slicer
>
{
idx
,
slicer
};
return
reducer
<
Slicer
>
{
{},
idx
,
slicer
};
}
template
<
class
Output
,
class
F
>
...
...
@@ -257,22 +464,40 @@ struct block
struct
lane
{
template
<
class
Slicer
>
struct
reducer
struct
reducer
:
reducer_base
<
reducer
<
Slicer
>>
{
index
idx
;
Slicer
slice
;
template
<
class
Op
,
class
T
,
class
Read
>
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
template
<
class
Size
,
class
F
>
struct
inner_storage
:
inner_storage_tag
{
return
sliced
(
slice
,
[
=
](
auto
x
,
auto
...
xs
)
{
using
type
=
typename
decltype
(
x
)
::
type
;
type
r
=
init
;
for
(
index_int
j
=
0
;
j
<
x
.
get_shape
().
elements
();
j
++
)
{
r
=
op
(
r
,
read
(
x
[
j
],
xs
[
j
]...));
}
return
r
;
});
using
type
=
remove_reference_t
<
decltype
(
declval
<
F
>
()(
0
,
_c
<
0
>
))
>
;
F
f
;
constexpr
Size
rsize
()
const
{
return
{};
}
template
<
class
U
,
class
V
>
constexpr
auto
operator
()(
U
j
,
V
d
)
const
{
return
f
(
j
,
d
);
}
};
template
<
class
Size
,
class
F
>
static
constexpr
inner_storage
<
Size
,
F
>
make_inner_storage
(
Size
,
F
f
)
{
return
{{},
{
f
}};
}
template
<
class
Op
,
class
T
,
class
Read
,
class
N
,
class
U
,
class
...
Us
>
__device__
auto
reduce_impl
(
Op
op
,
T
init
,
Read
read
,
N
n
,
U
&&
x
,
Us
&&
...
xs
)
const
{
using
type
=
remove_reference_t
<
decltype
(
x
(
0
,
_c
<
0
>
))
>
;
type
r
=
init
;
for
(
index_int
j
=
0
;
j
<
n
;
j
++
)
{
r
=
op
(
r
,
read
(
x
(
j
,
_c
<
0
>
),
xs
(
j
,
_c
<
0
>
)...));
}
return
r
;
}
template
<
class
F
>
...
...
@@ -281,29 +506,25 @@ struct lane
f
();
}
template
<
class
F
>
__device__
auto
inner
(
F
f
)
const
template
<
class
F
,
class
N
,
class
...
Ts
>
__device__
void
inner
_void_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
return
sliced
(
slice
,
[
=
](
auto
x
,
auto
...
xs
)
{
for
(
index_int
j
=
0
;
j
<
x
.
get_shape
().
elements
();
j
++
)
{
f
(
x
[
j
],
xs
[
j
]...);
}
});
for
(
index_int
j
=
0
;
j
<
n
;
j
++
)
{
f
(
xs
(
j
,
_c
<
0
>
)...);
}
}
template
<
class
Input
>
constexpr
auto
elements
(
)
const
template
<
class
R
,
class
F
,
class
N
,
class
...
Ts
>
__device__
auto
inner_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
using
reduce_type
=
decltype
(
slice
(
Input
{}));
return
get_shape_c
<
reduce_type
>
{}.
elements
();
return
make_inner_storage
(
n
,
[
=
](
auto
j
,
auto
d
)
{
return
f
(
xs
(
j
,
d
)...);
});
}
};
template
<
class
Slicer
>
static
__device__
auto
make
(
index
idx
,
Slicer
slicer
)
{
return
reducer
<
Slicer
>
{
idx
,
slicer
};
return
reducer
<
Slicer
>
{
{},
idx
,
slicer
};
}
template
<
class
Output
,
class
F
>
...
...
@@ -318,6 +539,26 @@ struct lane
}
};
// TODO: Remove these in the future when they can be selected in the compiler class
template
<
index_int
RElements
>
constexpr
auto
pick_block
()
{
using
nlocal
=
decltype
(
index
{}.
max_nlocal
());
if
constexpr
(
RElements
<
nlocal
{}
*
256
)
return
block
{};
else
return
block_large
{};
}
template
<
index_int
RElements
>
using
auto_block
=
decltype
(
pick_block
<
RElements
>
());
template
<
class
Input
,
index_int
Axis
>
constexpr
auto
reduce_elements_with_axis
()
{
constexpr
auto
s
=
get_shape_c
<
Input
>
{};
return
s
.
lens
[
Axis
];
}
}
// namespace reduce
template
<
class
Algo
,
...
...
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
View file @
9bff4331
...
...
@@ -76,14 +76,6 @@ struct shape
constexpr
index_int
index
(
index_array
x
)
const
{
return
x
.
dot
(
strides
);
}
constexpr
index_int
index
(
std
::
initializer_list
<
index_int
>
x
)
const
{
index_int
idx
=
0
;
for
(
index_int
i
=
0
;
i
<
x
.
size
();
i
++
)
idx
+=
*
(
x
.
begin
()
+
i
)
*
strides
[
i
];
return
idx
;
}
constexpr
index_int
index
(
index_int
i
)
const
{
if
(
this
->
standard
())
...
...
@@ -128,6 +120,7 @@ struct shape
result
[
0
]
=
tidx
;
return
result
;
}
/// Convert multi-index into a single index
constexpr
index_int
single
(
index_array
idx
)
const
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
View file @
9bff4331
...
...
@@ -30,18 +30,20 @@
namespace
migraphx
{
template
<
index_int
Axis
,
class
Input
,
class
Output
>
__device__
void
softmax
(
Input
input
,
Output
output
)
__device__
void
softmax
(
Input
input
1
,
Output
output
)
{
reduce
::
block
::
run
<
reduce
::
with_axis
<
Input
,
Axis
>>
([
&
](
auto
,
auto
r
)
{
using
block
=
reduce
::
auto_block
<
reduce
::
reduce_elements_with_axis
<
Input
,
Axis
>
()
>
;
block
::
template
run
<
reduce
::
with_axis
<
Input
,
Axis
>
>
([
&
](
auto
,
auto
r
)
{
auto
input
=
r
.
inner
(
op
::
id
{})(
input1
);
#ifdef MIGRAPHX_USE_FAST_SOFTMAX
const
auto
c
=
vec_at
(
r
.
slice
(
input
)[
0
],
0
);
const
auto
c
=
vec_at
(
r
.
slice
(
input
1
)[
0
],
0
);
#else
const
auto
c
=
r
.
reduce
(
op
::
max
{},
lowest
{},
op
::
id
{})(
input
);
#endif
auto
batch_sum
=
r
.
reduce
(
op
::
sum
{},
0
,
[
&
](
auto
x
)
{
return
migraphx
::
convert
<
float
>
(
migraphx
::
exp
(
x
-
c
));
})(
input
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
migraphx
::
exp
(
x
-
c
)
/
batch_sum
;
})(
output
,
input
);
auto
exp_in
=
r
.
inner
([
&
](
auto
x
)
{
return
migraphx
::
exp
(
x
-
c
);
})(
input
);
auto
batch_sum
=
r
.
reduce
(
op
::
sum
{},
0
,
[](
auto
x
)
{
return
migraphx
::
convert
<
float
>
(
x
);
})(
exp_in
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
x
/
batch_sum
;
})(
output
,
exp_in
);
});
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
View file @
9bff4331
...
...
@@ -141,6 +141,25 @@ MIGRAPHX_BUILTIN_TYPE_TRAITN(is_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN
(
is_nothrow_constructible
);
MIGRAPHX_BUILTIN_TYPE_TRAITN
(
is_trivially_constructible
);
template
<
class
T
>
struct
remove_cv
{
using
type
=
T
;
};
template
<
class
T
>
struct
remove_cv
<
const
T
>
:
remove_cv
<
T
>
{
};
template
<
class
T
>
struct
remove_cv
<
volatile
T
>
:
remove_cv
<
T
>
{
};
template
<
class
T
>
using
remove_cv_t
=
typename
remove_cv
<
T
>::
type
;
template
<
class
T
>
struct
remove_reference
{
...
...
@@ -168,6 +187,11 @@ struct add_pointer : type_identity<typename remove_reference<T>::type*>
template
<
class
T
>
using
add_pointer_t
=
typename
add_pointer
<
T
>::
type
;
template
<
class
T
>
struct
is_void
:
is_same
<
void
,
remove_cv_t
<
T
>>
{
};
template
<
class
...
Ts
>
struct
common_type
;
...
...
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
View file @
9bff4331
...
...
@@ -28,8 +28,45 @@
namespace
migraphx
{
using
index_int
=
std
::
uint32_t
;
using
diff_int
=
std
::
int32_t
;
#if defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS) and defined(MIGRAPHX_USE_HIPRTC)
using
int8_t
=
signed
char
;
using
uint8_t
=
unsigned
char
;
using
int16_t
=
signed
short
;
using
uint16_t
=
unsigned
short
;
using
int32_t
=
signed
int
;
using
uint32_t
=
unsigned
int
;
using
int64_t
=
signed
long
long
;
using
uint64_t
=
unsigned
long
long
;
#elif defined(MIGRAPHX_USE_HIPRTC)
using
int8_t
=
__hip_int8_t
;
using
uint8_t
=
__hip_uint8_t
;
using
int16_t
=
__hip_int16_t
;
using
uint16_t
=
__hip_uint16_t
;
using
int32_t
=
__hip_int32_t
;
using
uint32_t
=
__hip_uint32_t
;
using
int64_t
=
__hip_int64_t
;
using
uint64_t
=
__hip_uint64_t
;
#else
using
int8_t
=
std
::
int8_t
;
using
uint8_t
=
std
::
uint8_t
;
using
int16_t
=
std
::
int16_t
;
using
uint16_t
=
std
::
uint16_t
;
using
int32_t
=
std
::
int32_t
;
using
uint32_t
=
std
::
uint32_t
;
using
int64_t
=
std
::
int64_t
;
using
uint64_t
=
std
::
uint64_t
;
#endif // MIGRAPHX_USE_HIPRTC
using
index_int
=
uint32_t
;
using
diff_int
=
int32_t
;
static_assert
(
sizeof
(
int8_t
)
==
1
,
"int8_t must be 1 bytes"
);
static_assert
(
sizeof
(
uint8_t
)
==
1
,
"uint8_t must be 1 bytes"
);
static_assert
(
sizeof
(
int16_t
)
==
2
,
"int16_t must be 2 bytes"
);
static_assert
(
sizeof
(
uint16_t
)
==
2
,
"uint16_t must be 2 bytes"
);
static_assert
(
sizeof
(
int32_t
)
==
4
,
"int32_t must be 4 bytes"
);
static_assert
(
sizeof
(
uint32_t
)
==
4
,
"uint32_t must be 4 bytes"
);
static_assert
(
sizeof
(
int64_t
)
==
8
,
"int64_t must be 8 bytes"
);
static_assert
(
sizeof
(
uint64_t
)
==
8
,
"uint64_t must be 8 bytes"
);
#define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT
...
...
src/targets/gpu/lowering.cpp
View file @
9bff4331
...
...
@@ -83,14 +83,12 @@ struct miopen_apply
auto
&
ctx
=
get_context
();
int8_x4_format
=
get_int8_x4_format
(
ctx
);
compute_fp32
=
get_compute_fp32_flag
();
offload_copy
=
(
mod
->
name
()
==
"main"
)
?
pass
->
offload_copy
:
false
;
offload_copy
=
(
mod
->
name
()
==
"main"
)
?
pass
->
offload_copy
:
false
;
add_generic_op
(
"contiguous"
);
add_extend_op
(
"argmax"
);
add_extend_op
(
"argmin"
);
add_extend_op
(
"gather"
);
add_extend_op
(
"logsoftmax"
);
add_extend_op
(
"lrn"
);
add_extend_op
(
"multinomial"
);
...
...
@@ -113,6 +111,7 @@ struct miopen_apply
add_loop_op
();
add_neg_op
();
add_nms_op
();
add_select_module_op
();
}
void
copy_params
()
const
...
...
@@ -360,6 +359,20 @@ struct miopen_apply
return
mod
->
replace_instruction
(
ins
,
gpu_out
);
});
}
/**
* Adds dynamic allocation for submodule output parameter.
*/
void
add_select_module_op
()
{
apply_map
.
emplace
(
"select_module"
,
[
=
](
instruction_ref
ins
)
{
auto
s
=
ins
->
get_shape
();
auto
output
=
insert_allocation
(
ins
,
s
);
std
::
vector
<
instruction_ref
>
inputs
=
ins
->
inputs
();
inputs
.
push_back
(
output
);
return
mod
->
replace_instruction
(
ins
,
ins
->
get_operator
(),
inputs
,
ins
->
module_inputs
());
});
}
};
void
lowering
::
apply
(
module
&
m
)
const
{
miopen_apply
{
&
m
,
this
}.
apply
();
}
...
...
src/targets/gpu/mlir.cpp
View file @
9bff4331
...
...
@@ -32,7 +32,13 @@
#include <mlir-c/Dialect/MIGraphX.h>
#include <mlir-c/IntegerSet.h>
#include <mlir-c/Pass.h>
#include <mlir-c/Registration.h>
#include <mutex>
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3
#warning "Incompatible version of rocMLIR library used, disabling"
#undef MIGRAPHX_MLIR
#else
#include <mlir-c/RegisterRocMLIR.h>
#endif
#endif
#include <migraphx/env.hpp>
...
...
@@ -50,10 +56,6 @@
#include <deque>
#include <variant>
#if defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) && MLIR_MIGRAPHX_DIALECT_API_VERSION >= 2
#define MIGRAPHX_MLIR_BARE_POINTER
#endif
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
...
...
@@ -168,9 +170,11 @@ struct mlir_program
location
(
mlirLocationUnknownGet
(
ctx
.
get
())),
mmodule
(
mlirModuleCreateEmpty
(
location
))
{
MlirDialectHandle
mixr_handle
=
mlirGetDialectHandle__migraphx__
();
mlirDialectHandleRegisterDialect
(
mixr_handle
,
ctx
.
get
());
mlirRegisterAllDialects
(
ctx
.
get
());
MlirDialectRegistry
registry
=
mlirDialectRegistryCreate
();
mlirRegisterRocMLIRDialects
(
registry
);
mlirContextAppendDialectRegistry
(
ctx
.
get
(),
registry
);
mlirContextLoadAllAvailableDialects
(
ctx
.
get
());
mlirDialectRegistryDestroy
(
registry
);
mlirContextSetAllowUnregisteredDialects
(
ctx
.
get
(),
true
/*allow*/
);
}
...
...
@@ -452,7 +456,8 @@ struct mlir_program
auto
ops
=
create_operation_state
(
"func.func"
);
ops
.
add_attributes
({{
"function_type"
,
make_function_type
(
inputs
,
outputs
)},
{
"sym_name"
,
std
::
string
(
"main"
)},
{
"kernel"
,
std
::
string
(
"mixr"
)}});
{
"kernel"
,
std
::
string
(
"mixr"
)},
{
"arch"
,
target_arch
}});
ops
.
add_region
(
std
::
move
(
region
));
insert
(
body
,
std
::
move
(
ops
));
...
...
@@ -512,7 +517,8 @@ struct mlir_program
pp
=
problem_params
{
ins
->
get_operator
(),
to_shapes
(
ins
->
inputs
()),
ins
->
get_shape
()};
// check if HW supports xdlops
bool
xdlops
=
contains
(
get_xdlops_archs
(),
target_name
);
auto
target_chip
=
trim
(
split_string
(
target_arch
,
':'
).
front
());
bool
xdlops
=
contains
(
get_xdlops_archs
(),
target_chip
);
std
::
string
tuned
=
get_tune_params
(
xdlops
);
if
(
not
tuned
.
empty
())
ops
.
add_attributes
({{
"perf_config"
,
tuned
}});
...
...
@@ -540,7 +546,7 @@ struct mlir_program
// 1st pipeline to call
mlirMIGraphXAddHighLevelPipeline
(
pm
.
get
());
// 2nd pipeline to call
mlirMIGraphXAddBackendPipeline
(
pm
.
get
(),
target_
name
.
c_str
()
,
"amdgcn-amd-amdhsa"
,
""
);
mlirMIGraphXAddBackendPipeline
(
pm
.
get
(),
target_
arch
.
c_str
());
mlirPassManagerRun
(
pm
.
get
(),
mmodule
.
get
());
code_object_op
op
{};
...
...
@@ -550,16 +556,7 @@ struct mlir_program
return
op
;
}
void
find_target
()
{
std
::
string
tname
=
get_device_name
();
// HACK: Since MLIR can't handle the full target name
target_name
=
trim
(
split_string
(
tname
,
':'
).
front
());
if
(
tname
.
size
()
!=
target_name
.
size
())
std
::
cout
<<
"*************** WARNING: MLIR may not compile the correct target features for: "
<<
tname
<<
std
::
endl
;
}
void
find_target
()
{
target_arch
=
get_device_name
();
}
std
::
pair
<
std
::
size_t
,
std
::
size_t
>
get_launch_params
()
const
{
...
...
@@ -588,7 +585,7 @@ struct mlir_program
mlir_module
mmodule
;
problem_params
pp
;
std
::
deque
<
std
::
string
>
strings
{};
std
::
string
target_
name
;
std
::
string
target_
arch
;
};
std
::
string
dump_mlir
(
const
module
&
m
)
...
...
@@ -650,6 +647,10 @@ code_object_op compile_mlir(const context&, module m, const std::vector<instruct
const
bool
trace
=
enabled
(
MIGRAPHX_TRACE_MLIR
{});
if
(
trace
)
std
::
cout
<<
m
<<
std
::
endl
;
// set mutex while llvm thread support is disabled.
static
std
::
mutex
g_mlirc_mutex
;
// NOLINT
const
std
::
lock_guard
<
std
::
mutex
>
lock
(
g_mlirc_mutex
);
mlir_program
mp
;
mp
.
find_target
();
mp
.
parse
(
m
);
...
...
@@ -669,46 +670,9 @@ instruction_ref insert_mlir(module& m,
std
::
vector
<
instruction_ref
>
refs
;
std
::
size_t
last
=
0
;
#ifdef MIGRAPHX_MLIR_BARE_POINTER
refs
.
reserve
(
inputs
.
size
());
std
::
copy
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
refs
));
last
=
refs
.
size
()
-
1
;
#else
refs
.
reserve
(
inputs
.
size
()
*
15
);
std
::
unordered_map
<
uint64_t
,
instruction_ref
>
literal_map
{};
auto
get_literal
=
[
&
](
uint64_t
value
)
{
auto
fi
=
literal_map
.
find
(
value
);
if
(
fi
!=
literal_map
.
end
())
return
fi
->
second
;
auto
lit
=
m
.
add_literal
(
value
);
literal_map
.
emplace
(
value
,
lit
);
return
lit
;
};
for
(
auto
input
:
inputs
)
{
const
size_t
offset
=
0
;
auto
s
=
input
->
get_shape
();
last
=
refs
.
size
();
refs
.
push_back
(
input
);
refs
.
push_back
(
input
);
refs
.
push_back
(
get_literal
(
offset
));
// offset
// dim sizes
std
::
transform
(
s
.
lens
().
begin
(),
s
.
lens
().
end
(),
std
::
back_inserter
(
refs
),
[
&
](
const
auto
&
lval
)
{
return
get_literal
(
lval
);
});
// refs.push_back(get_literal(1)); // G
// dim strides
std
::
transform
(
s
.
strides
().
begin
(),
s
.
strides
().
end
(),
std
::
back_inserter
(
refs
),
[
&
](
const
auto
&
lval
)
{
return
get_literal
(
lval
);
});
// refs.push_back(get_literal(1)); // G
}
#endif
last
=
refs
.
size
()
-
1
;
co
.
expected_inputs
=
to_shapes
(
refs
);
co
.
output_arg
=
last
;
return
m
.
insert_instruction
(
ins
,
co
,
refs
);
...
...
src/targets/gpu/perfdb.cpp
View file @
9bff4331
...
...
@@ -27,6 +27,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/permutation.hpp>
#include <fstream>
#include <mutex>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -88,6 +89,9 @@ std::string generate_miopen_config(const problem_params& pp)
auto
query_miopen_db
(
const
std
::
string
&
query
)
{
static
std
::
mutex
g_db_mutex
;
// NOLINT
const
std
::
lock_guard
<
std
::
mutex
>
lock
(
g_db_mutex
);
// TODO: Store db as a static variable
const
auto
dbpath
=
fs
::
path
{
"/opt"
}
/
"rocm"
/
"share"
/
"miopen"
/
"db"
/
"miopen.db"
;
// Check if db file exists.
...
...
src/targets/gpu/prefuse_ops.cpp
View file @
9bff4331
...
...
@@ -26,6 +26,8 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -90,7 +92,9 @@ struct find_layernorm
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
eps
=
r
.
instructions
[
"eps"
]
->
eval
().
at
<
float
>
();
float
eps
=
0
;
if
(
contains
(
r
.
instructions
,
"eps"
))
eps
=
r
.
instructions
[
"eps"
]
->
eval
().
at
<
float
>
();
m
.
replace_instruction
(
ins
,
layernorm
{
eps
},
x_ins
);
}
...
...
@@ -100,23 +104,26 @@ struct find_add_layernorm
{
auto
matcher
()
const
{
return
match
::
layernorm
()(
match
::
var
(
"x"
)(
match
::
name
(
"add"
).
bind
(
"add"
)));
return
match
::
name
(
"gpu::prelayernorm"
)(
match
::
args
(
match
::
name
(
"add"
)(
match
::
used_once
()).
bind
(
"add"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
eps
=
r
.
instructions
[
"eps"
]
->
eval
().
at
<
float
>
(
);
auto
op
=
any_cast
<
layernorm
>
(
ins
->
get_operator
()
);
m
.
replace_instruction
(
ins
,
add_layernorm
{
eps
},
add_ins
->
inputs
());
m
.
replace_instruction
(
ins
,
add_layernorm
{
op
.
epsilon
},
add_ins
->
inputs
());
}
};
}
// namespace
void
prefuse_ops
::
apply
(
module
&
m
)
const
void
prefuse_ops
::
apply
(
module
_pass_manager
&
mp
m
)
const
{
match
::
find_matches
(
m
,
find_add_layernorm
{},
find_layernorm
{});
match
::
find_matches
(
mpm
.
get_module
(),
find_layernorm
{});
mpm
.
run_pass
(
dead_code_elimination
{});
match
::
find_matches
(
mpm
.
get_module
(),
find_add_layernorm
{});
}
}
// namespace gpu
...
...
src/targets/gpu/target.cpp
View file @
9bff4331
...
...
@@ -38,6 +38,7 @@
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/memory_coloring.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/optimize_module.hpp>
#include <migraphx/preallocate_param.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/register_target.hpp>
...
...
@@ -74,7 +75,6 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_SCHEDULE_PASS
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_NHWC
)
struct
id_pass
{
std
::
string
name
()
const
{
return
"id"
;
}
...
...
@@ -91,6 +91,7 @@ pass enable_pass(bool enabled, pass p)
std
::
vector
<
pass
>
target
::
get_passes
(
migraphx
::
context
&
gctx
,
const
compile_options
&
options
)
const
{
auto
&
ctx
=
any_cast
<
context
>
(
gctx
);
ctx
.
set_exhaustive_tune_flag
(
options
.
exhaustive_tune
);
std
::
set
<
shape
::
type_t
>
unsupported_types
(
shape
::
types
().
begin
(),
shape
::
types
().
end
());
unsupported_types
.
erase
(
shape
::
type_t
::
float_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
half_type
);
...
...
@@ -119,21 +120,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
rewrite_pooling
{},
dead_code_elimination
{},
rewrite_gelu
{},
dead_code_elimination
{},
eliminate_common_subexpression
{},
dead_code_elimination
{},
simplify_algebra
{},
simplify_reshapes
{},
optimize_module
{},
enable_pass
(
enabled
(
MIGRAPHX_ENABLE_NHWC
{}),
layout_nhwc
{}),
dead_code_elimination
{},
simplify_reshapes
{},
simplify_algebra
{},
prefuse_ops
{},
dead_code_elimination
{},
auto_contiguous
{},
simplify_reshapes
{},
propagate_constant
{},
dead_code_elimination
{},
optimize_module
{},
enable_pass
(
not
enabled
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
{}),
fuse_pointwise
{}),
dead_code_elimination
{},
fuse_mlir
{
&
ctx
},
...
...
src/targets/ref/include/migraphx/ref/target.hpp
View file @
9bff4331
...
...
@@ -46,8 +46,6 @@ struct target
argument
allocate
(
const
shape
&
s
)
const
;
};
MIGRAPHX_REGISTER_TARGET
(
target
);
}
// namespace ref
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/targets/ref/lowering.cpp
View file @
9bff4331
...
...
@@ -346,10 +346,10 @@ struct ref_pad
std
::
string
name
()
const
{
return
"ref::pad"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
assert
(
out
put_shape
.
standard
());
argument
result
{
out
put_shape
};
assert
(
dyn_out
.
com
put
ed
_shape
.
standard
());
argument
result
{
dyn_out
.
com
put
ed
_shape
};
result
.
visit
([
&
](
auto
output
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
std
::
fill
(
output
.
begin
(),
output
.
end
(),
pad_clamp
<
type
>
(
op
.
value
));
...
...
@@ -383,9 +383,9 @@ struct ref_gemm
std
::
string
name
()
const
{
return
"ref::dot"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
migemm
(
result
,
args
[
0
],
args
[
1
],
1.0
f
,
0.0
f
);
return
result
;
...
...
@@ -449,10 +449,10 @@ struct ref_softmax : auto_register_op<ref_softmax<Op>>
{
return
op
.
normalize_compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
out
put_shape
};
auto
batch_lens
=
out
put_shape
.
lens
();
argument
result
{
dyn_out
.
com
put
ed
_shape
};
auto
batch_lens
=
dyn_out
.
com
put
ed
_shape
.
lens
();
int64_t
tuned_axis
=
tune_axis
(
args
[
0
].
get_shape
().
lens
().
size
(),
op
.
axis
,
op
.
name
());
std
::
size_t
n_dims
=
batch_lens
[
tuned_axis
];
batch_lens
[
tuned_axis
]
=
1
;
...
...
@@ -475,7 +475,7 @@ struct ref_softmax : auto_register_op<ref_softmax<Op>>
for
(
std
::
size_t
j
=
0
;
j
<
n_dims
;
++
j
)
{
idx
[
tuned_axis
]
=
j
;
std
::
size_t
index
=
out
put_shape
.
index
(
idx
);
std
::
size_t
index
=
dyn_out
.
com
put
ed
_shape
.
index
(
idx
);
output
[
index
]
=
std
::
exp
(
input
[
index
]
-
batch_max
[
i
]);
}
...
...
test/CMakeLists.txt
View file @
9bff4331
#####################################################################################
#
####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
...
...
@@ -20,7 +20,7 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
#
####################################################################################
cmake_policy
(
SET CMP0057 NEW
)
...
...
@@ -49,27 +49,31 @@ function(add_test_command NAME EXE)
set_tests_properties
(
${
NAME
}
PROPERTIES DISABLED On
)
elseif
(
WIN32
)
set
(
WINPATH
)
foreach
(
PATH
${
CMAKE_FIND_ROOT_PATH
}
)
list
(
APPEND WINPATH
${
PATH
}
/bin
)
endforeach
()
file
(
GENERATE OUTPUT
"
${
CMAKE_CURRENT_BINARY_DIR
}
/test_
${
NAME
}
.cmd"
CONTENT
"set PATH=
${
WINPATH
}
;%PATH%
%1
${
ARGN
}
"
)
add_test
(
NAME
${
NAME
}
COMMAND
${
WINE_CMD
}
cmd /c
"
${
CMAKE_CURRENT_BINARY_DIR
}
/test_
${
NAME
}
.cmd"
$<TARGET_FILE:
${
EXE
}
>
)
else
()
if
(
MIGRAPHX_TEST_GDB
)
# add_test(NAME ${NAME} COMMAND ${MIGRAPHX_GDB}
#
--batch
#
--return-child-result
#
-ex "set disable-randomization off"
#
-ex run
#
-ex backtrace
#
--args $<TARGET_FILE:${EXE}> ${ARGN})
# add_test(NAME ${NAME} COMMAND ${MIGRAPHX_GDB}
# --batch
# --return-child-result
# -ex "set disable-randomization off"
# -ex run
# -ex backtrace
# --args $<TARGET_FILE:${EXE}> ${ARGN})
set
(
TEST_DIR
${
CMAKE_CURRENT_BINARY_DIR
}
/gdb/test_
${
NAME
}
)
file
(
MAKE_DIRECTORY
${
TEST_DIR
}
)
if
(
NOT EXISTS
${
TEST_DIR
}
)
if
(
NOT EXISTS
${
TEST_DIR
}
)
message
(
FATAL_ERROR
"Failed to create test directory:
${
TEST_DIR
}
"
)
endif
()
file
(
GENERATE OUTPUT
"
${
TEST_DIR
}
/run.cmake"
CONTENT
"
# Remove previous core dump
...
...
@@ -90,26 +94,23 @@ function(add_test_command NAME EXE)
add_test
(
NAME
${
NAME
}
COMMAND
${
EXE
}
${
ARGN
}
)
endif
()
endif
()
set_tests_properties
(
${
NAME
}
PROPERTIES FAIL_REGULAR_EXPRESSION
"FAILED"
)
endfunction
()
function
(
add_test_executable TEST_NAME
)
add_executable
(
${
TEST_NAME
}
EXCLUDE_FROM_ALL
${
ARGN
}
)
add_executable
(
${
TEST_NAME
}
EXCLUDE_FROM_ALL
${
ARGN
}
)
target_link_libraries
(
${
TEST_NAME
}
${
CMAKE_THREAD_LIBS_INIT
}
)
# Cmake does not add flags correctly for gcc
if
(
CMAKE_CXX_COMPILER_ID MATCHES
"GNU"
)
if
(
CMAKE_CXX_COMPILER_ID MATCHES
"GNU"
)
set_target_properties
(
${
TEST_NAME
}
PROPERTIES COMPILE_FLAGS -pthread LINK_FLAGS -pthread
)
endif
()
separate_arguments
(
MIOPEN_TEST_FLAGS_ARGS UNIX_COMMAND
${
MIOPEN_TEST_FLAGS
}
)
if
(
MIOPEN_TEST_ALL
)
set
(
TEST_COMMAND
${
TEST_NAME
}
${
MIOPEN_TEST_FLOAT_ARG
}
--all
${
MIOPEN_TEST_FLAGS_ARGS
}
)
else
()
set
(
TEST_COMMAND
${
TEST_NAME
}
${
MIOPEN_TEST_FLOAT_ARG
}
${
MIOPEN_TEST_FLAGS_ARGS
}
)
endif
()
set
(
TEST_COMMAND
${
TEST_NAME
}
)
add_test_command
(
${
TEST_NAME
}
${
TEST_COMMAND
}
)
add_dependencies
(
tests
${
TEST_NAME
}
)
add_dependencies
(
check
${
TEST_NAME
}
)
target_link_libraries
(
${
TEST_NAME
}
migraphx
migraphx_ref
migraphx_onnx
)
target_link_libraries
(
${
TEST_NAME
}
migraphx migraphx_onnx
)
target_include_directories
(
${
TEST_NAME
}
PUBLIC include
)
endfunction
(
add_test_executable
)
...
...
@@ -129,11 +130,11 @@ if(MIGRAPHX_ENABLE_GPU)
get_filename_component
(
BASE_NAME
${
TEST
}
NAME_WE
)
add_test_executable
(
test_gpu_
${
BASE_NAME
}
${
TEST
}
)
rocm_clang_tidy_check
(
test_gpu_
${
BASE_NAME
}
)
set_tests_properties
(
test_gpu_
${
BASE_NAME
}
PROPERTIES
COST 10
set_tests_properties
(
test_gpu_
${
BASE_NAME
}
PROPERTIES
COST 10
RESOURCE_LOCK gpu
)
target_link_libraries
(
test_gpu_
${
BASE_NAME
}
migraphx_gpu
)
target_link_libraries
(
test_gpu_
${
BASE_NAME
}
migraphx_gpu
migraphx_kernels
)
endforeach
()
endif
()
...
...
@@ -145,8 +146,8 @@ if(MIGRAPHX_ENABLE_FPGA)
get_filename_component
(
BASE_NAME
${
TEST
}
NAME_WE
)
add_test_executable
(
test_fpga_
${
BASE_NAME
}
${
TEST
}
)
rocm_clang_tidy_check
(
test_fpga_
${
BASE_NAME
}
)
set_tests_properties
(
test_fpga_
${
BASE_NAME
}
PROPERTIES
COST 10
set_tests_properties
(
test_fpga_
${
BASE_NAME
}
PROPERTIES
COST 10
RESOURCE_LOCK fpga
)
target_link_libraries
(
test_fpga_
${
BASE_NAME
}
migraphx_fpga
)
...
...
@@ -155,15 +156,16 @@ endif()
# Onnx test
set
(
TEST_ONNX_DIR
${
CMAKE_CURRENT_SOURCE_DIR
}
/onnx
)
file
(
GLOB ONNX_TESTS
${
TEST_ONNX_DIR
}
/*.cpp
)
file
(
GLOB ONNX_TESTS
${
TEST_ONNX_DIR
}
/*.cpp
)
foreach
(
ONNX_TEST
${
ONNX_TESTS
}
)
get_filename_component
(
BASE_NAME
${
ONNX_TEST
}
NAME_WE
)
set
(
TEST_NAME test_
${
BASE_NAME
}
)
add_executable
(
${
TEST_NAME
}
${
ONNX_TEST
}
)
rocm_clang_tidy_check
(
${
TEST_NAME
}
)
target_link_libraries
(
${
TEST_NAME
}
migraphx_onnx
migraphx_ref
)
target_link_libraries
(
${
TEST_NAME
}
migraphx_onnx
)
target_include_directories
(
${
TEST_NAME
}
PUBLIC include
)
add_test
(
NAME
${
TEST_NAME
}
COMMAND $<TARGET_FILE:
${
TEST_NAME
}
> WORKING_DIRECTORY
${
TEST_ONNX_DIR
}
)
add_test
(
NAME
${
TEST_NAME
}
COMMAND $<TARGET_FILE:
${
TEST_NAME
}
> WORKING_DIRECTORY
${
TEST_ONNX_DIR
}
)
add_dependencies
(
tests
${
TEST_NAME
}
)
add_dependencies
(
check
${
TEST_NAME
}
)
endforeach
()
...
...
@@ -172,28 +174,28 @@ endforeach()
set
(
TEST_TF_DIR
${
CMAKE_CURRENT_SOURCE_DIR
}
/tf
)
add_executable
(
test_tf tf/tf_test.cpp
)
rocm_clang_tidy_check
(
test_tf
)
target_link_libraries
(
test_tf migraphx_tf
migraphx_ref
)
target_link_libraries
(
test_tf migraphx_tf
)
target_include_directories
(
test_tf PUBLIC include
)
add_test
(
NAME test_tf COMMAND $<TARGET_FILE:test_tf> WORKING_DIRECTORY
${
TEST_TF_DIR
}
)
add_test
(
NAME test_tf COMMAND $<TARGET_FILE:test_tf> WORKING_DIRECTORY
${
TEST_TF_DIR
}
)
add_dependencies
(
tests test_tf
)
add_dependencies
(
check test_tf
)
add_subdirectory
(
api
)
add_subdirectory
(
verify
)
if
(
MIGRAPHX_ENABLE_PYTHON
)
add_subdirectory
(
py
)
add_subdirectory
(
py
)
endif
()
function
(
test_header NAME HEADER
)
file
(
WRITE
${
CMAKE_CURRENT_BINARY_DIR
}
/header-main-include-
${
NAME
}
.cpp
file
(
WRITE
${
CMAKE_CURRENT_BINARY_DIR
}
/header-main-include-
${
NAME
}
.cpp
"#include <
${
HEADER
}
>
\n
int main() {}
\n
"
)
file
(
WRITE
${
CMAKE_CURRENT_BINARY_DIR
}
/header-static-include-
${
NAME
}
.cpp
file
(
WRITE
${
CMAKE_CURRENT_BINARY_DIR
}
/header-static-include-
${
NAME
}
.cpp
"#include <
${
HEADER
}
>
\n
"
)
add_test_executable
(
${
NAME
}
${
CMAKE_CURRENT_BINARY_DIR
}
/header-main-include-
${
NAME
}
.cpp
${
CMAKE_CURRENT_BINARY_DIR
}
/header-main-include-
${
NAME
}
.cpp
${
CMAKE_CURRENT_BINARY_DIR
}
/header-static-include-
${
NAME
}
.cpp
)
endfunction
()
...
...
@@ -206,14 +208,13 @@ function(test_headers PREFIX)
string
(
MAKE_C_IDENTIFIER
${
HEADER_REL
}
TEST_NAME
)
get_filename_component
(
BASE_NAME
${
HEADER
}
NAME_WE
)
test_header
(
header_
${
TEST_NAME
}
${
PREFIX
}
/
${
BASE_NAME
}
.hpp
)
if
(
MIGRAPHX_ENABLE_GPU
)
target_link_libraries
(
header_
${
TEST_NAME
}
migraphx_gpu
)
endif
()
target_link_libraries
(
header_
${
TEST_NAME
}
migraphx_all_targets
)
endforeach
()
endfunction
()
test_headers
(
migraphx
${
CMAKE_SOURCE_DIR
}
/src/include/migraphx/*.hpp
)
test_headers
(
migraphx/ref
${
CMAKE_SOURCE_DIR
}
/src/targets/ref/include/migraphx/ref/*.hpp
)
if
(
MIGRAPHX_ENABLE_GPU
)
test_headers
(
migraphx/gpu
${
CMAKE_SOURCE_DIR
}
/src/targets/gpu/include/migraphx/gpu/*.hpp
)
test_headers
(
migraphx/gpu
${
CMAKE_SOURCE_DIR
}
/src/targets/gpu/include/migraphx/gpu/*.hpp
)
endif
()
test/api/CMakeLists.txt
View file @
9bff4331
...
...
@@ -56,8 +56,10 @@ add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR})
add_api_test
(
tf_parser test_tf_parser.cpp
${
TEST_TF_DIR
}
)
# GPU-based tests
if
(
MIGRAPHX_ENABLE_GPU
)
list
(
APPEND CMAKE_PREFIX_PATH /opt/rocm
)
find_package
(
hip
)
add_api_test
(
gpu test_gpu.cpp
${
TEST_ONNX_DIR
}
)
target_link_libraries
(
test_api_gpu
migraphx_gpu
)
target_link_libraries
(
test_api_gpu
hip::host
)
add_api_test
(
custom_op_gpu test_custom_op_gpu.cpp
${
TEST_ONNX_DIR
}
)
target_link_libraries
(
test_api_custom_op_gpu
migraphx_gpu
)
target_link_libraries
(
test_api_custom_op_gpu
hip::host
)
endif
()
test/api/test_cpu.cpp
View file @
9bff4331
...
...
@@ -66,7 +66,7 @@ TEST_CASE(load_and_run_init_list)
TEST_CASE
(
quantize_fp16
)
{
auto
p1
=
migraphx
::
parse_onnx
(
"gemm_
ex_
test.onnx"
);
auto
p1
=
migraphx
::
parse_onnx
(
"gemm_test.onnx"
);
const
auto
&
p2
=
p1
;
const
auto
&
p3
=
p1
;
migraphx
::
quantize_fp16
(
p1
);
...
...
@@ -82,7 +82,7 @@ TEST_CASE(quantize_fp16)
TEST_CASE
(
quantize_int8
)
{
auto
p1
=
migraphx
::
parse_onnx
(
"gemm_
ex_
test.onnx"
);
auto
p1
=
migraphx
::
parse_onnx
(
"gemm_test.onnx"
);
const
auto
&
p2
=
p1
;
auto
t
=
migraphx
::
target
(
"ref"
);
migraphx
::
quantize_int8_options
options
;
...
...
test/api/test_gpu.cpp
View file @
9bff4331
...
...
@@ -35,6 +35,7 @@ TEST_CASE(load_and_run)
auto
shapes_before
=
p
.
get_output_shapes
();
migraphx
::
compile_options
options
;
options
.
set_offload_copy
();
options
.
set_exhaustive_tune_flag
();
p
.
compile
(
migraphx
::
target
(
"gpu"
),
options
);
auto
shapes_after
=
p
.
get_output_shapes
();
CHECK
(
shapes_before
.
size
()
==
1
);
...
...
test/api/test_save_load.cpp
View file @
9bff4331
...
...
@@ -30,7 +30,6 @@ TEST_CASE(load_save_default)
std
::
string
filename
=
"migraphx_api_load_save.mxr"
;
auto
p1
=
migraphx
::
parse_onnx
(
"conv_relu_maxpool_test.onnx"
);
auto
s1
=
p1
.
get_output_shapes
();
migraphx
::
save
(
p1
,
filename
.
c_str
());
auto
p2
=
migraphx
::
load
(
filename
.
c_str
());
auto
s2
=
p2
.
get_output_shapes
();
...
...
test/fpga/get_target_assignments.cpp
View file @
9bff4331
...
...
@@ -26,7 +26,6 @@
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/fpga/target.hpp>
#include <migraphx/target_assignments.hpp>
#include <migraphx/iterator_for.hpp>
...
...
test/gpu/hip.cpp
0 → 100644
View file @
9bff4331
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <test.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/target.hpp>
TEST_CASE
(
tuple_to_from_gpu
)
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
int32_type
,
{
2
,
4
}};
std
::
vector
<
float
>
p1_data
=
{
1.1
,
2.2
,
3.3
,
4.4
,
5.5
,
6.6
};
std
::
vector
<
int
>
p2_data
=
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
};
auto
p1
=
migraphx
::
argument
{
s1
,
p1_data
.
data
()};
auto
p2
=
migraphx
::
argument
{
s2
,
p2_data
.
data
()};
auto
p1_gpu
=
migraphx
::
gpu
::
to_gpu
(
p1
);
auto
p2_gpu
=
migraphx
::
gpu
::
to_gpu
(
p2
);
auto
p_tuple
=
migraphx
::
gpu
::
from_gpu
(
migraphx
::
argument
({
p1_gpu
,
p2_gpu
}));
std
::
vector
<
migraphx
::
argument
>
results
=
p_tuple
.
get_sub_objects
();
std
::
vector
<
float
>
result1
;
results
[
0
].
visit
([
&
](
auto
output
)
{
result1
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
int
>
result2
;
results
[
1
].
visit
([
&
](
auto
output
)
{
result2
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
result1
==
p1_data
);
EXPECT
(
result2
==
p2_data
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/gpu/jit.cpp
View file @
9bff4331
...
...
@@ -27,21 +27,22 @@
#include <migraphx/generate.hpp>
#include <migraphx/program.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/gpu/kernel.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx_kernels.hpp>
// NOLINTNEXTLINE
const
std
::
string
write_2s
=
R"__migraphx__(
#include <hip/hip_runtime.h>
extern "C" {
__global__ void write(
int8_t
* data)
__global__ void write(
char
* data)
{
int num = threadIdx.x + blockDim.x * blockIdx.x;
data[num] = 2;
...
...
@@ -58,7 +59,7 @@ const std::string add_2s_binary = R"__migraphx__(
#include <hip/hip_runtime.h>
extern "C" {
__global__ void add_2(
std::int8_t* x, std::int8_t
* y)
__global__ void add_2(
char* x, char
* y)
{
int num = threadIdx.x + blockDim.x * blockIdx.x;
y[num] = x[num] + 2;
...
...
@@ -137,7 +138,8 @@ int main() {}
const
std
::
string
math_template
=
R"__migraphx__(
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/math.hpp>
#include <migraphx/kernels/types.hpp>
using namespace migraphx;
extern "C" {
__global__ void kernel(${type}* p)
{
...
...
@@ -233,7 +235,7 @@ TEST_CASE(code_object_hip)
auto
y
=
mm
->
add_parameter
(
"output"
,
input
);
mm
->
add_instruction
(
co
,
x
,
y
);
migraphx
::
compile_options
options
;
p
.
compile
(
migraphx
::
gpu
::
target
{}
,
options
);
p
.
compile
(
migraphx
::
make_
target
(
"gpu"
)
,
options
);
auto
result
=
migraphx
::
gpu
::
from_gpu
(
p
.
eval
({{
"output"
,
migraphx
::
gpu
::
allocate_gpu
(
input
)}}).
front
());
...
...
@@ -259,7 +261,7 @@ TEST_CASE(compile_code_object_hip)
auto
x
=
mm
->
add_literal
(
input_literal
);
auto
y
=
mm
->
add_parameter
(
"output"
,
input
);
mm
->
add_instruction
(
co
,
x
,
y
);
p
.
compile
(
migraphx
::
gpu
::
target
{}
,
migraphx
::
compile_options
{});
p
.
compile
(
migraphx
::
make_
target
(
"gpu"
)
,
migraphx
::
compile_options
{});
auto
result
=
migraphx
::
gpu
::
from_gpu
(
p
.
eval
({{
"output"
,
migraphx
::
gpu
::
allocate_gpu
(
input
)}}).
front
());
...
...
@@ -282,7 +284,7 @@ TEST_CASE(compile_pointwise)
auto
x
=
mm
->
add_literal
(
input_literal
);
auto
y
=
mm
->
add_parameter
(
"output"
,
input
);
mm
->
add_instruction
(
co
,
x
,
y
);
p
.
compile
(
migraphx
::
gpu
::
target
{}
,
migraphx
::
compile_options
{});
p
.
compile
(
migraphx
::
make_
target
(
"gpu"
)
,
migraphx
::
compile_options
{});
auto
result
=
migraphx
::
gpu
::
from_gpu
(
p
.
eval
({{
"output"
,
migraphx
::
gpu
::
allocate_gpu
(
input
)}}).
front
());
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment