Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
99d1fed4
"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "8f827641b0bf80f7a804f726b8e5342c68d4d7f4"
Commit
99d1fed4
authored
Jun 20, 2019
by
Shucai Xiao
Browse files
add cpu implmentations of the argmax and argmin operators.
parent
66bae091
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
192 additions
and
0 deletions
+192
-0
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+56
-0
src/include/migraphx/op/argmin.hpp
src/include/migraphx/op/argmin.hpp
+56
-0
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+2
-0
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+78
-0
No files found.
src/include/migraphx/op/argmax.hpp
0 → 100644
View file @
99d1fed4
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
argmax
{
int
axis
=
0
;
int
keep_dims
=
1
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axis
,
"axis"
),
f
(
self
.
keep_dims
,
"keep_dims"
));
}
std
::
string
name
()
const
{
return
"argmax"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
lens
=
inputs
[
0
].
lens
();
int
n_dim
=
static_cast
<
int
>
(
lens
.
size
());
if
(
axis
>=
n_dim
||
axis
<
0
)
{
MIGRAPHX_THROW
(
"ARGMAX: axis is out of range."
);
}
lens
[
axis
]
=
1
;
if
(
!
keep_dims
)
{
lens
.
erase
(
lens
.
begin
()
+
axis
);
}
return
{
shape
::
int64_type
,
lens
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/argmin.hpp
0 → 100644
View file @
99d1fed4
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
argmin
{
int
axis
=
0
;
int
keep_dims
=
1
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axis
,
"axis"
),
f
(
self
.
keep_dims
,
"keep_dims"
));
}
std
::
string
name
()
const
{
return
"argmin"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
lens
=
inputs
[
0
].
lens
();
int
n_dim
=
static_cast
<
int
>
(
lens
.
size
());
if
(
axis
>=
n_dim
||
axis
<
0
)
{
MIGRAPHX_THROW
(
"ARGMIN: axis is out of range."
);
}
lens
[
axis
]
=
1
;
if
(
!
keep_dims
)
{
lens
.
erase
(
lens
.
begin
()
+
axis
);
}
return
{
shape
::
int64_type
,
lens
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/operators.hpp
View file @
99d1fed4
...
@@ -5,6 +5,8 @@
...
@@ -5,6 +5,8 @@
#include <migraphx/op/abs.hpp>
#include <migraphx/op/abs.hpp>
#include <migraphx/op/acos.hpp>
#include <migraphx/op/acos.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/op/asin.hpp>
#include <migraphx/op/asin.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/atan.hpp>
#include <migraphx/op/atan.hpp>
...
...
src/targets/cpu/lowering.cpp
View file @
99d1fed4
...
@@ -637,6 +637,82 @@ struct cpu_logsoftmax
...
@@ -637,6 +637,82 @@ struct cpu_logsoftmax
}
}
};
};
struct
cpu_argmax
{
op
::
argmax
op
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
migraphx
::
reflect
(
self
.
op
,
f
);
}
std
::
string
name
()
const
{
return
"cpu::argmax"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
using
value_type
=
batch_max
(
output_shape
.
elements
(),
std
::
numeric_limits
<
value_type
>::
lowest
());
auto
data_shape
=
args
[
0
].
get_shape
();
shape_for_each
(
data_shape
,
[
&
](
auto
idx
)
{
auto
data_index
=
data_shape
.
index
(
idx
);
idx
[
axis
]
=
0
;
auto
out_index
=
data_shape
.
index
(
idx
);
if
(
batch_max
[
index
]
<
input
[
data_index
])
{
batch_max
[
index
]
=
input
[
data_index
];
output
[
index
]
=
static_cast
<
int64_t
>
(
data_index
);
}
});
});
});
return
result
;
}
};
struct
cpu_argmin
{
op
::
argmin
op
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
migraphx
::
reflect
(
self
.
op
,
f
);
}
std
::
string
name
()
const
{
return
"cpu::argmin"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
using
value_type
=
batch_min
(
output_shape
.
elements
(),
std
::
numeric_limits
<
value_type
>::
max
());
auto
data_shape
=
args
[
0
].
get_shape
();
shape_for_each
(
data_shape
,
[
&
](
auto
idx
)
{
auto
data_index
=
data_shape
.
index
(
idx
);
idx
[
axis
]
=
0
;
auto
out_index
=
data_shape
.
index
(
idx
);
if
(
batch_min
[
index
]
>
input
[
data_index
])
{
batch_min
[
index
]
=
input
[
data_index
];
output
[
index
]
=
static_cast
<
int64_t
>
(
data_index
);
}
});
});
});
return
result
;
}
};
struct
cpu_apply
struct
cpu_apply
{
{
program
*
prog
;
program
*
prog
;
...
@@ -656,6 +732,8 @@ struct cpu_apply
...
@@ -656,6 +732,8 @@ struct cpu_apply
void
init
()
void
init
()
{
{
apply_map
[
"argmax"
]
=
extend_op
<
cpu_argmax
,
op
::
argmax
>
();
apply_map
[
"argmin"
]
=
extend_op
<
cpu_argmin
,
op
::
argmin
>
();
apply_map
[
"batch_norm_inference"
]
=
apply_map
[
"batch_norm_inference"
]
=
extend_op
<
cpu_batch_norm_inference
,
op
::
batch_norm_inference
>
();
extend_op
<
cpu_batch_norm_inference
,
op
::
batch_norm_inference
>
();
apply_map
[
"convolution"
]
=
extend_op
<
cpu_convolution
,
op
::
convolution
>
();
apply_map
[
"convolution"
]
=
extend_op
<
cpu_convolution
,
op
::
convolution
>
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment