Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
99d1fed4
Commit
99d1fed4
authored
Jun 20, 2019
by
Shucai Xiao
Browse files
add cpu implmentations of the argmax and argmin operators.
parent
66bae091
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
192 additions
and
0 deletions
+192
-0
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+56
-0
src/include/migraphx/op/argmin.hpp
src/include/migraphx/op/argmin.hpp
+56
-0
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+2
-0
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+78
-0
No files found.
src/include/migraphx/op/argmax.hpp
0 → 100644
View file @
99d1fed4
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
argmax
{
int
axis
=
0
;
int
keep_dims
=
1
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axis
,
"axis"
),
f
(
self
.
keep_dims
,
"keep_dims"
));
}
std
::
string
name
()
const
{
return
"argmax"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
lens
=
inputs
[
0
].
lens
();
int
n_dim
=
static_cast
<
int
>
(
lens
.
size
());
if
(
axis
>=
n_dim
||
axis
<
0
)
{
MIGRAPHX_THROW
(
"ARGMAX: axis is out of range."
);
}
lens
[
axis
]
=
1
;
if
(
!
keep_dims
)
{
lens
.
erase
(
lens
.
begin
()
+
axis
);
}
return
{
shape
::
int64_type
,
lens
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/argmin.hpp
0 → 100644
View file @
99d1fed4
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
argmin
{
int
axis
=
0
;
int
keep_dims
=
1
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axis
,
"axis"
),
f
(
self
.
keep_dims
,
"keep_dims"
));
}
std
::
string
name
()
const
{
return
"argmin"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
lens
=
inputs
[
0
].
lens
();
int
n_dim
=
static_cast
<
int
>
(
lens
.
size
());
if
(
axis
>=
n_dim
||
axis
<
0
)
{
MIGRAPHX_THROW
(
"ARGMIN: axis is out of range."
);
}
lens
[
axis
]
=
1
;
if
(
!
keep_dims
)
{
lens
.
erase
(
lens
.
begin
()
+
axis
);
}
return
{
shape
::
int64_type
,
lens
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/operators.hpp
View file @
99d1fed4
...
...
@@ -5,6 +5,8 @@
#include <migraphx/op/abs.hpp>
#include <migraphx/op/acos.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/op/asin.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/atan.hpp>
...
...
src/targets/cpu/lowering.cpp
View file @
99d1fed4
...
...
@@ -637,6 +637,82 @@ struct cpu_logsoftmax
}
};
struct
cpu_argmax
{
op
::
argmax
op
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
migraphx
::
reflect
(
self
.
op
,
f
);
}
std
::
string
name
()
const
{
return
"cpu::argmax"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
using
value_type
=
batch_max
(
output_shape
.
elements
(),
std
::
numeric_limits
<
value_type
>::
lowest
());
auto
data_shape
=
args
[
0
].
get_shape
();
shape_for_each
(
data_shape
,
[
&
](
auto
idx
)
{
auto
data_index
=
data_shape
.
index
(
idx
);
idx
[
axis
]
=
0
;
auto
out_index
=
data_shape
.
index
(
idx
);
if
(
batch_max
[
index
]
<
input
[
data_index
])
{
batch_max
[
index
]
=
input
[
data_index
];
output
[
index
]
=
static_cast
<
int64_t
>
(
data_index
);
}
});
});
});
return
result
;
}
};
struct
cpu_argmin
{
op
::
argmin
op
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
migraphx
::
reflect
(
self
.
op
,
f
);
}
std
::
string
name
()
const
{
return
"cpu::argmin"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
using
value_type
=
batch_min
(
output_shape
.
elements
(),
std
::
numeric_limits
<
value_type
>::
max
());
auto
data_shape
=
args
[
0
].
get_shape
();
shape_for_each
(
data_shape
,
[
&
](
auto
idx
)
{
auto
data_index
=
data_shape
.
index
(
idx
);
idx
[
axis
]
=
0
;
auto
out_index
=
data_shape
.
index
(
idx
);
if
(
batch_min
[
index
]
>
input
[
data_index
])
{
batch_min
[
index
]
=
input
[
data_index
];
output
[
index
]
=
static_cast
<
int64_t
>
(
data_index
);
}
});
});
});
return
result
;
}
};
struct
cpu_apply
{
program
*
prog
;
...
...
@@ -656,6 +732,8 @@ struct cpu_apply
void
init
()
{
apply_map
[
"argmax"
]
=
extend_op
<
cpu_argmax
,
op
::
argmax
>
();
apply_map
[
"argmin"
]
=
extend_op
<
cpu_argmin
,
op
::
argmin
>
();
apply_map
[
"batch_norm_inference"
]
=
extend_op
<
cpu_batch_norm_inference
,
op
::
batch_norm_inference
>
();
apply_map
[
"convolution"
]
=
extend_op
<
cpu_convolution
,
op
::
convolution
>
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment