Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
cfd36b63
Commit
cfd36b63
authored
Jul 09, 2019
by
Paul
Browse files
Merge branch 'propogate-broadcast' into multiply-add
parents
db3f1478
d5685340
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
609 additions
and
18 deletions
+609
-18
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+81
-0
src/include/migraphx/op/argmin.hpp
src/include/migraphx/op/argmin.hpp
+81
-0
src/include/migraphx/op/erf.hpp
src/include/migraphx/op/erf.hpp
+23
-0
src/include/migraphx/op/logsoftmax.hpp
src/include/migraphx/op/logsoftmax.hpp
+0
-7
src/include/migraphx/op/softmax.hpp
src/include/migraphx/op/softmax.hpp
+0
-7
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+3
-0
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+92
-0
src/propagate_constant.cpp
src/propagate_constant.cpp
+3
-3
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+2
-0
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+5
-0
src/targets/gpu/argmax.cpp
src/targets/gpu/argmax.cpp
+23
-0
src/targets/gpu/argmin.cpp
src/targets/gpu/argmin.cpp
+23
-0
src/targets/gpu/device/argmax.cpp
src/targets/gpu/device/argmax.cpp
+23
-0
src/targets/gpu/device/argmin.cpp
src/targets/gpu/device/argmin.cpp
+23
-0
src/targets/gpu/device/erf.cpp
src/targets/gpu/device/erf.cpp
+18
-0
src/targets/gpu/device/include/migraphx/gpu/device/reduce.hpp
...targets/gpu/device/include/migraphx/gpu/device/reduce.hpp
+1
-1
src/targets/gpu/include/migraphx/gpu/argmax.hpp
src/targets/gpu/include/migraphx/gpu/argmax.hpp
+37
-0
src/targets/gpu/include/migraphx/gpu/argmin.hpp
src/targets/gpu/include/migraphx/gpu/argmin.hpp
+37
-0
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
+114
-0
src/targets/gpu/include/migraphx/gpu/device/argmax.hpp
src/targets/gpu/include/migraphx/gpu/device/argmax.hpp
+20
-0
No files found.
src/include/migraphx/op/argmax.hpp
0 → 100644
View file @
cfd36b63
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
argmax
{
int64_t
axis
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axis
,
"axis"
));
}
std
::
string
name
()
const
{
return
"argmax"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
lens
=
inputs
[
0
].
lens
();
int64_t
n_dim
=
static_cast
<
int64_t
>
(
lens
.
size
());
if
(
axis
>=
n_dim
||
axis
<
0
)
{
MIGRAPHX_THROW
(
"ARGMAX: axis is out of range."
);
}
lens
[
axis
]
=
1
;
return
{
shape
::
int64_type
,
lens
};
}
template
<
class
T
>
int64_t
calc_argmax
(
T
&
input
,
std
::
vector
<
std
::
size_t
>&
indices
,
size_t
item_num
)
const
{
auto
max_val
=
input
(
indices
.
begin
(),
indices
.
end
());
int64_t
max_index
=
0
;
for
(
std
::
size_t
i
=
1
;
i
<
item_num
;
++
i
)
{
indices
[
axis
]
=
i
;
auto
cur_val
=
input
(
indices
.
begin
(),
indices
.
end
());
if
(
max_val
<
cur_val
)
{
max_val
=
cur_val
;
max_index
=
i
;
}
}
return
max_index
;
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
auto
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
axis
];
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
data_idx
=
output_shape
.
multi
(
i
);
output
[
i
]
=
this
->
calc_argmax
(
input
,
data_idx
,
batch_item_num
);
});
});
});
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/argmin.hpp
0 → 100644
View file @
cfd36b63
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
argmin
{
int64_t
axis
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axis
,
"axis"
));
}
std
::
string
name
()
const
{
return
"argmin"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
lens
=
inputs
[
0
].
lens
();
int64_t
n_dim
=
static_cast
<
int64_t
>
(
lens
.
size
());
if
(
axis
>=
n_dim
||
axis
<
0
)
{
MIGRAPHX_THROW
(
"ARGMIN: axis is out of range."
);
}
lens
[
axis
]
=
1
;
return
{
shape
::
int64_type
,
lens
};
}
template
<
class
T
>
int64_t
calc_argmin
(
T
&
input
,
std
::
vector
<
std
::
size_t
>&
indices
,
size_t
item_num
)
const
{
auto
min_val
=
input
(
indices
.
begin
(),
indices
.
end
());
int64_t
min_index
=
0
;
for
(
std
::
size_t
i
=
1
;
i
<
item_num
;
++
i
)
{
indices
[
axis
]
=
i
;
auto
cur_val
=
input
(
indices
.
begin
(),
indices
.
end
());
if
(
min_val
>
cur_val
)
{
min_val
=
cur_val
;
min_index
=
i
;
}
}
return
min_index
;
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
std
::
size_t
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
axis
];
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
data_idx
=
output_shape
.
multi
(
i
);
output
[
i
]
=
this
->
calc_argmin
(
input
,
data_idx
,
batch_item_num
);
});
});
});
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/erf.hpp
0 → 100644
View file @
cfd36b63
#ifndef MIGRAPHX_GUARD_OPERATORS_ERF_HPP
#define MIGRAPHX_GUARD_OPERATORS_ERF_HPP
#include <migraphx/op/unary.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
erf
:
unary
<
erf
>
{
auto
apply
()
const
{
return
[](
auto
x
)
{
return
std
::
erf
(
x
);
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/logsoftmax.hpp
View file @
cfd36b63
#ifndef MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/softmax.hpp
View file @
cfd36b63
#ifndef MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/operators.hpp
View file @
cfd36b63
...
...
@@ -5,6 +5,8 @@
#include <migraphx/op/abs.hpp>
#include <migraphx/op/acos.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/op/asin.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/atan.hpp>
...
...
@@ -22,6 +24,7 @@
#include <migraphx/op/div.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/elu.hpp>
#include <migraphx/op/erf.hpp>
#include <migraphx/op/exp.hpp>
#include <migraphx/op/flatten.hpp>
#include <migraphx/op/gather.hpp>
...
...
src/onnx/onnx.cpp
View file @
cfd36b63
...
...
@@ -40,6 +40,7 @@ struct onnx_parser
add_generic_op
(
"Sigmoid"
,
op
::
sigmoid
{});
add_generic_op
(
"Abs"
,
op
::
abs
{});
add_generic_op
(
"Exp"
,
op
::
exp
{});
add_generic_op
(
"Erf"
,
op
::
erf
{});
add_generic_op
(
"Log"
,
op
::
log
{});
// disable dropout for inference
add_generic_op
(
"Dropout"
,
op
::
identity
{});
...
...
@@ -63,6 +64,8 @@ struct onnx_parser
add_variadic_op
(
"Max"
,
op
::
max
{});
add_variadic_op
(
"Min"
,
op
::
min
{});
add_mem_op
(
"ArgMax"
,
&
onnx_parser
::
parse_argmax
);
add_mem_op
(
"ArgMin"
,
&
onnx_parser
::
parse_argmin
);
add_mem_op
(
"Clip"
,
&
onnx_parser
::
parse_clip
);
add_mem_op
(
"LRN"
,
&
onnx_parser
::
parse_lrn
);
add_mem_op
(
"ImageScaler"
,
&
onnx_parser
::
parse_imagescaler
);
...
...
@@ -93,6 +96,7 @@ struct onnx_parser
add_mem_op
(
"GRU"
,
&
onnx_parser
::
parse_gru
);
add_mem_op
(
"LSTM"
,
&
onnx_parser
::
parse_lstm
);
add_mem_op
(
"Pad"
,
&
onnx_parser
::
parse_pad
);
add_mem_op
(
"ReduceSum"
,
&
onnx_parser
::
parse_reduce_sum
);
// init the activation function map
init_actv_func
();
...
...
@@ -274,6 +278,60 @@ struct onnx_parser
return
prog
.
add_instruction
(
op
::
logsoftmax
{
axis
},
std
::
move
(
args
));
}
instruction_ref
parse_argmax
(
const
std
::
string
&
,
const
attribute_map
&
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
int64_t
axis
=
0
;
if
(
contains
(
attributes
,
"axis"
))
{
axis
=
static_cast
<
int64_t
>
(
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
int
>
());
}
int
keep_dims
=
1
;
if
(
contains
(
attributes
,
"keepdims"
))
{
keep_dims
=
parse_value
(
attributes
.
at
(
"keepdims"
)).
at
<
int
>
();
}
if
(
keep_dims
==
0
)
{
auto
ins
=
prog
.
add_instruction
(
op
::
argmax
{
axis
},
std
::
move
(
args
));
return
prog
.
add_instruction
(
op
::
squeeze
{{
axis
}},
ins
);
}
else
{
return
prog
.
add_instruction
(
op
::
argmax
{
axis
},
std
::
move
(
args
));
}
}
instruction_ref
parse_argmin
(
const
std
::
string
&
,
const
attribute_map
&
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
int64_t
axis
=
0
;
if
(
contains
(
attributes
,
"axis"
))
{
axis
=
static_cast
<
int64_t
>
(
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
int
>
());
}
int
keep_dims
=
1
;
if
(
contains
(
attributes
,
"keepdims"
))
{
keep_dims
=
parse_value
(
attributes
.
at
(
"keepdims"
)).
at
<
int
>
();
}
if
(
keep_dims
==
0
)
{
auto
ins
=
prog
.
add_instruction
(
op
::
argmin
{
axis
},
std
::
move
(
args
));
return
prog
.
add_instruction
(
op
::
squeeze
{{
axis
}},
ins
);
}
else
{
return
prog
.
add_instruction
(
op
::
argmin
{
axis
},
std
::
move
(
args
));
}
}
instruction_ref
parse_conv
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
...
...
@@ -1230,6 +1288,40 @@ struct onnx_parser
return
{
hidden_states
,
last_output
,
last_cell_output
};
}
instruction_ref
parse_reduce_sum
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
std
::
size_t
n_dim
=
args
.
front
()
->
get_shape
().
lens
().
size
();
// default to reduce over all dimensions
std
::
vector
<
std
::
size_t
>
axes
(
n_dim
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
if
(
contains
(
attributes
,
"axes"
))
{
axes
.
clear
();
auto
&&
attr_axes
=
attributes
[
"axes"
].
ints
();
axes
=
std
::
vector
<
std
::
size_t
>
(
attr_axes
.
begin
(),
attr_axes
.
end
());
}
int
keep_dims
=
1
;
if
(
contains
(
attributes
,
"keepdims"
))
{
keep_dims
=
parse_value
(
attributes
.
at
(
"keepdims"
)).
at
<
int
>
();
}
if
(
keep_dims
==
1
)
{
return
prog
.
add_instruction
(
op
::
reduce_sum
{
axes
},
std
::
move
(
args
));
}
else
{
auto
ins
=
prog
.
add_instruction
(
op
::
reduce_sum
{
axes
},
std
::
move
(
args
));
std
::
vector
<
int64_t
>
squeeze_axes
{
axes
.
begin
(),
axes
.
end
()};
return
prog
.
add_instruction
(
op
::
squeeze
{
squeeze_axes
},
ins
);
}
}
void
parse_from
(
std
::
istream
&
is
)
{
onnx
::
ModelProto
model
;
...
...
src/propagate_constant.cpp
View file @
cfd36b63
...
...
@@ -10,8 +10,8 @@ inline namespace MIGRAPHX_INLINE_NS {
bool
skip_propogate
(
instruction_ref
ins
)
{
if
(
ins
->
name
()
==
"
@literal
"
)
return
true
;
if
(
ins
->
name
()
==
"
contiguous
"
)
return
skip_propogate
(
ins
->
inputs
().
front
())
;
auto
&&
s
=
ins
->
get_shape
();
if
(
s
.
broadcasted
()
and
not
s
.
scalar
())
return
true
;
...
...
@@ -33,7 +33,7 @@ void propagate_constant::apply(program& p) const
ins
->
outputs
().
end
());
for
(
auto
child
:
children
)
{
if
(
skip_propogate
(
child
))
if
(
child
->
name
()
==
"@literal"
or
skip_propogate
(
child
))
{
self
(
child
);
continue
;
...
...
src/targets/cpu/lowering.cpp
View file @
cfd36b63
...
...
@@ -13,6 +13,8 @@
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/softmax.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp>
...
...
src/targets/gpu/CMakeLists.txt
View file @
cfd36b63
...
...
@@ -12,9 +12,12 @@ endif()
add_library
(
migraphx_device
device/add.cpp
device/argmax.cpp
device/argmin.cpp
device/max.cpp
device/min.cpp
device/exp.cpp
device/erf.cpp
device/log.cpp
device/sin.cpp
device/cos.cpp
...
...
@@ -44,6 +47,8 @@ target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURR
target_include_directories
(
migraphx_device PRIVATE $<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/device/include>
)
add_library
(
migraphx_gpu
argmax.cpp
argmin.cpp
eliminate_workspace.cpp
fuse_ops.cpp
hip.cpp
...
...
src/targets/gpu/argmax.cpp
0 → 100644
View file @
cfd36b63
#include <migraphx/gpu/argmax.hpp>
#include <migraphx/gpu/device/argmax.hpp>
#include <migraphx/gpu/context.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
shape
hip_argmax
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
standard
();
return
op
.
compute_shape
({
inputs
.
at
(
0
)});
}
argument
hip_argmax
::
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
device
::
argmax
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
op
.
axis
);
return
args
.
back
();
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/argmin.cpp
0 → 100644
View file @
cfd36b63
#include <migraphx/gpu/argmin.hpp>
#include <migraphx/gpu/device/argmin.hpp>
#include <migraphx/gpu/context.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
shape
hip_argmin
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
standard
();
return
op
.
compute_shape
({
inputs
.
at
(
0
)});
}
argument
hip_argmin
::
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
device
::
argmin
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
op
.
axis
);
return
args
.
back
();
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/device/argmax.cpp
0 → 100644
View file @
cfd36b63
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/argmax.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/arg_op.hpp>
#include <migraphx/gpu/hip.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
void
argmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
)
{
arg_op
(
argmax_op
{},
stream
,
result
,
arg
,
axis
);
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/device/argmin.cpp
0 → 100644
View file @
cfd36b63
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/argmin.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/arg_op.hpp>
#include <migraphx/gpu/hip.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
void
argmin
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
)
{
arg_op
(
argmin_op
{},
stream
,
result
,
arg
,
axis
);
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/device/erf.cpp
0 → 100644
View file @
cfd36b63
#include <migraphx/gpu/device/erf.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
void
erf
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
)
{
nary
(
stream
,
result
,
arg
)([](
auto
x
)
{
return
::
erf
(
to_hip_type
(
x
));
});
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/device/include/migraphx/gpu/device/reduce.hpp
View file @
cfd36b63
...
...
@@ -128,7 +128,7 @@ __device__ T dpp_mov(T& x)
template
<
class
T
,
class
Op
>
__device__
void
dpp_reduce
(
T
&
in
,
Op
op
)
{
T
out
;
T
out
{}
;
out
=
dpp_mov
<
dpp_row_shr
(
1
)
>
(
in
);
in
=
op
(
in
,
out
);
out
=
dpp_mov
<
dpp_row_shr
(
2
)
>
(
in
);
...
...
src/targets/gpu/include/migraphx/gpu/argmax.hpp
0 → 100644
View file @
cfd36b63
#ifndef MIGRAPHX_GUARD_RTGLIB_ARGMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_ARGMAX_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/gpu/device/argmax.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
struct
context
;
struct
hip_argmax
{
op
::
argmax
op
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
migraphx
::
reflect
(
self
.
op
,
f
);
}
std
::
string
name
()
const
{
return
"gpu::argmax"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
;
argument
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
;
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/include/migraphx/gpu/argmin.hpp
0 → 100644
View file @
cfd36b63
#ifndef MIGRAPHX_GUARD_RTGLIB_ARGMIN_HPP
#define MIGRAPHX_GUARD_RTGLIB_ARGMIN_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/gpu/device/argmin.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
struct
context
;
struct
hip_argmin
{
op
::
argmin
op
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
migraphx
::
reflect
(
self
.
op
,
f
);
}
std
::
string
name
()
const
{
return
"gpu::argmin"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
;
argument
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
;
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
0 → 100644
View file @
cfd36b63
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ARG_OP_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ARG_OP_HPP
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/hip.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
template
<
class
T
>
struct
val_index
{
T
val
;
int64_t
index
;
};
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
make_val_index
(
T
v
)
{
return
{
v
,
-
1
};
}
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
make_val_index
(
T
v
,
int64_t
i
)
{
return
{
v
,
i
};
}
struct
argmax_op
{
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
{
if
(
x
.
val
>
y
.
val
)
return
x
;
else
if
(
x
.
val
<
y
.
val
)
return
y
;
else
{
return
(
x
.
index
<
y
.
index
)
?
x
:
y
;
}
}
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
const
{
return
lowest
();
}
};
struct
argmin_op
{
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
{
if
(
x
.
val
<
y
.
val
)
return
x
;
else
if
(
x
.
val
>
y
.
val
)
return
y
;
else
{
return
(
x
.
index
<
y
.
index
)
?
x
:
y
;
}
}
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
const
{
return
highest
();
}
};
template
<
class
Op
>
void
arg_op
(
Op
op
,
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
)
{
auto
arg_shape
=
arg
.
get_shape
();
auto
lens
=
arg_shape
.
lens
();
auto
batch_lens
=
lens
;
size_t
batch_item_num
=
lens
[
axis
];
batch_lens
[
axis
]
=
1
;
migraphx
::
shape
batch_shape
{
arg_shape
.
type
(),
batch_lens
};
hip_visit_all
(
arg
,
arg_shape
,
batch_shape
)([
&
](
auto
input
,
auto
arg_s
,
auto
batch_s
)
{
auto
output
=
device_cast
(
result
.
get
<
int64_t
>
().
data
());
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
input
)
::
value_type
>>
;
// use one block for items in one batch.
const
size_t
max_block_size
=
256
;
const
std
::
size_t
block_size
=
compute_block_size
(
batch_item_num
,
max_block_size
);
gs_launch
(
stream
,
batch_shape
.
elements
()
*
block_size
,
block_size
)([
=
](
auto
i
,
auto
idx
)
__device__
{
auto
batch_idx
=
batch_s
.
multi
(
i
/
block_size
);
auto
data_idx
=
batch_idx
;
auto
init
=
make_val_index
<
type
>
(
op
.
init
());
auto
op_output
=
block_reduce
<
max_block_size
>
(
idx
,
op
,
init
,
batch_item_num
,
[
&
](
auto
j
)
__device__
{
data_idx
[
axis
]
=
j
;
return
make_val_index
(
input
[
arg_s
.
index
(
data_idx
)],
j
);
});
if
(
idx
.
local
==
0
)
{
output
[
batch_s
.
index
(
batch_idx
)]
=
op_output
.
index
;
}
});
});
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/include/migraphx/gpu/device/argmax.hpp
0 → 100644
View file @
cfd36b63
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ARGMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ARGMAX_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
void
argmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
);
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment