Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2b936b13
Unverified
Commit
2b936b13
authored
Sep 19, 2022
by
Charlie Lin
Committed by
GitHub
Sep 19, 2022
Browse files
Merge branch 'develop' into refactor_auto_pad_conv
parents
ca360585
97a1ed2d
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
286 additions
and
88 deletions
+286
-88
.github/workflows/performance.yaml
.github/workflows/performance.yaml
+0
-2
src/targets/gpu/compile_hip_code_object.cpp
src/targets/gpu/compile_hip_code_object.cpp
+5
-5
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+13
-1
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
+110
-46
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
+81
-5
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
...argets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
+17
-13
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+11
-13
test/fuse_pointwise.cpp
test/fuse_pointwise.cpp
+1
-1
test/gpu/pack_int8_args.cpp
test/gpu/pack_int8_args.cpp
+1
-1
test/verify/test_conv_group_add.cpp
test/verify/test_conv_group_add.cpp
+46
-0
No files found.
.github/workflows/performance.yaml
View file @
2b936b13
...
...
@@ -26,8 +26,6 @@ on:
required
:
true
default
:
'
-s'
concurrency
:
benchmark
jobs
:
release
:
uses
:
rocmsoftwareplatform/migraphx-benchmark/.github/workflows/perf-test.yml@main
...
...
src/targets/gpu/compile_hip_code_object.cpp
View file @
2b936b13
...
...
@@ -138,16 +138,16 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std
::
size_t
groups
=
(
n
+
local
-
1
)
/
local
;
std
::
size_t
max_blocks
=
max_global
/
local
;
std
::
size_t
nglobal
=
std
::
min
(
max_blocks
*
over
,
groups
)
*
local
;
return
nglobal
;
return
std
::
min
(
nglobal
,
n
)
;
};
}
std
::
size_t
compute_block_size
(
std
::
size_t
n
,
std
::
size_t
max_block_size
)
{
size_t
block_size
=
128
;
while
(
block_size
<=
max_block_size
and
block_size
<
=
n
)
block_size
*=
2
;
return
block_size
/
2
;
const
std
::
size_t
min_
block_size
=
64
;
const
std
::
size_t
base_
block_size
=
32
;
auto
block_size
=
(((
n
-
1
)
/
base_block_size
+
1
))
*
base_block_size
;
return
std
::
min
(
std
::
max
(
min_block_size
,
block_size
),
max_block_size
)
;
}
operation
compile_hip_code_object
(
const
std
::
string
&
content
,
hip_compile_options
options
)
...
...
src/targets/gpu/fuse_mlir.cpp
View file @
2b936b13
...
...
@@ -61,13 +61,25 @@ struct mlir_conv
MIGRAPHX_REGISTER_OP
(
mlir_conv
);
namespace
{
MIGRAPHX_PRED_MATCHER
(
is_mlir_conv
,
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"convolution"
)
return
false
;
value
v
=
ins
->
get_operator
().
to_value
();
auto
group
=
v
.
at
(
"group"
).
to
<
int
>
();
if
(
group
!=
1
)
return
false
;
return
true
;
}
struct
find_conv_pointwise
{
// Find a convolution followed by a pointwise operation.
auto
matcher
()
const
{
auto
convolution
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"convolution"
).
bind
(
"convolution"
));
match
::
skip
(
match
::
name
(
"contiguous"
))(
is_mlir_conv
(
).
bind
(
"convolution"
));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
convolution
.
bind
(
"x"
)));
}
...
...
src/targets/gpu/fuse_ops.cpp
View file @
2b936b13
...
...
@@ -259,7 +259,7 @@ struct hip_add_relu : binary_device<hip_add_relu, &device::add_relu>
};
MIGRAPHX_REGISTER_OP
(
hip_add_relu
)
struct
hip_add_sigmoid
:
binary_device
<
hip_add_
relu
,
&
device
::
add_sigmoid
>
struct
hip_add_sigmoid
:
binary_device
<
hip_add_
sigmoid
,
&
device
::
add_sigmoid
>
{
};
MIGRAPHX_REGISTER_OP
(
hip_add_sigmoid
)
...
...
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
View file @
2b936b13
...
...
@@ -33,49 +33,95 @@
namespace
migraphx
{
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_ARRAY_OP(op, binary_op) \
template <class U> \
constexpr array& operator op(const array<U, N>& x) \
{ \
for(index_int i = 0; i < N; i++) \
d[i] op x[i]; \
return *this; \
} \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
constexpr array& operator op(const U& x) \
{ \
for(index_int i = 0; i < N; i++) \
d[i] op x; \
return *this; \
} \
template <class U> \
friend constexpr auto operator binary_op(const array& x, const array<U, N>& y) \
{ \
array<decltype(T {} binary_op U{}), N> z{}; \
for(index_int i = 0; i < N; i++) \
z[i] = x[i] binary_op y[i]; \
return z; \
} \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
friend constexpr auto operator binary_op(const array& x, const U& y) \
{ \
array<decltype(T {} binary_op U{}), N> z{}; \
for(index_int i = 0; i < N; i++) \
z[i] = x[i] binary_op y; \
return z; \
} \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
friend constexpr auto operator binary_op(const U& x, const array& y) \
{ \
array<decltype(T {} binary_op U{}), N> z{}; \
for(index_int i = 0; i < N; i++) \
z[i] = x binary_op y[i]; \
return z; \
#define MIGRAPHX_DEVICE_ARRAY_OP(op, binary_op) \
template <class U> \
constexpr array& operator op(const array<U, N>& x) \
{ \
array_detail::array_for_each(*this, x)([](auto& sy, auto sx) { sy op sx; }); \
return *this; \
} \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
constexpr array& operator op(const U& x) \
{ \
array_detail::array_for_each (*this)([&](auto& sy) { sy op x; }); \
return *this; \
} \
template <class U> \
friend constexpr auto operator binary_op(const array& x, const array<U, N>& y) \
{ \
array<decltype(T {} binary_op U{}), N> z{}; \
array_detail::array_for_each(z, x, y)( \
[&](auto& sz, auto sx, auto sy) { sz = sx binary_op sy; }); \
return z; \
} \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
friend constexpr auto operator binary_op(const array& x, const U& y) \
{ \
array<decltype(T {} binary_op U{}), N> z{}; \
array_detail::array_for_each(z, x)([&](auto& sz, auto sx) { sz = sx binary_op y; }); \
return z; \
} \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
friend constexpr auto operator binary_op(const U& x, const array& y) \
{ \
array<decltype(T {} binary_op U{}), N> z{}; \
array_detail::array_for_each(z, y)([&](auto& sz, auto sy) { sz = x binary_op sy; }); \
return z; \
}
namespace
array_detail
{
template
<
class
T
>
constexpr
auto
is_vectorizable
()
{
return
not
is_same
<
T
,
bool
>
{}
and
(
is_fundamental
<
T
>
{}
or
is_same
<
T
,
half
>
{});
}
template
<
class
T
>
__device__
auto
&
array2vec
(
T
&
x
)
{
using
value_type
=
typename
T
::
value_type
;
constexpr
auto
size
=
decltype
(
x
.
size
()){};
using
type
=
vec
<
value_type
,
size
>
;
if
constexpr
(
is_const
<
T
>
{})
return
reinterpret_cast
<
const
type
&>
(
x
);
else
return
reinterpret_cast
<
type
&>
(
x
);
}
template
<
class
T
,
class
...
Ts
>
constexpr
auto
array_for_each
(
T
&
x
,
Ts
&
...
xs
)
{
MIGRAPHX_ASSERT
(((
x
.
size
()
==
xs
.
size
())
and
...));
return
[
&
](
auto
f
)
{
constexpr
auto
size
=
decltype
(
x
.
size
()){};
if
constexpr
((
is_vectorizable
<
typename
T
::
value_type
>
()
or
(
is_vectorizable
<
typename
Ts
::
value_type
>
()
or
...))
and
size
<=
8
and
size
>
1
and
(
size
%
2
==
0
))
{
if
(
__builtin_is_constant_evaluated
())
{
for
(
index_int
i
=
0
;
i
<
size
;
i
++
)
f
(
x
[
i
],
xs
[
i
]...);
}
else
{
using
vec_type
=
std
::
remove_reference_t
<
decltype
(
array2vec
(
x
))
>
;
f
(
array2vec
(
x
),
__builtin_convertvector
(
array2vec
(
xs
),
vec_type
)...);
}
}
else
{
for
(
index_int
i
=
0
;
i
<
size
;
i
++
)
f
(
x
[
i
],
xs
[
i
]...);
}
};
}
}
// namespace array_detail
template
<
class
T
,
index_int
N
>
struct
array
{
using
value_type
=
T
;
T
d
[
N
];
constexpr
T
&
operator
[](
index_int
i
)
{
...
...
@@ -108,18 +154,13 @@ struct array
constexpr
T
dot
(
const
array
&
x
)
const
{
T
result
=
0
;
for
(
index_int
i
=
0
;
i
<
N
;
i
++
)
result
+=
x
[
i
]
*
d
[
i
];
return
result
;
auto
r
=
x
*
(
*
this
);
return
r
.
reduce
([](
auto
a
,
auto
b
)
{
return
a
+
b
;
},
0
);
}
constexpr
T
product
()
const
{
T
result
=
1
;
for
(
index_int
i
=
0
;
i
<
N
;
i
++
)
result
*=
d
[
i
];
return
result
;
return
reduce
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
1
);
}
constexpr
T
single
(
index_int
width
=
100
)
const
...
...
@@ -134,6 +175,24 @@ struct array
return
result
;
}
template
<
class
F
>
constexpr
auto
apply
(
F
f
)
const
{
array
<
decltype
(
f
(
d
[
0
])),
N
>
result
;
for
(
index_int
i
=
0
;
i
<
N
;
i
++
)
result
[
i
]
=
f
(
d
[
i
]);
return
result
;
}
template
<
class
F
>
constexpr
auto
reduce
(
F
f
,
T
init
)
const
{
T
result
=
init
;
for
(
index_int
i
=
0
;
i
<
N
;
i
++
)
result
=
f
(
result
,
d
[
i
]);
return
result
;
}
MIGRAPHX_DEVICE_ARRAY_OP
(
+=
,
+
)
MIGRAPHX_DEVICE_ARRAY_OP
(
-=
,
-
)
MIGRAPHX_DEVICE_ARRAY_OP
(
*=
,
*
)
...
...
@@ -201,6 +260,11 @@ struct array
}
};
template
<
class
T
,
class
...
Ts
>
constexpr
array
<
T
,
sizeof
...(
Ts
)
+
1
>
make_array
(
T
x
,
Ts
...
xs
)
{
return
{
x
,
static_cast
<
T
>
(
xs
)...};
}
template
<
class
T
,
T
...
Xs
>
struct
integral_const_array
:
array
<
T
,
sizeof
...(
Xs
)
>
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
View file @
2b936b13
...
...
@@ -28,9 +28,60 @@
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/debug.hpp>
namespace
migraphx
{
#if defined(MIGRAPHX_NGLOBAL) && defined(MIGRAPHX_NLOCAL)
#define MIGRAPHX_NGROUP ((MIGRAPHX_NGLOBAL + MIGRAPHX_NLOCAL - 1) / MIGRAPHX_NLOCAL)
#endif
inline
__device__
__attribute__
((
const
))
index_int
compute_global_size
()
{
#ifdef MIGRAPHX_NGLOBAL
return
MIGRAPHX_NGLOBAL
;
#else
// This actualy works even when global is not divisible by local size.
// This doesnt actually do a multiplicatiosn. Instead it calls a device
// function to get the global size, which is why it works.
return
blockDim
.
x
*
gridDim
.
x
;
// NOLINT
#endif
}
// We cant just use blockDim.x to get the local size since its broken on hip
// when global is not divisible by local size. In this case, we calulate the
// size for the last group.
inline
__device__
__attribute__
((
const
))
index_int
compute_local_size
()
{
#ifdef MIGRAPHX_NLOCAL
const
auto
nlocal
=
MIGRAPHX_NLOCAL
;
#else
const
auto
nlocal
=
blockDim
.
x
;
// NOLINT
#endif
#ifdef MIGRAPHX_NGROUP
const
auto
ngroup
=
MIGRAPHX_NGROUP
;
#else
const
auto
ngroup
=
gridDim
.
x
;
// NOLINT
#endif
const
auto
group_id
=
blockIdx
.
x
;
// NOLINT
const
auto
nglobal
=
compute_global_size
();
if
(
group_id
==
ngroup
-
1
)
{
return
1
+
(
nglobal
-
1
)
%
nlocal
;
}
else
{
return
nlocal
;
// NOLINT
}
}
#ifdef MIGRAPHX_NGROUP
// If global is divisible by local then local can be a const
#if(MIGRAPHX_NGLOBAL % MIGRAPHX_NLOCAL == 0) || (MIGRAPHX_NGROUP == 1)
#define MIGRAPHX_HAS_CONST_LOCAL 1
#endif
#endif
struct
index
{
index_int
global
=
0
;
...
...
@@ -38,20 +89,44 @@ struct index
index_int
group
=
0
;
#ifdef MIGRAPHX_NGLOBAL
constexpr
index_constant
<
MIGRAPHX_NGLOBAL
>
nglobal
()
const
{
return
{};
}
constexpr
index_constant
<
MIGRAPHX_NGLOBAL
>
nglobal
()
const
{
static_assert
(
MIGRAPHX_NGLOBAL
>
0
,
"Global size must be greater than 0"
);
return
{};
}
#else
__device__
index_int
nglobal
()
const
{
return
blockDim
.
x
*
gridDim
.
x
;
// NOLINT
MIGRAPHX_ASSERT
(
compute_global_size
()
>
0
);
return
compute_global_size
();
// NOLINT
}
#endif
#ifdef MIGRAPHX_NLOCAL
constexpr
index_constant
<
MIGRAPHX_NLOCAL
>
nlocal
()
const
{
return
{};
}
#ifdef MIGRAPHX_HAS_CONST_LOCAL
constexpr
index_constant
<
MIGRAPHX_NLOCAL
>
nlocal
()
const
{
static_assert
(
MIGRAPHX_NLOCAL
>
0
,
"Local size must be greater than 0"
);
return
{};
}
#else
__device__
index_int
nlocal
()
const
{
return
blockDim
.
x
;
// NOLINT
#ifdef MIGRAPHX_NGROUP
static_assert
((
MIGRAPHX_NGLOBAL
%
MIGRAPHX_NLOCAL
!=
0
)
and
(
MIGRAPHX_NGROUP
>
1
),
"Local size should be const"
);
#endif
MIGRAPHX_ASSERT
(
compute_local_size
()
>
0
);
return
compute_local_size
();
// NOLINT
}
#endif
#ifdef MIGRAPHX_NLOCAL
constexpr
index_constant
<
MIGRAPHX_NLOCAL
>
max_nlocal
()
const
{
return
{};
}
#else
__device__
index_int
max_nlocal
()
const
{
MIGRAPHX_ASSERT
(
blockDim
.
x
>
0
);
return
blockDim
.
x
;
}
#endif
template
<
class
N
,
class
Stride
>
...
...
@@ -63,6 +138,7 @@ struct index
template
<
class
F
,
class
N
,
class
Stride
>
static
constexpr
void
for_stride
(
index_int
start
,
N
n
,
Stride
stride
,
F
f
)
{
MIGRAPHX_ASSERT
(
start
<
stride
);
if
constexpr
(
not
is_integral
<
N
>
{}
and
not
is_integral
<
Stride
>
{}
and
max_stride_iterations
(
n
,
stride
)
==
1
)
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
View file @
2b936b13
...
...
@@ -29,6 +29,12 @@
namespace
migraphx
{
template
<
class
T
,
index_int
N
,
class
Op
>
constexpr
auto
vec_reduce
(
const
array
<
T
,
N
>&
a
,
Op
op
)
{
return
a
.
apply
([
&
](
auto
x
)
{
return
vec_reduce
(
x
,
op
);
});
}
template
<
index_int
Axis
,
class
F
,
class
BinOp
,
...
...
@@ -43,23 +49,21 @@ __device__ void generic_binary_layernorm(
reduce
::
block
::
run
<
reduce_output
>
([
&
](
auto
,
auto
r
)
{
using
value_type
=
typename
Input1
::
type
;
constexpr
auto
relements
=
r
.
template
elements
<
Input1
>();
auto
mean
=
[
&
](
auto
f
)
{
return
r
.
reduce
(
op
::
sum
{},
0
,
[
&
](
auto
x1
,
auto
x2
)
{
return
f
(
x1
,
x2
)
/
value_type
{
relements
};
auto
means
=
r
.
reduce
(
op
::
sum
{},
make_array
<
vec_type
<
value_type
>>
(
0
,
0
),
[
&
](
auto
x1
,
auto
x2
)
{
auto
x
=
op
(
x1
,
x2
);
return
make_array
(
x
,
x
*
x
)
*
vec_type
<
value_type
>
{
1.0
/
relements
};
})(
input1
,
input2
);
};
// mean(x)
auto
mean_x
=
mean
(
op
);
// mean(m ^ 2)
auto
mean_m2
=
mean
([
&
](
auto
x1
,
auto
x2
)
{
auto
m
=
op
(
x1
,
x2
)
-
mean_x
;
return
m
*
m
;
});
auto
mean_x
=
means
[
0
];
auto
mean_x2
=
means
[
1
];
auto
variance
=
mean_x2
-
(
mean_x
*
mean_x
);
r
.
inner
([
&
](
auto
&
y
,
auto
x1
,
auto
x2
,
auto
...
xs
)
{
auto
m
=
op
(
x1
,
x2
)
-
mean_x
;
auto
x
=
op
(
x1
,
x2
);
auto
m
=
x
-
mean_x
;
// m * rsqrt(mean(m ^ 2) + 1e-12)
y
=
compute
(
m
*
rsqrt
(
mean_m2
+
value_type
{
1e-12
}),
xs
...);
y
=
compute
(
m
*
rsqrt
(
variance
+
value_type
{
1e-12
}),
xs
...);
})(
output
,
input1
,
input2
,
inputs
...);
});
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
2b936b13
...
...
@@ -94,16 +94,17 @@ MIGRAPHX_DPP_REDUCE(op::max, v_max)
MIGRAPHX_DPP_REDUCE
(
op
::
min
,
v_min
)
MIGRAPHX_DPP_REDUCE
(
op
::
product
,
v_mul
)
template
<
class
Op
,
class
T
,
class
F
>
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
i
ndex
_int
n
,
F
f
)
template
<
class
Op
,
class
T
,
class
Index
,
class
F
>
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
I
ndex
n
,
F
f
)
{
MIGRAPHX_ASSERT
(
idx
.
max_nlocal
()
==
idx
.
nlocal
());
#if __AMDGCN_WAVEFRONT_SIZE == 32
constexpr
index_int
lanes_per_thread
=
16
;
#else
constexpr
index_int
lanes_per_thread
=
64
;
#endif
using
type
=
decltype
(
f
(
0
));
__shared__
type
buffer
[
idx
.
nlocal
()
/
lanes_per_thread
];
__shared__
type
buffer
[
idx
.
max_
nlocal
()
/
lanes_per_thread
];
type
x
=
init
;
idx
.
local_stride
(
n
,
[
&
](
auto
i
)
{
x
=
op
(
x
,
f
(
i
));
});
dpp_reduce
(
x
,
op
);
...
...
@@ -123,12 +124,12 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
return
y
;
}
#else
template
<
class
Op
,
class
T
,
class
F
>
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
i
ndex
_int
n
,
F
f
)
template
<
class
Op
,
class
T
,
class
Index
,
class
F
>
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
I
ndex
n
,
F
f
)
{
MIGRAPHX_ASSERT
(
idx
.
max_nlocal
()
==
idx
.
nlocal
());
using
type
=
decltype
(
f
(
0
));
__shared__
type
buffer
[
idx
.
nlocal
()];
__shared__
type
buffer
[
idx
.
max_
nlocal
()];
type
x
=
init
;
idx
.
local_stride
(
n
,
[
&
](
auto
i
)
{
x
=
op
(
x
,
f
(
i
));
});
buffer
[
idx
.
local
]
=
x
;
...
...
@@ -201,12 +202,9 @@ struct block
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
{
return
sliced
(
slicer
,
[
=
](
auto
x
,
auto
...
xs
)
{
return
vec_reduce
(
block_reduce
(
idx
,
op
,
init
,
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
return
read
(
x
[
j
],
xs
[
j
]...);
}),
op
);
return
block_reduce
(
idx
,
op
,
init
,
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
return
vec_reduce
(
read
(
x
[
j
],
xs
[
j
]...),
op
);
});
});
}
...
...
test/fuse_pointwise.cpp
View file @
2b936b13
...
...
@@ -21,7 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include
"
migraphx/dead_code_elimination.hpp
"
#include
<
migraphx/dead_code_elimination.hpp
>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
...
...
test/gpu/pack_int8_args.cpp
View file @
2b936b13
...
...
@@ -21,7 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include
"
migraphx/instruction_ref.hpp
"
#include
<
migraphx/instruction_ref.hpp
>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/target.hpp>
...
...
test/verify/test_conv_group_add.cpp
0 → 100644
View file @
2b936b13
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_conv_group_add
:
verify_program
<
test_conv_group_add
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
1
,
68
,
28
,
28
}};
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
w
=
mm
->
add_parameter
(
"w"
,
{
migraphx
::
shape
::
float_type
,
{
68
,
17
,
1
,
1
}});
auto
b
=
mm
->
add_parameter
(
"b"
,
{
migraphx
::
shape
::
float_type
,
{
68
}});
auto
conv
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
{{
"group"
,
4
}}),
x
,
w
);
auto
bb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
1
,
68
,
28
,
28
}}}),
b
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
conv
,
bb
);
return
p
;
}
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment