Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
65e14286
Commit
65e14286
authored
Sep 28, 2022
by
charlie
Browse files
Unary ops changes and tests
parent
30243d2c
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
471 additions
and
72 deletions
+471
-72
src/include/migraphx/op/elu.hpp
src/include/migraphx/op/elu.hpp
+6
-7
src/include/migraphx/op/leaky_relu.hpp
src/include/migraphx/op/leaky_relu.hpp
+5
-4
src/targets/ref/lowering.cpp
src/targets/ref/lowering.cpp
+0
-61
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+460
-0
No files found.
src/include/migraphx/op/elu.hpp
View file @
65e14286
...
...
@@ -32,21 +32,20 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
elu
struct
elu
:
unary
<
elu
>
{
std
::
string
name
()
const
{
return
"elu"
;
}
float
alpha
=
1
;
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
inputs
.
front
();
}
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
alpha
,
"alpha"
));
}
auto
apply
()
const
{
return
[
&
](
auto
x
)
{
return
x
>
0
?
x
:
alpha
*
std
::
expm1
(
x
);
};
}
};
}
// namespace op
...
...
src/include/migraphx/op/leaky_relu.hpp
View file @
65e14286
...
...
@@ -26,12 +26,13 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/unary.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
leaky_relu
struct
leaky_relu
:
unary
<
leaky_relu
>
{
float
alpha
=
0.01
;
...
...
@@ -42,10 +43,10 @@ struct leaky_relu
}
std
::
string
name
()
const
{
return
"leaky_relu"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
auto
apply
()
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
inputs
.
front
();
return
[
&
](
auto
x
)
{
return
x
>
0
?
x
:
x
*
alpha
;
};
}
};
...
...
src/targets/ref/lowering.cpp
View file @
65e14286
...
...
@@ -507,65 +507,6 @@ struct ref_quant_gemm
};
MIGRAPHX_REGISTER_OP
(
ref_gemm
)
struct
leaky_relu_op
{
op
::
leaky_relu
op
;
std
::
string
name
()
const
{
return
"ref::leaky_relu"
;
}
auto
fcn
()
const
{
auto
a
=
op
.
alpha
;
return
[
a
](
auto
x
)
{
return
x
>
0
?
x
:
x
*
a
;
};
}
};
struct
elu_op
{
op
::
elu
op
;
std
::
string
name
()
const
{
return
"ref::elu"
;
}
auto
fcn
()
const
{
auto
a
=
op
.
alpha
;
return
[
a
](
auto
x
)
{
return
x
>
0
?
x
:
a
*
std
::
expm1
(
x
);
};
}
};
template
<
typename
Op
>
struct
ref_unary
:
auto_register_op
<
ref_unary
<
Op
>>
{
ref_unary
()
=
default
;
template
<
class
T
>
ref_unary
(
T
pop
)
:
op
(
Op
{
std
::
move
(
pop
)})
{
}
Op
op
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
migraphx
::
reflect
(
self
.
op
.
op
,
f
);
}
std
::
string
name
()
const
{
return
op
.
name
();
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
const
auto
&
s
=
inputs
.
at
(
0
);
return
{
s
.
type
(),
s
.
lens
()};
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
assert
(
input
.
get_shape
().
standard
());
std
::
transform
(
input
.
begin
(),
input
.
end
(),
output
.
begin
(),
op
.
fcn
());
});
return
result
;
}
};
template
<
class
Op
>
struct
ref_softmax
:
auto_register_op
<
ref_softmax
<
Op
>>
{
...
...
@@ -708,9 +649,7 @@ struct ref_apply
apply_map
[
"quant_dot"
]
=
extend_op
<
ref_quant_gemm
,
op
::
quant_dot
>
();
apply_map
[
"quant_convolution"
]
=
extend_op
<
ref_convolution
<
op
::
quant_convolution
>
,
op
::
quant_convolution
>
();
apply_map
[
"elu"
]
=
extend_op
<
ref_unary
<
elu_op
>
,
op
::
elu
>
();
apply_map
[
"im2col"
]
=
extend_op
<
ref_im2col
,
op
::
im2col
>
();
apply_map
[
"leaky_relu"
]
=
extend_op
<
ref_unary
<
leaky_relu_op
>
,
op
::
leaky_relu
>
();
apply_map
[
"logsoftmax"
]
=
extend_op
<
ref_softmax
<
op
::
logsoftmax
>
,
op
::
logsoftmax
>
();
apply_map
[
"lrn"
]
=
extend_op
<
ref_lrn
,
op
::
lrn
>
();
apply_map
[
"pad"
]
=
extend_op
<
ref_pad
,
op
::
pad
>
();
...
...
test/ref_ops_test.cpp
View file @
65e14286
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment