Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e69e915b
Commit
e69e915b
authored
Jul 03, 2019
by
Shucai Xiao
Browse files
change axis attribute from int to int64_t
parent
ee46bc9f
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
15 additions
and
20 deletions
+15
-20
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+2
-2
src/include/migraphx/op/argmin.hpp
src/include/migraphx/op/argmin.hpp
+2
-7
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+6
-6
src/targets/gpu/device/argmax.cpp
src/targets/gpu/device/argmax.cpp
+1
-1
src/targets/gpu/device/argmin.cpp
src/targets/gpu/device/argmin.cpp
+1
-1
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
+1
-1
src/targets/gpu/include/migraphx/gpu/device/argmax.hpp
src/targets/gpu/include/migraphx/gpu/device/argmax.hpp
+1
-1
src/targets/gpu/include/migraphx/gpu/device/argmin.hpp
src/targets/gpu/include/migraphx/gpu/device/argmin.hpp
+1
-1
No files found.
src/include/migraphx/op/argmax.hpp
View file @
e69e915b
...
@@ -12,7 +12,7 @@ namespace op {
...
@@ -12,7 +12,7 @@ namespace op {
struct
argmax
struct
argmax
{
{
int
axis
=
0
;
int
64_t
axis
=
0
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -26,7 +26,7 @@ struct argmax
...
@@ -26,7 +26,7 @@ struct argmax
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
lens
=
inputs
[
0
].
lens
();
auto
lens
=
inputs
[
0
].
lens
();
int
n_dim
=
static_cast
<
int
>
(
lens
.
size
());
int
64_t
n_dim
=
static_cast
<
int
64_t
>
(
lens
.
size
());
if
(
axis
>=
n_dim
||
axis
<
0
)
if
(
axis
>=
n_dim
||
axis
<
0
)
{
{
MIGRAPHX_THROW
(
"ARGMAX: axis is out of range."
);
MIGRAPHX_THROW
(
"ARGMAX: axis is out of range."
);
...
...
src/include/migraphx/op/argmin.hpp
View file @
e69e915b
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
//#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
//#include <migraphx/stringutils.hpp>
//#include <migraphx/literal.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
//#include <cmath>
//#include <utility>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -17,7 +12,7 @@ namespace op {
...
@@ -17,7 +12,7 @@ namespace op {
struct
argmin
struct
argmin
{
{
int
axis
=
0
;
int
64_t
axis
=
0
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -31,7 +26,7 @@ struct argmin
...
@@ -31,7 +26,7 @@ struct argmin
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
lens
=
inputs
[
0
].
lens
();
auto
lens
=
inputs
[
0
].
lens
();
int
n_dim
=
static_cast
<
int
>
(
lens
.
size
());
int
64_t
n_dim
=
static_cast
<
int
64_t
>
(
lens
.
size
());
if
(
axis
>=
n_dim
||
axis
<
0
)
if
(
axis
>=
n_dim
||
axis
<
0
)
{
{
MIGRAPHX_THROW
(
"ARGMIN: axis is out of range."
);
MIGRAPHX_THROW
(
"ARGMIN: axis is out of range."
);
...
...
src/onnx/onnx.cpp
View file @
e69e915b
...
@@ -273,10 +273,10 @@ struct onnx_parser
...
@@ -273,10 +273,10 @@ struct onnx_parser
const
attribute_map
&
attributes
,
const
attribute_map
&
attributes
,
std
::
vector
<
instruction_ref
>
args
)
std
::
vector
<
instruction_ref
>
args
)
{
{
int
axis
=
0
;
int
64_t
axis
=
0
;
if
(
contains
(
attributes
,
"axis"
))
if
(
contains
(
attributes
,
"axis"
))
{
{
axis
=
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
int
>
();
axis
=
static_cast
<
int64_t
>
(
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
int
>
()
)
;
}
}
int
keep_dims
=
1
;
int
keep_dims
=
1
;
...
@@ -288,7 +288,7 @@ struct onnx_parser
...
@@ -288,7 +288,7 @@ struct onnx_parser
if
(
keep_dims
==
0
)
if
(
keep_dims
==
0
)
{
{
auto
ins
=
prog
.
add_instruction
(
op
::
argmax
{
axis
},
std
::
move
(
args
));
auto
ins
=
prog
.
add_instruction
(
op
::
argmax
{
axis
},
std
::
move
(
args
));
return
prog
.
add_instruction
(
op
::
squeeze
{{
static_cast
<
int64_t
>
(
axis
)
}},
ins
);
return
prog
.
add_instruction
(
op
::
squeeze
{{
axis
}},
ins
);
}
}
else
else
{
{
...
@@ -300,10 +300,10 @@ struct onnx_parser
...
@@ -300,10 +300,10 @@ struct onnx_parser
const
attribute_map
&
attributes
,
const
attribute_map
&
attributes
,
std
::
vector
<
instruction_ref
>
args
)
std
::
vector
<
instruction_ref
>
args
)
{
{
int
axis
=
0
;
int
64_t
axis
=
0
;
if
(
contains
(
attributes
,
"axis"
))
if
(
contains
(
attributes
,
"axis"
))
{
{
axis
=
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
int
>
();
axis
=
static_cast
<
int64_t
>
(
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
int
>
()
)
;
}
}
int
keep_dims
=
1
;
int
keep_dims
=
1
;
...
@@ -315,7 +315,7 @@ struct onnx_parser
...
@@ -315,7 +315,7 @@ struct onnx_parser
if
(
keep_dims
==
0
)
if
(
keep_dims
==
0
)
{
{
auto
ins
=
prog
.
add_instruction
(
op
::
argmin
{
axis
},
std
::
move
(
args
));
auto
ins
=
prog
.
add_instruction
(
op
::
argmin
{
axis
},
std
::
move
(
args
));
return
prog
.
add_instruction
(
op
::
squeeze
{{
static_cast
<
int64_t
>
(
axis
)
}},
ins
);
return
prog
.
add_instruction
(
op
::
squeeze
{{
axis
}},
ins
);
}
}
else
else
{
{
...
...
src/targets/gpu/device/argmax.cpp
View file @
e69e915b
...
@@ -12,7 +12,7 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -12,7 +12,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
argmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
void
argmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
64_t
axis
)
{
{
arg_op
(
argmax_op
{},
stream
,
result
,
arg
,
axis
);
arg_op
(
argmax_op
{},
stream
,
result
,
arg
,
axis
);
}
}
...
...
src/targets/gpu/device/argmin.cpp
View file @
e69e915b
...
@@ -12,7 +12,7 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -12,7 +12,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
argmin
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
void
argmin
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
64_t
axis
)
{
{
arg_op
(
argmin_op
{},
stream
,
result
,
arg
,
axis
);
arg_op
(
argmin_op
{},
stream
,
result
,
arg
,
axis
);
}
}
...
...
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
View file @
e69e915b
...
@@ -70,7 +70,7 @@ struct argmin_op
...
@@ -70,7 +70,7 @@ struct argmin_op
};
};
template
<
class
Op
>
template
<
class
Op
>
void
arg_op
(
Op
op
,
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
void
arg_op
(
Op
op
,
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
64_t
axis
)
{
{
auto
arg_shape
=
arg
.
get_shape
();
auto
arg_shape
=
arg
.
get_shape
();
auto
lens
=
arg_shape
.
lens
();
auto
lens
=
arg_shape
.
lens
();
...
...
src/targets/gpu/include/migraphx/gpu/device/argmax.hpp
View file @
e69e915b
...
@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
argmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
);
void
argmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
64_t
axis
);
}
// namespace device
}
// namespace device
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/include/migraphx/gpu/device/argmin.hpp
View file @
e69e915b
...
@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
argmin
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
);
void
argmin
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
64_t
axis
);
}
// namespace device
}
// namespace device
}
// namespace gpu
}
// namespace gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment