Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
89b80be6
Commit
89b80be6
authored
Feb 25, 2019
by
Shucai Xiao
Browse files
Merge branch 'gather_operator' into seq2seq_example
parents
82bd8e2e
33b6bcb6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
168 additions
and
40 deletions
+168
-40
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+25
-40
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+42
-0
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+51
-0
test/op_shape_test.cpp
test/op_shape_test.cpp
+50
-0
No files found.
src/include/migraphx/operators.hpp
View file @
89b80be6
...
...
@@ -766,7 +766,7 @@ struct gather
}
// for scalar output
if
(
lens
.
size
()
==
0
)
if
(
lens
.
empty
()
)
{
return
{
type
,
{
1
},
{
0
}};
}
...
...
@@ -774,21 +774,27 @@ struct gather
return
{
type
,
lens
};
}
// template <class T>
// void compute_index(const T& out_idx,
// const int axis_index,
// const std::vector<std::size_t>& vec_indices,
// const std::size_t max_dim,
// T& in_idx) const
// {
// in_idx = out_idx;
// std::size_t idx = vec_indices.at(out_idx[axis_index]);
// if(idx >= max_dim)
// {
// MIGRAPHX_THROW("Gather: indices are out of range in input tensor");
// }
// in_idx[axis_index] = idx;
// }
template
<
typename
V
,
typename
T
>
T
compute_data_index
(
const
V
&
indices
,
const
int
axis_index
,
const
T
&
out_idx
)
const
{
auto
data_idx
=
out_idx
;
std
::
size_t
index
{};
if
(
!
indices
.
get_shape
().
scalar
())
{
auto
start_it
=
data_idx
.
begin
()
+
axis_index
;
auto
end_it
=
data_idx
.
begin
()
+
axis_index
+
indices
.
get_shape
().
lens
().
size
();
std
::
vector
<
std
::
size_t
>
ind_idx
(
start_it
,
end_it
);
data_idx
.
erase
(
start_it
,
end_it
);
index
=
indices
(
ind_idx
.
begin
(),
ind_idx
.
end
());
}
else
{
index
=
indices
.
front
();
}
data_idx
.
insert
(
data_idx
.
begin
()
+
axis_index
,
index
);
return
data_idx
;
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
...
...
@@ -797,37 +803,16 @@ struct gather
int
axis_index
=
(
axis
<
0
)
?
(
args
[
0
].
get_shape
().
lens
().
size
()
+
axis
)
:
axis
;
// max dimension in axis
// std::size_t max_dim = args[0].get_shape().lens()[axis_index];
// std::vector<std::size_t> vec_indices;
// args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
if
(
indices
.
ge
t_shape
()
.
scalar
())
if
(
outpu
t_shape
.
scalar
())
{
if
(
output_shape
.
scalar
())
{
output
[
0
]
=
data
[
indices
.
front
()];
}
else
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
out_idx
)
{
auto
data_idx
=
out_idx
;
data_idx
.
insert
(
data_idx
.
begin
()
+
axis_index
,
indices
.
front
());
output
(
out_idx
.
begin
(),
out_idx
.
end
())
=
data
(
data_idx
.
begin
(),
data_idx
.
end
());
});
}
output
[
0
]
=
data
[
indices
.
front
()];
}
else
{
auto
ind_lens
=
indices
.
get_shape
().
lens
();
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
out_idx
)
{
auto
data_idx
=
out_idx
;
auto
start_it
=
data_idx
.
begin
()
+
axis_index
;
auto
end_it
=
data_idx
.
begin
()
+
axis_index
+
ind_lens
.
size
();
std
::
vector
<
std
::
size_t
>
ind_idx
(
start_it
,
end_it
);
data_idx
.
erase
(
start_it
,
end_it
);
data_idx
.
insert
(
start_it
,
indices
(
ind_idx
.
begin
(),
ind_idx
.
end
()));
auto
data_idx
=
compute_data_index
(
indices
,
axis_index
,
out_idx
);
output
(
out_idx
.
begin
(),
out_idx
.
end
())
=
data
(
data_idx
.
begin
(),
data_idx
.
end
());
});
...
...
test/cpu_ops_test.cpp
View file @
89b80be6
...
...
@@ -164,6 +164,48 @@ TEST_CASE(gather_test)
result
.
visit
([
&
](
auto
output
)
{
res_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
res_data
,
golden
));
}
{
migraphx
::
program
p
;
std
::
vector
<
float
>
data
(
3
*
3
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0.5
);
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
3
}};
auto
a0
=
p
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
// scalar index
migraphx
::
shape
s_indices
{
migraphx
::
shape
::
int32_type
,
{
1
},
{
0
}};
std
::
vector
<
int
>
indices
{
0
};
auto
a1
=
p
.
add_literal
(
migraphx
::
literal
{
s_indices
,
indices
});
int
axis
=
-
1
;
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
a0
,
a1
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
res_data
{};
std
::
vector
<
float
>
golden
=
{
0.5
f
,
3.5
f
,
6.5
f
};
result
.
visit
([
&
](
auto
output
)
{
res_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
res_data
,
golden
));
}
{
migraphx
::
program
p
;
std
::
vector
<
float
>
data
(
3
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0.5
);
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
}};
auto
a0
=
p
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
// scalar index
migraphx
::
shape
s_indices
{
migraphx
::
shape
::
int32_type
,
{
1
},
{
0
}};
std
::
vector
<
int
>
indices
{
0
};
auto
a1
=
p
.
add_literal
(
migraphx
::
literal
{
s_indices
,
indices
});
int
axis
=
-
1
;
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
a0
,
a1
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
res_data
{};
std
::
vector
<
float
>
golden
=
{
0.5
f
};
result
.
visit
([
&
](
auto
output
)
{
res_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
res_data
,
golden
));
}
}
TEST_CASE
(
squeeze_test
)
...
...
test/gpu/miopen.cpp
View file @
89b80be6
...
...
@@ -1068,6 +1068,54 @@ struct test_gather_neg_axis
}
};
struct
test_gather_scalar_output
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
}};
migraphx
::
shape
s_indices
{
migraphx
::
shape
::
int32_type
,
{
1
},
{
0
}};
std
::
vector
<
int
>
indices
{
1
};
auto
a0
=
p
.
add_parameter
(
"data"
,
s
);
auto
a1
=
p
.
add_literal
(
migraphx
::
literal
{
s_indices
,
indices
});
int
axis
=
0
;
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
a0
,
a1
);
return
p
;
}
};
struct
test_gather_scalar_index
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
3
}};
migraphx
::
shape
s_indices
{
migraphx
::
shape
::
int32_type
,
{
1
},
{
0
}};
std
::
vector
<
int
>
indices
{
1
};
auto
a0
=
p
.
add_parameter
(
"data"
,
s
);
auto
a1
=
p
.
add_literal
(
migraphx
::
literal
{
s_indices
,
indices
});
int
axis
=
-
1
;
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
a0
,
a1
);
return
p
;
}
};
struct
test_gather_1d_index
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
3
}};
migraphx
::
shape
s_indices
{
migraphx
::
shape
::
int32_type
,
{
1
}};
std
::
vector
<
int
>
indices
{
1
};
auto
a0
=
p
.
add_parameter
(
"data"
,
s
);
auto
a1
=
p
.
add_literal
(
migraphx
::
literal
{
s_indices
,
indices
});
int
axis
=
-
1
;
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
a0
,
a1
);
return
p
;
}
};
void
manual_identity
()
{
migraphx
::
program
p
;
...
...
@@ -2904,6 +2952,9 @@ int main()
verify_program
<
test_slice
>
();
verify_program
<
test_gather
>
();
verify_program
<
test_gather_neg_axis
>
();
verify_program
<
test_gather_scalar_output
>
();
verify_program
<
test_gather_scalar_index
>
();
verify_program
<
test_gather_1d_index
>
();
verify_program
<
test_rnn_forward
>
();
verify_program
<
test_rnn_forward10
>
();
verify_program
<
test_rnn_reverse
>
();
...
...
test/op_shape_test.cpp
View file @
89b80be6
...
...
@@ -251,6 +251,56 @@ TEST_CASE(gather)
indices
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
1
}};
int
axis
=
-
4
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
4
,
5
}},
migraphx
::
op
::
gather
{
axis
},
input
,
indices
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
1
},
{
0
}};
int
axis
=
-
4
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
5
}},
migraphx
::
op
::
gather
{
axis
},
input
,
indices
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
1
},
{
0
}};
int
axis
=
3
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}},
migraphx
::
op
::
gather
{
axis
},
input
,
indices
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
1
},
{
0
}};
int
axis
=
0
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
},
{
0
}},
migraphx
::
op
::
gather
{
axis
},
input
,
indices
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
1
}};
int
axis
=
0
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}},
migraphx
::
op
::
gather
{
axis
},
input
,
indices
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment