Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f5e0c343
Commit
f5e0c343
authored
Jan 15, 2019
by
Shucai Xiao
Browse files
Add the gather operator and onnx parsing of shape and constantfill.
parent
301b7605
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
538 additions
and
1 deletion
+538
-1
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+111
-0
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+110
-0
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+27
-0
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+2
-0
src/targets/gpu/device/gather.cpp
src/targets/gpu/device/gather.cpp
+70
-0
src/targets/gpu/gather.cpp
src/targets/gpu/gather.cpp
+40
-0
src/targets/gpu/include/migraphx/gpu/device/gather.hpp
src/targets/gpu/include/migraphx/gpu/device/gather.hpp
+30
-0
src/targets/gpu/include/migraphx/gpu/gather.hpp
src/targets/gpu/include/migraphx/gpu/gather.hpp
+51
-0
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+2
-1
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+43
-0
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+16
-0
test/onnx/gather_test.onnx
test/onnx/gather_test.onnx
+23
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+13
-0
No files found.
src/include/migraphx/operators.hpp
View file @
f5e0c343
...
...
@@ -6,6 +6,8 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
...
...
@@ -631,6 +633,115 @@ struct as_shape
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
// Gather to use the algorithm in onnx::gather operator
struct
gather
{
std
::
size_t
axis
=
0
;
std
::
string
name
()
const
{
return
"gather"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
auto
lens
=
inputs
[
0
].
lens
();
if
(
axis
>=
lens
.
size
())
{
MIGRAPHX_THROW
(
"Gather, axis is out of range."
);
}
auto
type
=
inputs
[
0
].
type
();
lens
[
axis
]
=
inputs
[
1
].
elements
();
return
{
type
,
lens
};
}
template
<
class
T
>
void
compute_index
(
const
T
&
out_idx
,
const
std
::
vector
<
argument
>&
args
,
T
&
in_idx
)
const
{
in_idx
=
out_idx
;
// max dimension in axis
std
::
size_t
max_dim
=
args
[
0
].
get_shape
().
lens
()[
axis
];
std
::
size_t
idx
=
args
[
1
].
at
<
std
::
size_t
>
(
out_idx
[
axis
]);
if
(
idx
>=
max_dim
)
{
MIGRAPHX_THROW
(
"Gather, indices are out of range in input tensor"
);
}
in_idx
[
axis
]
=
idx
;
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
std
::
vector
<
std
::
size_t
>
in_idx
;
this
->
compute_index
(
idx
,
args
,
in_idx
);
output
(
idx
.
begin
(),
idx
.
end
())
=
input
(
in_idx
.
begin
(),
in_idx
.
end
());
});
});
return
result
;
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
// Gather to use the algorithm in torch.nn.gather, which is diffrent
// from the onnx::gather operator.
struct
gather_torch
{
std
::
size_t
axis
=
0
;
std
::
string
name
()
const
{
return
"gather_torch"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
auto
lens
=
inputs
[
0
].
lens
();
if
(
axis
>=
lens
.
size
())
{
MIGRAPHX_THROW
(
"Gather, axis is out of range."
);
}
auto
type
=
inputs
[
0
].
type
();
// output shape is the same as that of the indices
return
{
type
,
inputs
[
1
].
lens
()};
}
template
<
class
T
>
void
compute_index
(
const
T
&
out_idx
,
const
std
::
vector
<
argument
>&
args
,
T
&
in_idx
)
const
{
in_idx
=
out_idx
;
// max dimension in axis
std
::
size_t
max_dim
=
args
[
0
].
get_shape
().
lens
()[
axis
];
args
[
1
].
visit
([
&
](
auto
idx
)
{
std
::
size_t
i
=
idx
(
out_idx
.
begin
(),
out_idx
.
end
());
if
(
i
>=
max_dim
)
{
MIGRAPHX_THROW
(
"gather_torch, indices are out of range in input tensor"
);
}
in_idx
[
axis
]
=
i
;
});
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
out_idx
)
{
std
::
vector
<
std
::
size_t
>
in_idx
;
this
->
compute_index
(
out_idx
,
args
,
in_idx
);
std
::
cout
<<
"gather torch input = "
<<
input
(
in_idx
.
begin
(),
in_idx
.
end
())
<<
std
::
endl
;
output
(
out_idx
.
begin
(),
out_idx
.
end
())
=
input
(
in_idx
.
begin
(),
in_idx
.
end
());
std
::
cout
<<
"gather torch out = "
<<
output
(
out_idx
.
begin
(),
out_idx
.
end
())
<<
std
::
endl
;
});
});
return
result
;
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
struct
dot
{
float
alpha
=
1.0
;
...
...
src/onnx/onnx.cpp
View file @
f5e0c343
...
...
@@ -80,6 +80,9 @@ struct onnx_parser
add_mem_op
(
"Unsqueeze"
,
&
onnx_parser
::
parse_unsqueeze
);
add_mem_op
(
"Slice"
,
&
onnx_parser
::
parse_slice
);
add_mem_op
(
"Concat"
,
&
onnx_parser
::
parse_concat
);
add_mem_op
(
"Gather"
,
&
onnx_parser
::
parse_gather
);
add_mem_op
(
"Shape"
,
&
onnx_parser
::
parse_shape
);
add_mem_op
(
"ConstantFill"
,
&
onnx_parser
::
parse_constant_fill
);
add_mem_op
(
"Transpose"
,
&
onnx_parser
::
parse_transpose
);
}
...
...
@@ -356,6 +359,18 @@ struct onnx_parser
return
prog
.
add_instruction
(
op
,
std
::
move
(
args
));
}
instruction_ref
parse_gather
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
std
::
size_t
axis
=
0
;
if
(
contains
(
attributes
,
"axis"
))
{
axis
=
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
int
>
();
}
op
::
gather_torch
op
{
axis
};
return
prog
.
add_instruction
(
op
,
std
::
move
(
args
));
}
instruction_ref
parse_slice
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
...
...
@@ -525,6 +540,79 @@ struct onnx_parser
return
prog
.
add_instruction
(
migraphx
::
op
::
transpose
{
perm
},
args
.
front
());
}
// Use a literal instruction to replace the shape since, output of
// shape operator are literals in migraphx
instruction_ref
parse_shape
(
const
std
::
string
&
,
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
args
.
size
()
!=
1
)
MIGRAPHX_THROW
(
"Shape, operator should have 1 operand"
);
std
::
vector
<
std
::
size_t
>
arg_shape
=
args
[
0
]
->
get_shape
().
lens
();
std
::
vector
<
int64_t
>
vec_shape
(
arg_shape
.
size
());
migraphx
::
shape
s
(
migraphx
::
shape
::
int64_type
,
{
arg_shape
.
size
()});
std
::
transform
(
arg_shape
.
begin
(),
arg_shape
.
end
(),
vec_shape
.
begin
(),
[](
auto
i
)
{
return
int64_t
(
i
);
});
return
prog
.
add_literal
(
migraphx
::
literal
{
s
,
vec_shape
});
}
// Use a literal instruction to replace the constantFill operator. In RNN, input shape
// and value are fixed, so no need to do the actual computation for the constantFill
// operator
instruction_ref
parse_constant_fill
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
args
.
size
()
!=
1
)
{
MIGRAPHX_THROW
(
"Constantfill, MIGraphX only handle the case with 1 operand"
);
}
int
input_as_shape
=
0
;
int
dtype
=
1
;
float
value
=
0.0
f
;
if
(
contains
(
attributes
,
"dtype"
))
{
dtype
=
parse_value
(
attributes
.
at
(
"dtype"
)).
at
<
int
>
();
}
migraphx
::
shape
::
type_t
type
=
get_type
(
dtype
);
if
(
contains
(
attributes
,
"input_as_shape"
))
{
input_as_shape
=
parse_value
(
attributes
.
at
(
"input_as_shape"
)).
at
<
int
>
();
}
if
(
contains
(
attributes
,
"value"
))
{
value
=
parse_value
(
attributes
.
at
(
"value"
)).
at
<
float
>
();
}
if
(
input_as_shape
==
1
)
{
migraphx
::
argument
in
=
args
[
0
]
->
eval
();
if
(
in
.
empty
())
{
MIGRAPHX_THROW
(
"ConstantFill, cannot handle dynamic shape as input for ConstantFill"
);
}
std
::
vector
<
std
::
size_t
>
dims
;
in
.
visit
([
&
](
auto
input
)
{
dims
.
assign
(
input
.
begin
(),
input
.
end
());
});
migraphx
::
shape
s
(
type
,
dims
);
return
prog
.
add_literal
(
migraphx
::
literal
(
s
,
{
value
}));
}
else
if
(
input_as_shape
==
0
)
{
std
::
vector
<
std
::
size_t
>
dims
=
args
[
0
]
->
get_shape
().
lens
();
migraphx
::
shape
s
{
type
,
dims
};
return
prog
.
add_literal
(
migraphx
::
literal
(
s
,
{
value
}));
}
else
{
MIGRAPHX_THROW
(
"Wrong input for ConstantFill"
);
}
}
void
parse_from
(
std
::
istream
&
is
)
{
onnx
::
ModelProto
model
;
...
...
@@ -774,6 +862,28 @@ struct onnx_parser
});
return
{
shape_type
,
dims
};
}
shape
::
type_t
get_type
(
int
dtype
)
{
switch
(
dtype
)
{
case
1
:
return
shape
::
float_type
;
case
2
:
return
shape
::
uint8_type
;
case
3
:
return
shape
::
int8_type
;
case
4
:
return
shape
::
uint16_type
;
case
5
:
return
shape
::
int16_type
;
case
6
:
return
shape
::
int32_type
;
case
7
:
return
shape
::
int64_type
;
case
10
:
return
shape
::
half_type
;
case
11
:
return
shape
::
double_type
;
case
12
:
return
shape
::
uint32_type
;
case
13
:
return
shape
::
uint64_type
;
default:
{
MIGRAPHX_THROW
(
"Prototensor data type "
+
std
::
to_string
(
dtype
)
+
" not supported"
);
}
}
}
};
program
parse_onnx
(
const
std
::
string
&
name
)
...
...
src/targets/cpu/lowering.cpp
View file @
f5e0c343
...
...
@@ -322,6 +322,30 @@ struct cpu_gemm
}
};
struct
cpu_gather
{
op
::
gather
op
;
std
::
string
name
()
const
{
return
"cpu::gather"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
op
.
compute
(
output_shape
,
args
);
}
};
struct
cpu_gather_torch
{
op
::
gather_torch
op
;
std
::
string
name
()
const
{
return
"cpu::gather_torch"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
op
.
compute
(
output_shape
,
args
);
}
};
struct
identity_op
{
std
::
string
name
()
const
{
return
"cpu::identity"
;
}
...
...
@@ -651,6 +675,9 @@ struct cpu_apply
extend_op
<
cpu_batch_norm_inference
,
op
::
batch_norm_inference
>
();
apply_map
[
"contiguous"
]
=
extend_op
<
cpu_contiguous
,
op
::
contiguous
>
();
apply_map
[
"concat"
]
=
extend_op
<
cpu_concat
,
op
::
concat
>
();
// To support the rnn from pytorch, we need to use the algorithm
// of gather in torch.nn.gather
apply_map
[
"gather"
]
=
extend_op
<
cpu_gather_torch
,
op
::
gather_torch
>
();
apply_map
[
"leaky_relu"
]
=
extend_op
<
cpu_unary
<
leaky_relu_op
>
,
op
::
leaky_relu
>
();
apply_map
[
"elu"
]
=
extend_op
<
cpu_unary
<
elu_op
>
,
op
::
elu
>
();
apply_map
[
"identity"
]
=
simple_op
<
cpu_unary
<
identity_op
>>
();
...
...
src/targets/gpu/CMakeLists.txt
View file @
f5e0c343
...
...
@@ -28,6 +28,7 @@ add_library(migraphx_device
device/contiguous.cpp
device/mul.cpp
device/concat.cpp
device/gather.cpp
)
set_target_properties
(
migraphx_device PROPERTIES EXPORT_NAME device
)
rocm_clang_tidy_check
(
migraphx_device
)
...
...
@@ -56,6 +57,7 @@ add_library(migraphx_gpu
sigmoid.cpp
abs.cpp
elu.cpp
gather.cpp
)
set_target_properties
(
migraphx_gpu PROPERTIES EXPORT_NAME gpu
)
rocm_clang_tidy_check
(
migraphx_gpu
)
...
...
src/targets/gpu/device/gather.cpp
0 → 100644
View file @
f5e0c343
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/gather.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/hip.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
argument
gather
(
hipStream_t
stream
,
const
migraphx
::
shape
&
output_shape
,
std
::
vector
<
migraphx
::
argument
>
args
,
std
::
size_t
axis
)
{
visit_all
(
args
.
back
(),
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
std
::
size_t
nelements
=
output_shape
.
elements
();
args
[
1
].
visit
([
&
](
auto
indices
)
{
visit_tensor_size
(
output_shape
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
auto
*
outptr
=
device_cast
(
output
.
data
());
const
auto
*
inptr
=
device_cast
(
input
.
data
());
hip_tensor_descriptor
<
ndim
>
desc_input
(
input
.
get_shape
());
hip_tensor_descriptor
<
ndim
>
desc_output
(
output
.
get_shape
());
gs_launch
(
stream
,
nelements
)([
=
](
auto
i
)
{
auto
lens
=
desc_output
.
multi
(
i
);
lens
[
axis
]
=
indices_ptr
[
lens
[
axis
]];
outptr
[
i
]
=
inptr
[
desc_input
.
linear
(
lens
)];
});
});
});
});
return
args
.
back
();
}
argument
gather_torch
(
hipStream_t
stream
,
const
migraphx
::
shape
&
output_shape
,
std
::
vector
<
migraphx
::
argument
>
args
,
std
::
size_t
axis
)
{
visit_all
(
args
.
back
(),
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
std
::
size_t
nelements
=
output_shape
.
elements
();
args
[
1
].
visit
([
&
](
auto
indices
)
{
visit_tensor_size
(
output_shape
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
auto
*
outptr
=
device_cast
(
output
.
data
());
const
auto
*
inptr
=
device_cast
(
input
.
data
());
hip_tensor_descriptor
<
ndim
>
desc_input
(
input
.
get_shape
());
hip_tensor_descriptor
<
ndim
>
desc_output
(
output
.
get_shape
());
hip_tensor_descriptor
<
ndim
>
desc_ind
(
output
.
get_shape
());
gs_launch
(
stream
,
nelements
)([
=
](
auto
i
)
{
auto
lens
=
desc_output
.
multi
(
i
);
lens
[
axis
]
=
indices_ptr
[
desc_ind
.
linear
(
lens
)];
outptr
[
i
]
=
inptr
[
desc_input
.
linear
(
lens
)];
});
});
});
});
return
args
.
back
();
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/gather.cpp
0 → 100644
View file @
f5e0c343
#include <migraphx/gpu/gather.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/device/concat.hpp>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
shape
hip_gather
::
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
inputs
.
pop_back
();
return
op
.
compute_shape
(
inputs
);
}
argument
hip_gather
::
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
device
::
gather
(
ctx
.
get_stream
().
get
(),
output_shape
,
args
,
op
.
axis
);
}
shape
hip_gather_torch
::
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
inputs
.
pop_back
();
return
op
.
compute_shape
(
inputs
);
}
argument
hip_gather_torch
::
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
device
::
gather_torch
(
ctx
.
get_stream
().
get
(),
output_shape
,
args
,
op
.
axis
);
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/include/migraphx/gpu/device/gather.hpp
0 → 100644
View file @
f5e0c343
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_GATHER_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_GATHER_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
// use algorithm of onnx::gather (not used for now)
argument
gather
(
hipStream_t
stream
,
const
migraphx
::
shape
&
output_shape
,
std
::
vector
<
migraphx
::
argument
>
args
,
std
::
size_t
axis
);
// use algorithm of torch.nn.gather
argument
gather_torch
(
hipStream_t
stream
,
const
migraphx
::
shape
&
output_shape
,
std
::
vector
<
migraphx
::
argument
>
args
,
std
::
size_t
axis
);
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/include/migraphx/gpu/gather.hpp
0 → 100644
View file @
f5e0c343
#ifndef MIGRAPHX_GUARD_RTGLIB_GATHER_HPP
#define MIGRAPHX_GUARD_RTGLIB_GATHER_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/gather.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
// use algorithm of onnx::gather (not use for now)
struct
hip_gather
{
op
::
gather
op
;
std
::
string
name
()
const
{
return
"gpu::gather"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
;
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
int
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
// use algorithm of torch.nn.gather
struct
hip_gather_torch
{
op
::
gather_torch
op
;
std
::
string
name
()
const
{
return
"gpu::gather_torch"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
;
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
int
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/lowering.cpp
View file @
f5e0c343
...
...
@@ -40,6 +40,7 @@
#include <migraphx/gpu/pooling.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/concat.hpp>
#include <migraphx/gpu/gather.hpp>
#include <utility>
#include <functional>
#include <algorithm>
...
...
@@ -90,7 +91,7 @@ struct miopen_apply
add_extend_op
<
miopen_contiguous
,
op
::
contiguous
>
(
"contiguous"
);
add_extend_op
<
hip_concat
,
op
::
concat
>
(
"concat"
);
add_extend_op
<
miopen_softmax
,
op
::
softmax
>
(
"softmax"
);
add_extend_op
<
hip_gather_torch
,
op
::
gather_torch
>
(
"gather"
);
add_convolution_op
();
add_pooling_op
();
add_batch_norm_inference_op
();
...
...
test/cpu_ops_test.cpp
View file @
f5e0c343
...
...
@@ -101,6 +101,49 @@ TEST_CASE(concat_test)
}
}
TEST_CASE
(
gather_test
)
{
{
migraphx
::
program
p
;
std
::
vector
<
float
>
data
(
3
*
3
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0.5
);
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
3
}};
auto
a0
=
p
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
migraphx
::
shape
s_indices
{
migraphx
::
shape
::
int32_type
,
{
1
,
2
}};
std
::
vector
<
int
>
indices
{
0
,
2
};
auto
a1
=
p
.
add_literal
(
migraphx
::
literal
{
s_indices
,
indices
});
std
::
size_t
axis
=
0
;
p
.
add_instruction
(
migraphx
::
op
::
gather_torch
{
axis
},
a0
,
a1
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
res_data
(
4
*
5
);
std
::
vector
<
float
>
golden
=
{
0.5
f
,
7.5
f
};
result
.
visit
([
&
](
auto
output
)
{
res_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
res_data
,
golden
));
}
{
migraphx
::
program
p
;
std
::
vector
<
float
>
data
(
3
*
3
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0.5
);
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
3
}};
auto
a0
=
p
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
migraphx
::
shape
s_indices
{
migraphx
::
shape
::
int32_type
,
{
1
,
2
}};
std
::
vector
<
int
>
indices
{
0
,
2
};
auto
a1
=
p
.
add_literal
(
migraphx
::
literal
{
s_indices
,
indices
});
std
::
size_t
axis
=
1
;
p
.
add_instruction
(
migraphx
::
op
::
gather_torch
{
axis
},
a0
,
a1
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
res_data
(
4
*
5
);
std
::
vector
<
float
>
golden
=
{
0.5
f
,
2.5
f
};
result
.
visit
([
&
](
auto
output
)
{
res_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
res_data
,
golden
));
}
}
TEST_CASE
(
squeeze_test
)
{
{
...
...
test/gpu/miopen.cpp
View file @
f5e0c343
...
...
@@ -934,6 +934,22 @@ struct test_concat_relu
}
};
struct
test_gather
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
3
}};
migraphx
::
shape
s_indices
{
migraphx
::
shape
::
int32_type
,
{
2
,
2
}};
std
::
vector
<
int
>
indices
{
1
,
2
,
2
,
1
};
auto
a0
=
p
.
add_parameter
(
"data"
,
s
);
auto
a1
=
p
.
add_literal
(
migraphx
::
literal
{
s_indices
,
indices
});
std
::
size_t
axis
=
0
;
p
.
add_instruction
(
migraphx
::
op
::
gather_torch
{
axis
},
a0
,
a1
);
return
p
;
}
};
void
manual_identity
()
{
migraphx
::
program
p
;
...
...
test/onnx/gather_test.onnx
0 → 100644
View file @
f5e0c343
gather - example :–
'
data
indicesy "Gather*
axis test_gatherZ
data
Z !
indices
b
y
B
\ No newline at end of file
test/onnx/onnx_test.cpp
View file @
f5e0c343
...
...
@@ -400,6 +400,19 @@ TEST_CASE(reshape_test)
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
gather_test
)
{
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"data"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
5
,
6
}});
auto
l1
=
p
.
add_parameter
(
"indices"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
,
4
,
5
}});
std
::
size_t
axis
=
1
;
p
.
add_instruction
(
migraphx
::
op
::
gather_torch
{
axis
},
l0
,
l1
);
auto
prog
=
migraphx
::
parse_onnx
(
"gather_test.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
flatten_test
)
{
migraphx
::
program
p
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment