Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
65e14286
Commit
65e14286
authored
Sep 28, 2022
by
charlie
Browse files
Unary ops changes and tests
parent
30243d2c
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
471 additions
and
72 deletions
+471
-72
src/include/migraphx/op/elu.hpp
src/include/migraphx/op/elu.hpp
+6
-7
src/include/migraphx/op/leaky_relu.hpp
src/include/migraphx/op/leaky_relu.hpp
+5
-4
src/targets/ref/lowering.cpp
src/targets/ref/lowering.cpp
+0
-61
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+460
-0
No files found.
src/include/migraphx/op/elu.hpp
View file @
65e14286
...
...
@@ -32,21 +32,20 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
elu
struct
elu
:
unary
<
elu
>
{
std
::
string
name
()
const
{
return
"elu"
;
}
float
alpha
=
1
;
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
inputs
.
front
();
}
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
alpha
,
"alpha"
));
}
auto
apply
()
const
{
return
[
&
](
auto
x
)
{
return
x
>
0
?
x
:
alpha
*
std
::
expm1
(
x
);
};
}
};
}
// namespace op
...
...
src/include/migraphx/op/leaky_relu.hpp
View file @
65e14286
...
...
@@ -26,12 +26,13 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/unary.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
leaky_relu
struct
leaky_relu
:
unary
<
leaky_relu
>
{
float
alpha
=
0.01
;
...
...
@@ -42,10 +43,10 @@ struct leaky_relu
}
std
::
string
name
()
const
{
return
"leaky_relu"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
auto
apply
()
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
inputs
.
front
();
return
[
&
](
auto
x
)
{
return
x
>
0
?
x
:
x
*
alpha
;
};
}
};
...
...
src/targets/ref/lowering.cpp
View file @
65e14286
...
...
@@ -507,65 +507,6 @@ struct ref_quant_gemm
};
MIGRAPHX_REGISTER_OP
(
ref_gemm
)
struct
leaky_relu_op
{
op
::
leaky_relu
op
;
std
::
string
name
()
const
{
return
"ref::leaky_relu"
;
}
auto
fcn
()
const
{
auto
a
=
op
.
alpha
;
return
[
a
](
auto
x
)
{
return
x
>
0
?
x
:
x
*
a
;
};
}
};
struct
elu_op
{
op
::
elu
op
;
std
::
string
name
()
const
{
return
"ref::elu"
;
}
auto
fcn
()
const
{
auto
a
=
op
.
alpha
;
return
[
a
](
auto
x
)
{
return
x
>
0
?
x
:
a
*
std
::
expm1
(
x
);
};
}
};
template
<
typename
Op
>
struct
ref_unary
:
auto_register_op
<
ref_unary
<
Op
>>
{
ref_unary
()
=
default
;
template
<
class
T
>
ref_unary
(
T
pop
)
:
op
(
Op
{
std
::
move
(
pop
)})
{
}
Op
op
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
migraphx
::
reflect
(
self
.
op
.
op
,
f
);
}
std
::
string
name
()
const
{
return
op
.
name
();
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
const
auto
&
s
=
inputs
.
at
(
0
);
return
{
s
.
type
(),
s
.
lens
()};
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
assert
(
input
.
get_shape
().
standard
());
std
::
transform
(
input
.
begin
(),
input
.
end
(),
output
.
begin
(),
op
.
fcn
());
});
return
result
;
}
};
template
<
class
Op
>
struct
ref_softmax
:
auto_register_op
<
ref_softmax
<
Op
>>
{
...
...
@@ -708,9 +649,7 @@ struct ref_apply
apply_map
[
"quant_dot"
]
=
extend_op
<
ref_quant_gemm
,
op
::
quant_dot
>
();
apply_map
[
"quant_convolution"
]
=
extend_op
<
ref_convolution
<
op
::
quant_convolution
>
,
op
::
quant_convolution
>
();
apply_map
[
"elu"
]
=
extend_op
<
ref_unary
<
elu_op
>
,
op
::
elu
>
();
apply_map
[
"im2col"
]
=
extend_op
<
ref_im2col
,
op
::
im2col
>
();
apply_map
[
"leaky_relu"
]
=
extend_op
<
ref_unary
<
leaky_relu_op
>
,
op
::
leaky_relu
>
();
apply_map
[
"logsoftmax"
]
=
extend_op
<
ref_softmax
<
op
::
logsoftmax
>
,
op
::
logsoftmax
>
();
apply_map
[
"lrn"
]
=
extend_op
<
ref_lrn
,
op
::
lrn
>
();
apply_map
[
"pad"
]
=
extend_op
<
ref_pad
,
op
::
pad
>
();
...
...
test/ref_ops_test.cpp
View file @
65e14286
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment