Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d7a28300
Commit
d7a28300
authored
Apr 14, 2022
by
Shucai Xiao
Browse files
merge develop branch to branch_for_ort2
parents
bcb2c0a4
a930f1d5
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
319 additions
and
40 deletions
+319
-40
cmake/PythonModules.cmake
cmake/PythonModules.cmake
+4
-0
examples/vision/python_resnet50/resnet50_inference.ipynb
examples/vision/python_resnet50/resnet50_inference.ipynb
+2
-2
src/CMakeLists.txt
src/CMakeLists.txt
+4
-2
src/include/migraphx/gemm.hpp
src/include/migraphx/gemm.hpp
+4
-2
src/include/migraphx/generate.hpp
src/include/migraphx/generate.hpp
+4
-4
src/include/migraphx/op/scatter.hpp
src/include/migraphx/op/scatter.hpp
+34
-8
src/include/migraphx/op/scatter_add.hpp
src/include/migraphx/op/scatter_add.hpp
+38
-0
src/include/migraphx/op/scatter_mul.hpp
src/include/migraphx/op/scatter_mul.hpp
+36
-0
src/include/migraphx/op/scatter_none.hpp
src/include/migraphx/op/scatter_none.hpp
+37
-0
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+3
-1
src/onnx/parse_generic_op.cpp
src/onnx/parse_generic_op.cpp
+2
-2
src/onnx/parse_scatter.cpp
src/onnx/parse_scatter.cpp
+44
-0
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+1
-0
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
+8
-3
src/targets/gpu/hip.cpp
src/targets/gpu/hip.cpp
+16
-2
src/targets/gpu/include/migraphx/gpu/scatter.hpp
src/targets/gpu/include/migraphx/gpu/scatter.hpp
+5
-3
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+4
-1
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+43
-2
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+25
-4
test/onnx/scatter_add_test.onnx
test/onnx/scatter_add_test.onnx
+5
-4
No files found.
cmake/PythonModules.cmake
View file @
d7a28300
...
@@ -54,6 +54,10 @@ function(py_add_module NAME)
...
@@ -54,6 +54,10 @@ function(py_add_module NAME)
endfunction
()
endfunction
()
set
(
PYTHON_SEARCH_VERSIONS 2.7 3.5 3.6 3.7 3.8 3.9
)
set
(
PYTHON_SEARCH_VERSIONS 2.7 3.5 3.6 3.7 3.8 3.9
)
set
(
PYTHON_DISABLE_VERSIONS
""
CACHE STRING
""
)
foreach
(
PYTHON_DISABLE_VERSION
${
PYTHON_DISABLE_VERSIONS
}
)
list
(
REMOVE_ITEM PYTHON_SEARCH_VERSIONS
${
PYTHON_DISABLE_VERSION
}
)
endforeach
()
set
(
_PYTHON_VERSIONS
)
set
(
_PYTHON_VERSIONS
)
foreach
(
PYTHON_VERSION
${
PYTHON_SEARCH_VERSIONS
}
)
foreach
(
PYTHON_VERSION
${
PYTHON_SEARCH_VERSIONS
}
)
...
...
examples/vision/python_resnet50/resnet50_inference.ipynb
View file @
d7a28300
...
@@ -106,8 +106,8 @@
...
@@ -106,8 +106,8 @@
"outputs": [],
"outputs": [],
"source": [
"source": [
"if not path.exists(\"./resnet50.onnx\"):\n",
"if not path.exists(\"./resnet50.onnx\"):\n",
" !wget https://github.com/onnx/models/
blob/master
/vision/classification/resnet/model/resnet50-v2-7.onnx
?raw=true\n
",
" !wget https://github.com/onnx/models/
raw/main
/vision/classification/resnet/model/resnet50-v2-7.onnx",
" !mv
'
resnet50-v2-7.onnx
?raw=true'
resnet50.onnx"
" !mv resnet50-v2-7.onnx resnet50.onnx"
]
]
},
},
{
{
...
...
src/CMakeLists.txt
View file @
d7a28300
...
@@ -162,10 +162,12 @@ register_migraphx_ops(
...
@@ -162,10 +162,12 @@ register_migraphx_ops(
round
round
rsqrt
rsqrt
scalar
scalar
scatter
scatter_add
scatternd_none
scatter_mul
scatter_none
scatternd_add
scatternd_add
scatternd_mul
scatternd_mul
scatternd_none
sigmoid
sigmoid
sign
sign
sinh
sinh
...
...
src/include/migraphx/gemm.hpp
View file @
d7a28300
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/
shape_for_each
.hpp>
#include <migraphx/
par_for
.hpp>
#include <migraphx/tensor_view.hpp>
#include <migraphx/tensor_view.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -20,8 +20,10 @@ void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha
...
@@ -20,8 +20,10 @@ void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha
assert
(
amat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
amat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
cmat
.
get_shape
().
lens
()[
dim_0
]
==
amat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
cmat
.
get_shape
().
lens
()[
dim_0
]
==
amat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
cmat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_1
]);
assert
(
cmat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_1
]);
auto
cs
=
cmat
.
get_shape
();
shape_for_each
(
cmat
.
get_shape
(),
[
&
](
const
auto
&
c_idx
)
{
par_for
(
cs
.
elements
(),
[
&
](
auto
i
)
{
auto
c_idx
=
cs
.
multi
(
i
);
auto
a_idx
=
c_idx
;
auto
a_idx
=
c_idx
;
auto
b_idx
=
c_idx
;
auto
b_idx
=
c_idx
;
double
s
=
0.0
;
double
s
=
0.0
;
...
...
src/include/migraphx/generate.hpp
View file @
d7a28300
...
@@ -88,16 +88,16 @@ struct xorshift_generator
...
@@ -88,16 +88,16 @@ struct xorshift_generator
template
<
class
T
>
template
<
class
T
>
auto
generate_tensor_data
(
const
migraphx
::
shape
&
s
,
unsigned
long
seed
=
0
)
auto
generate_tensor_data
(
const
migraphx
::
shape
&
s
,
unsigned
long
seed
=
0
)
{
{
auto
result
=
make_shared_array
<
T
>
(
s
.
element
s
());
auto
result
=
make_shared_array
<
T
>
(
s
.
element
_space
());
std
::
generate
(
result
.
get
(),
result
.
get
()
+
s
.
element
s
(),
xorshf96_generator
<
T
>
{
seed
});
std
::
generate
(
result
.
get
(),
result
.
get
()
+
s
.
element
_space
(),
xorshf96_generator
<
T
>
{
seed
});
return
result
;
return
result
;
}
}
template
<
class
T
>
template
<
class
T
>
auto
fill_tensor_data
(
const
migraphx
::
shape
&
s
,
unsigned
long
value
=
0
)
auto
fill_tensor_data
(
const
migraphx
::
shape
&
s
,
unsigned
long
value
=
0
)
{
{
auto
result
=
make_shared_array
<
T
>
(
s
.
element
s
());
auto
result
=
make_shared_array
<
T
>
(
s
.
element
_space
());
std
::
generate
(
result
.
get
(),
result
.
get
()
+
s
.
element
s
(),
[
=
]
{
return
value
;
});
std
::
generate
(
result
.
get
(),
result
.
get
()
+
s
.
element
_space
(),
[
=
]
{
return
value
;
});
return
result
;
return
result
;
}
}
...
...
src/include/migraphx/op/scatter.hpp
View file @
d7a28300
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include <migraphx/shape_for_each.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/name.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <utility>
...
@@ -16,7 +17,17 @@ namespace migraphx {
...
@@ -16,7 +17,17 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
struct
scatter
// The scatter operator fetches a subset of data given by an index array and then performs a
// reduction operation (add, multiply, or just set the data) on each element returned. We implement
// it as a separate derived struct for each of the three reduction methods. The related operator
// scatterND is a generalization that works on a set of 3 tensors of different ranks. The
// complementary operations are gather/gatherND.
//
// This is a template for deriving child structs from. Each child needs to define
// only a reduction() method. Names are automatically handled by the op_name template.
template
<
class
Derived
>
struct
scatter
:
op_name
<
Derived
>
{
{
int64_t
axis
=
0
;
int64_t
axis
=
0
;
...
@@ -33,29 +44,44 @@ struct scatter
...
@@ -33,29 +44,44 @@ struct scatter
return
{{
"normalize_axes"
,
normalize
}};
return
{{
"normalize_axes"
,
normalize
}};
}
}
std
::
string
name
()
const
{
return
"scatter"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
).
standard
();
check_shapes
{
inputs
,
*
this
}.
has
(
3
).
standard
();
return
inputs
.
front
();
// If non-packed, this converts to a packed output while preserving permutation of tensor
return
inputs
.
front
().
with_lens
(
inputs
.
front
().
lens
());
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
// max dimension in axis
auto
&
self
=
static_cast
<
const
Derived
&>
(
*
this
);
// max dimension in each axis
auto
axis_dim_size
=
output_shape
.
lens
()[
axis
];
auto
axis_dim_size
=
output_shape
.
lens
()[
axis
];
// cast all arguments as correct type
visit_all
(
result
,
args
[
0
],
args
[
2
])([
&
](
auto
output
,
auto
data
,
auto
update
)
{
visit_all
(
result
,
args
[
0
],
args
[
2
])([
&
](
auto
output
,
auto
data
,
auto
update
)
{
// copy all of data to output
std
::
copy
(
data
.
begin
(),
data
.
end
(),
output
.
begin
());
std
::
copy
(
data
.
begin
(),
data
.
end
(),
output
.
begin
());
args
[
1
].
visit
([
&
](
auto
indices
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
auto
ind_s
=
indices
.
get_shape
();
auto
ind_s
=
indices
.
get_shape
();
// iterate through items in shape
shape_for_each
(
ind_s
,
[
&
](
const
auto
&
idx
)
{
shape_for_each
(
ind_s
,
[
&
](
const
auto
&
idx
)
{
auto
out_idx
=
idx
;
auto
out_idx
=
idx
;
auto
index
=
indices
[
ind_s
.
index
(
idx
)];
// Overloaded tensor_view::() invokes indexing logic of
// std::size_t shape::index(std::size_t i) const
// which handles nonstandard shapes correctly
auto
index
=
indices
(
idx
.
begin
(),
idx
.
end
());
// normalize negative indexes (may be redundant after using
// normalize_compute_shape())
index
=
(
index
<
0
)
?
index
+
axis_dim_size
:
index
;
index
=
(
index
<
0
)
?
index
+
axis_dim_size
:
index
;
out_idx
[
axis
]
=
index
;
out_idx
[
axis
]
=
index
;
output
[
output_shape
.
index
(
out_idx
)]
=
update
[
ind_s
.
index
(
idx
)];
// look up the appropriate locations in output, using idx and out_idx.
// call reduction() method of derived struct to copy and reduce that element
self
.
reduction
()(
output
(
out_idx
.
begin
(),
out_idx
.
end
()),
update
(
idx
.
begin
(),
idx
.
end
()));
});
});
});
});
});
});
...
...
src/include/migraphx/op/scatter_add.hpp
0 → 100644
View file @
d7a28300
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_ADD_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_ADD_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
#include <migraphx/op/scatter.hpp>
// Scatter op. with "add" function as reduction.
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
scatter_add
:
scatter
<
scatter_add
>
{
// reduction (pointwise operation) is called by the parent struct's compute() method.
// It works much like a virtual function overload.
// For the scatter methods, there are three different reduction functions.
auto
reduction
()
const
{
return
[](
auto
&
x
,
const
auto
&
y
)
{
x
+=
y
;
};
}
// name of this struct is automatically assigned by the op_name<>
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/scatter_mul.hpp
0 → 100644
View file @
d7a28300
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_MUL_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_MUL_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
#include <migraphx/op/scatter.hpp>
// Scatter op. with "multiply" as the reduction function.
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
scatter_mul
:
scatter
<
scatter_mul
>
{
// reduction (pointwise operation) is called by the parent struct's compute() method.
// It works much like a virtual function overload.
// For the scatter operators, there are three different reduction functions.
auto
reduction
()
const
{
return
[](
auto
&
x
,
const
auto
&
y
)
{
x
*=
y
;
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/scatter_none.hpp
0 → 100644
View file @
d7a28300
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_NONE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_NONE_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/scatter.hpp>
#include <cmath>
#include <utility>
// Scatter op. with "none" as the reduction function (just copies the value). This is identical to
// the previously existing Scatter op.
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
scatter_none
:
scatter
<
scatter_none
>
{
// reduction (pointwise operation) is called by the parent struct's compute() method.
// It works much like a virtual function overload.
// For the scatter operators, there are three different reduction functions.
auto
reduction
()
const
{
return
[](
auto
&
x
,
const
auto
&
y
)
{
x
=
y
;
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/operators.hpp
View file @
d7a28300
...
@@ -86,7 +86,9 @@
...
@@ -86,7 +86,9 @@
#include <migraphx/op/round.hpp>
#include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter.hpp>
#include <migraphx/op/scatter_add.hpp>
#include <migraphx/op/scatter_mul.hpp>
#include <migraphx/op/scatter_none.hpp>
#include <migraphx/op/scatternd_add.hpp>
#include <migraphx/op/scatternd_add.hpp>
#include <migraphx/op/scatternd_none.hpp>
#include <migraphx/op/scatternd_none.hpp>
#include <migraphx/op/scatternd_mul.hpp>
#include <migraphx/op/scatternd_mul.hpp>
...
...
src/onnx/parse_generic_op.cpp
View file @
d7a28300
...
@@ -10,6 +10,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
...
@@ -10,6 +10,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{
{
std
::
vector
<
op_desc
>
operators
()
const
std
::
vector
<
op_desc
>
operators
()
const
{
{
// clang-format off
return
{{
"Abs"
,
"abs"
},
return
{{
"Abs"
,
"abs"
},
{
"Acos"
,
"acos"
},
{
"Acos"
,
"acos"
},
{
"Acosh"
,
"acosh"
},
{
"Acosh"
,
"acosh"
},
...
@@ -36,8 +37,6 @@ struct parse_generic_op : op_parser<parse_generic_op>
...
@@ -36,8 +37,6 @@ struct parse_generic_op : op_parser<parse_generic_op>
{
"Reciprocal"
,
"recip"
},
{
"Reciprocal"
,
"recip"
},
{
"Relu"
,
"relu"
},
{
"Relu"
,
"relu"
},
{
"Round"
,
"round"
},
{
"Round"
,
"round"
},
{
"Scatter"
,
"scatter"
},
{
"ScatterElements"
,
"scatter"
},
{
"Sigmoid"
,
"sigmoid"
},
{
"Sigmoid"
,
"sigmoid"
},
{
"Sign"
,
"sign"
},
{
"Sign"
,
"sign"
},
{
"Sin"
,
"sin"
},
{
"Sin"
,
"sin"
},
...
@@ -46,6 +45,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
...
@@ -46,6 +45,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{
"Tan"
,
"tan"
},
{
"Tan"
,
"tan"
},
{
"Tanh"
,
"tanh"
},
{
"Tanh"
,
"tanh"
},
{
"Not"
,
"not"
}};
{
"Not"
,
"not"
}};
// clang-format on
}
}
bool
needs_contiguous
(
const
std
::
string
&
op_name
)
const
bool
needs_contiguous
(
const
std
::
string
&
op_name
)
const
...
...
src/onnx/parse_scatter.cpp
0 → 100644
View file @
d7a28300
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
struct
parse_scatter
:
op_parser
<
parse_scatter
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"ScatterElements"
},
{
"Scatter"
}};
}
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
/*parser*/
,
const
onnx_parser
::
node_info
&
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
operation
op
;
std
::
string
op_name
=
"scatter_none"
;
int
axis
=
0
;
if
(
contains
(
info
.
attributes
,
"axis"
))
axis
=
info
.
attributes
.
at
(
"axis"
).
i
();
if
(
contains
(
info
.
attributes
,
"reduction"
))
{
std
::
string
reduction_att
(
info
.
attributes
.
at
(
"reduction"
).
s
());
// check for a valid reduction attribute. We have an operator for each one.
if
(
not
contains
({
"none"
,
"add"
,
"mul"
},
reduction_att
))
MIGRAPHX_THROW
(
"PARSE_SCATTER: unsupported reduction mode "
+
reduction_att
);
// merge scatter with reduction attribute to specify which scatter operation. Future
// reduction op names should follow this pattern and should also be added to the check
// above.
op_name
=
std
::
string
(
"scatter_"
)
+
reduction_att
;
}
op
=
migraphx
::
make_op
(
op_name
,
{{
"axis"
,
axis
}});
return
info
.
add_instruction
(
op
,
args
);
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/py/migraphx_py.cpp
View file @
d7a28300
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include <migraphx/operation.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/tf.hpp>
...
...
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
View file @
d7a28300
...
@@ -178,10 +178,15 @@ auto hip_vec_visit_all(T&& x, Ts&&... xs)
...
@@ -178,10 +178,15 @@ auto hip_vec_visit_all(T&& x, Ts&&... xs)
return
[
&
](
auto
f
)
{
return
[
&
](
auto
f
)
{
auto
sx
=
get_shape
(
x
);
auto
sx
=
get_shape
(
x
);
auto
lens
=
sx
.
lens
();
auto
lens
=
sx
.
lens
();
assert
(
lens
.
back
()
%
N
==
0
);
assert
(
sx
.
strides
().
back
()
==
1
);
lens
.
back
()
/=
N
;
lens
.
back
()
/=
N
;
shape
ssx
{
sx
.
type
(),
lens
};
shape
vec_sx
{
sx
.
type
(),
lens
};
hip_visit_all_impl
(
hip_visit_all_impl
(
vec_sx
,
ssx
,
make_hip_convert
([](
auto
*
p
)
{
return
as_vec
<
N
>
(
device_cast
(
p
));
}),
f
,
x
,
xs
...);
make_hip_convert
([](
auto
*
p
)
{
return
as_vec
<
N
>
(
device_cast
(
p
));
}),
f
,
x
,
xs
...);
};
};
}
}
...
...
src/targets/gpu/hip.cpp
View file @
d7a28300
...
@@ -27,6 +27,15 @@ using hip_host_ptr = MIGRAPHX_MANAGE_PTR(void, hipHostUnregister);
...
@@ -27,6 +27,15 @@ using hip_host_ptr = MIGRAPHX_MANAGE_PTR(void, hipHostUnregister);
std
::
string
hip_error
(
int
error
)
{
return
hipGetErrorString
(
static_cast
<
hipError_t
>
(
error
));
}
std
::
string
hip_error
(
int
error
)
{
return
hipGetErrorString
(
static_cast
<
hipError_t
>
(
error
));
}
bool
is_device_ptr
(
const
void
*
ptr
)
{
hipPointerAttribute_t
attr
;
auto
status
=
hipPointerGetAttributes
(
&
attr
,
ptr
);
if
(
status
!=
hipSuccess
)
return
false
;
return
attr
.
memoryType
==
hipMemoryTypeDevice
;
}
std
::
size_t
get_available_gpu_memory
()
std
::
size_t
get_available_gpu_memory
()
{
{
size_t
free
;
size_t
free
;
...
@@ -50,8 +59,8 @@ hip_ptr allocate_gpu(std::size_t sz, bool host = false)
...
@@ -50,8 +59,8 @@ hip_ptr allocate_gpu(std::size_t sz, bool host = false)
{
{
if
(
sz
>
get_available_gpu_memory
())
if
(
sz
>
get_available_gpu_memory
())
MIGRAPHX_THROW
(
"Memory not available to allocate buffer: "
+
std
::
to_string
(
sz
));
MIGRAPHX_THROW
(
"Memory not available to allocate buffer: "
+
std
::
to_string
(
sz
));
void
*
result
;
void
*
result
=
nullptr
;
auto
status
=
host
?
hipHostMalloc
(
&
result
,
sz
)
:
hipMalloc
(
&
result
,
sz
);
auto
status
=
host
?
hipHostMalloc
(
&
result
,
sz
)
:
hipMalloc
(
&
result
,
sz
);
if
(
status
!=
hipSuccess
)
if
(
status
!=
hipSuccess
)
{
{
if
(
host
)
if
(
host
)
...
@@ -59,6 +68,7 @@ hip_ptr allocate_gpu(std::size_t sz, bool host = false)
...
@@ -59,6 +68,7 @@ hip_ptr allocate_gpu(std::size_t sz, bool host = false)
else
else
return
allocate_gpu
(
sz
,
true
);
return
allocate_gpu
(
sz
,
true
);
}
}
assert
(
result
!=
nullptr
);
return
hip_ptr
{
result
};
return
hip_ptr
{
result
};
}
}
...
@@ -75,6 +85,8 @@ std::vector<T> read_from_gpu(const void* x, std::size_t sz)
...
@@ -75,6 +85,8 @@ std::vector<T> read_from_gpu(const void* x, std::size_t sz)
{
{
gpu_sync
();
gpu_sync
();
std
::
vector
<
T
>
result
(
sz
);
std
::
vector
<
T
>
result
(
sz
);
assert
(
not
is_device_ptr
(
result
.
data
()));
assert
(
is_device_ptr
(
x
));
auto
status
=
hipMemcpy
(
result
.
data
(),
x
,
sz
*
sizeof
(
T
),
hipMemcpyDeviceToHost
);
auto
status
=
hipMemcpy
(
result
.
data
(),
x
,
sz
*
sizeof
(
T
),
hipMemcpyDeviceToHost
);
if
(
status
!=
hipSuccess
)
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"Copy from gpu failed: "
+
hip_error
(
status
));
// NOLINT
MIGRAPHX_THROW
(
"Copy from gpu failed: "
+
hip_error
(
status
));
// NOLINT
...
@@ -85,6 +97,8 @@ hip_ptr write_to_gpu(const void* x, std::size_t sz, bool host = false)
...
@@ -85,6 +97,8 @@ hip_ptr write_to_gpu(const void* x, std::size_t sz, bool host = false)
{
{
gpu_sync
();
gpu_sync
();
auto
result
=
allocate_gpu
(
sz
,
host
);
auto
result
=
allocate_gpu
(
sz
,
host
);
assert
(
is_device_ptr
(
result
.
get
()));
assert
(
not
is_device_ptr
(
x
));
auto
status
=
hipMemcpy
(
result
.
get
(),
x
,
sz
,
hipMemcpyHostToDevice
);
auto
status
=
hipMemcpy
(
result
.
get
(),
x
,
sz
,
hipMemcpyHostToDevice
);
if
(
status
!=
hipSuccess
)
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"Copy to gpu failed: "
+
hip_error
(
status
));
MIGRAPHX_THROW
(
"Copy to gpu failed: "
+
hip_error
(
status
));
...
...
src/targets/gpu/include/migraphx/gpu/scatter.hpp
View file @
d7a28300
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/scatter.hpp>
#include <migraphx/op/scatter
_none
.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/miopen.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -14,7 +14,9 @@ struct context;
...
@@ -14,7 +14,9 @@ struct context;
struct
hip_scatter
struct
hip_scatter
{
{
op
::
scatter
op
;
// scatter_none is an exact replacement for previous op::scatter,
// renamed to match an Onnx option. Don't use base class op::scatter
op
::
scatter_none
op
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -22,7 +24,7 @@ struct hip_scatter
...
@@ -22,7 +24,7 @@ struct hip_scatter
return
migraphx
::
reflect
(
self
.
op
,
f
);
return
migraphx
::
reflect
(
self
.
op
,
f
);
}
}
std
::
string
name
()
const
{
return
"gpu::scatter"
;
}
std
::
string
name
()
const
{
return
"gpu::scatter
_none
"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
;
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
;
argument
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
...
...
src/targets/gpu/lowering.cpp
View file @
d7a28300
...
@@ -190,7 +190,7 @@ struct miopen_apply
...
@@ -190,7 +190,7 @@ struct miopen_apply
add_extend_op
(
"rnn_var_sl_last_output"
);
add_extend_op
(
"rnn_var_sl_last_output"
);
add_extend_op
(
"rnn_var_sl_shift_output"
);
add_extend_op
(
"rnn_var_sl_shift_output"
);
add_extend_op
(
"rnn_var_sl_shift_sequence"
);
add_extend_op
(
"rnn_var_sl_shift_sequence"
);
add_extend_op
(
"scatter"
);
add_extend_op
(
"scatter
_none
"
);
add_extend_op
(
"softmax"
);
add_extend_op
(
"softmax"
);
add_extend_op
(
"topk"
);
add_extend_op
(
"topk"
);
...
@@ -381,6 +381,9 @@ struct miopen_apply
...
@@ -381,6 +381,9 @@ struct miopen_apply
});
});
}
}
// add_generic_op just constructs the operator with no fields whereas add_extend_op copies over
// the fields Since it doesn't have fields its default constructed
void
add_generic_op
(
const
std
::
string
&
name
)
{
add_generic_op
(
name
,
"gpu::"
+
name
);
}
void
add_generic_op
(
const
std
::
string
&
name
)
{
add_generic_op
(
name
,
"gpu::"
+
name
);
}
void
add_generic_op
(
const
std
::
string
&
op_name
,
const
std
::
string
&
gpu_name
)
void
add_generic_op
(
const
std
::
string
&
op_name
,
const
std
::
string
&
gpu_name
)
...
...
test/onnx/gen_onnx.py
View file @
d7a28300
...
@@ -4381,7 +4381,7 @@ def roialign_test():
...
@@ -4381,7 +4381,7 @@ def roialign_test():
@
onnx_test
@
onnx_test
def
scatter_test
():
def
scatter_
add_
test
():
x
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
3
,
4
,
5
,
6
])
x
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
3
,
4
,
5
,
6
])
i
=
helper
.
make_tensor_value_info
(
'indices'
,
TensorProto
.
INT32
,
i
=
helper
.
make_tensor_value_info
(
'indices'
,
TensorProto
.
INT32
,
[
2
,
3
,
4
,
5
])
[
2
,
3
,
4
,
5
])
...
@@ -4390,7 +4390,48 @@ def scatter_test():
...
@@ -4390,7 +4390,48 @@ def scatter_test():
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
3
,
4
,
5
,
6
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
3
,
4
,
5
,
6
])
node
=
onnx
.
helper
.
make_node
(
node
=
onnx
.
helper
.
make_node
(
'Scatter'
,
'ScatterElements'
,
reduction
=
'add'
,
inputs
=
[
'data'
,
'indices'
,
'update'
],
outputs
=
[
'y'
],
axis
=-
2
,
)
return
([
node
],
[
x
,
i
,
u
],
[
y
])
@
onnx_test
def
scatter_mul_test
():
x
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
3
,
4
,
5
,
6
])
i
=
helper
.
make_tensor_value_info
(
'indices'
,
TensorProto
.
INT32
,
[
2
,
3
,
4
,
5
])
u
=
helper
.
make_tensor_value_info
(
'update'
,
TensorProto
.
FLOAT
,
[
2
,
3
,
4
,
5
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
3
,
4
,
5
,
6
])
node
=
onnx
.
helper
.
make_node
(
'ScatterElements'
,
reduction
=
'mul'
,
inputs
=
[
'data'
,
'indices'
,
'update'
],
outputs
=
[
'y'
],
axis
=-
2
,
)
return
([
node
],
[
x
,
i
,
u
],
[
y
])
@
onnx_test
def
scatter_none_test
():
x
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
3
,
4
,
5
,
6
])
i
=
helper
.
make_tensor_value_info
(
'indices'
,
TensorProto
.
INT32
,
[
2
,
3
,
4
,
5
])
u
=
helper
.
make_tensor_value_info
(
'update'
,
TensorProto
.
FLOAT
,
[
2
,
3
,
4
,
5
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
3
,
4
,
5
,
6
])
node
=
onnx
.
helper
.
make_node
(
'ScatterElements'
,
reduction
=
'none'
,
inputs
=
[
'data'
,
'indices'
,
'update'
],
inputs
=
[
'data'
,
'indices'
,
'update'
],
outputs
=
[
'y'
],
outputs
=
[
'y'
],
axis
=-
2
,
axis
=-
2
,
...
...
test/onnx/onnx_test.cpp
View file @
d7a28300
...
@@ -4233,7 +4233,8 @@ TEST_CASE(round_test)
...
@@ -4233,7 +4233,8 @@ TEST_CASE(round_test)
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
scatter_test
)
// the ScatterElements op has 3 reduction modes, which map to separate reference ops
migraphx
::
program
create_scatter_program
(
const
std
::
string
&
scatter_mode
,
int
axis
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
...
@@ -4242,10 +4243,30 @@ TEST_CASE(scatter_test)
...
@@ -4242,10 +4243,30 @@ TEST_CASE(scatter_test)
mm
->
add_parameter
(
"indices"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
,
4
,
5
}});
mm
->
add_parameter
(
"indices"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
,
4
,
5
}});
auto
l2
=
auto
l2
=
mm
->
add_parameter
(
"update"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}});
mm
->
add_parameter
(
"update"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}});
int
axis
=
-
2
;
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
scatter_mode
,
{{
"axis"
,
axis
}}),
l0
,
l1
,
l2
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatter"
,
{{
"axis"
,
axis
}}),
l0
,
l1
,
l2
);
mm
->
add_return
({
r
});
mm
->
add_return
({
r
});
auto
prog
=
migraphx
::
parse_onnx
(
"scatter_test.onnx"
);
return
p
;
}
TEST_CASE
(
scatter_add_test
)
{
migraphx
::
program
p
=
create_scatter_program
(
"scatter_add"
,
-
2
);
auto
prog
=
migraphx
::
parse_onnx
(
"scatter_add_test.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
scatter_mul_test
)
{
migraphx
::
program
p
=
create_scatter_program
(
"scatter_mul"
,
-
2
);
auto
prog
=
migraphx
::
parse_onnx
(
"scatter_mul_test.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
scatter_none_test
)
{
migraphx
::
program
p
=
create_scatter_program
(
"scatter_none"
,
-
2
);
auto
prog
=
migraphx
::
parse_onnx
(
"scatter_none_test.onnx"
);
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
}
}
...
...
test/onnx/scatter_test.onnx
→
test/onnx/scatter_
add_
test.onnx
View file @
d7a28300
scatter_test:
scatter_
add_
test:
9
V
data
data
indices
indices
updatey"Scatter*
updatey"ScatterElements*
axisscatter_testZ
axis*
reduction"addscatter_add_testZ
data
data
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment