Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
803eb3ce
Unverified
Commit
803eb3ce
authored
Nov 12, 2018
by
Paul Fultz II
Committed by
GitHub
Nov 12, 2018
Browse files
Merge pull request #112 from ROCmSoftwarePlatform/multi_broadcast
Multi broadcast
parents
d9b08400
bd60be01
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
212 additions
and
35 deletions
+212
-35
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+45
-1
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+65
-17
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+38
-17
test/op_shape_test.cpp
test/op_shape_test.cpp
+64
-0
No files found.
src/include/migraph/operators.hpp
View file @
803eb3ce
...
@@ -762,6 +762,48 @@ struct broadcast
...
@@ -762,6 +762,48 @@ struct broadcast
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
struct
multibroadcast
{
std
::
vector
<
std
::
size_t
>
output_lens
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
output_lens
,
"output_lens"
));
}
std
::
string
name
()
const
{
return
"multibroadcast"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
t
=
inputs
.
at
(
0
).
type
();
auto
input
=
inputs
.
at
(
0
);
if
(
input
.
lens
().
empty
())
MIGRAPH_THROW
(
"inputs dimensions should be > 0"
);
if
(
input
.
lens
().
size
()
>
output_lens
.
size
())
MIGRAPH_THROW
(
"inputs dimensions should <= output size"
);
std
::
vector
<
size_t
>
bcast_strides
(
output_lens
.
size
(),
0
);
auto
offset
=
output_lens
.
size
()
-
input
.
lens
().
size
();
for
(
int
i
=
input
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
{
if
(
output_lens
[
i
+
offset
]
==
input
.
lens
()[
i
])
{
bcast_strides
[
i
+
offset
]
=
input
.
strides
()[
i
];
}
}
return
{
t
,
output_lens
,
bcast_strides
};
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
struct
scalar
struct
scalar
{
{
shape
scalar_bcast
;
shape
scalar_bcast
;
...
@@ -789,7 +831,9 @@ struct binary
...
@@ -789,7 +831,9 @@ struct binary
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
}.
has
(
2
).
same_type
().
same_dims
();
check_shapes
{
inputs
}.
has
(
2
).
same_type
().
same_dims
();
return
inputs
.
at
(
0
);
auto
t
=
inputs
.
at
(
0
).
type
();
auto
lens
=
inputs
.
at
(
0
).
lens
();
return
{
t
,
lens
};
}
}
};
};
...
...
src/onnx/onnx.cpp
View file @
803eb3ce
...
@@ -49,16 +49,17 @@ struct onnx_parser
...
@@ -49,16 +49,17 @@ struct onnx_parser
onnx_parser
()
onnx_parser
()
{
{
add_generic_op
(
"Add"
,
op
::
add
{});
add_generic_op
(
"Div"
,
op
::
div
{});
add_generic_op
(
"MatMul"
,
op
::
dot
{});
add_generic_op
(
"MatMul"
,
op
::
dot
{});
add_generic_op
(
"Mul"
,
op
::
mul
{});
add_generic_op
(
"Relu"
,
op
::
relu
{});
add_generic_op
(
"Relu"
,
op
::
relu
{});
add_generic_op
(
"Sub"
,
op
::
sub
{});
add_generic_op
(
"Sum"
,
op
::
add
{});
// disable dropout for inference
// disable dropout for inference
add_generic_op
(
"Dropout"
,
op
::
identity
{});
add_generic_op
(
"Dropout"
,
op
::
identity
{});
add_broadcastable_binary_op
(
"Add"
,
op
::
add
{});
add_broadcastable_binary_op
(
"Div"
,
op
::
div
{});
add_broadcastable_binary_op
(
"Mul"
,
op
::
mul
{});
add_broadcastable_binary_op
(
"Sub"
,
op
::
sub
{});
add_broadcastable_binary_op
(
"Sum"
,
op
::
add
{});
add_mem_op
(
"ImageScaler"
,
&
onnx_parser
::
parse_imagescaler
);
add_mem_op
(
"ImageScaler"
,
&
onnx_parser
::
parse_imagescaler
);
add_mem_op
(
"LeakyRelu"
,
&
onnx_parser
::
parse_leaky_relu
);
add_mem_op
(
"LeakyRelu"
,
&
onnx_parser
::
parse_leaky_relu
);
add_mem_op
(
"Constant"
,
&
onnx_parser
::
parse_constant
);
add_mem_op
(
"Constant"
,
&
onnx_parser
::
parse_constant
);
...
@@ -92,12 +93,13 @@ struct onnx_parser
...
@@ -92,12 +93,13 @@ struct onnx_parser
return
std
::
mem_fn
(
f
)(
*
this
,
name
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
return
std
::
mem_fn
(
f
)(
*
this
,
name
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
});
});
}
}
template
<
class
T
>
template
<
class
T
>
void
add_
generic
_op
(
std
::
string
name
,
T
x
)
void
add_
broadcastable_binary
_op
(
std
::
string
name
,
T
x
)
{
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
args
.
size
()
==
2
and
contains
(
attributes
,
"broadcast"
))
if
(
args
.
size
()
!=
2
)
MIGRAPH_THROW
(
"binary operators should have 2 operands"
);
if
(
contains
(
attributes
,
"broadcast"
))
{
{
uint64_t
broadcasted
=
parse_value
(
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
uint64_t
broadcasted
=
parse_value
(
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
if
(
broadcasted
!=
0
)
if
(
broadcasted
!=
0
)
...
@@ -109,7 +111,51 @@ struct onnx_parser
...
@@ -109,7 +111,51 @@ struct onnx_parser
prog
.
add_instruction
(
op
::
broadcast
{
axis
,
args
[
0
]
->
get_shape
()},
args
[
1
]);
prog
.
add_instruction
(
op
::
broadcast
{
axis
,
args
[
0
]
->
get_shape
()},
args
[
1
]);
return
prog
.
add_instruction
(
x
,
args
[
0
],
l
);
return
prog
.
add_instruction
(
x
,
args
[
0
],
l
);
}
}
return
prog
.
add_instruction
(
x
,
args
);
}
else
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
// Get lengths for both arguments
const
std
::
vector
<
std
::
size_t
>*
s0
=
&
args
[
0
]
->
get_shape
().
lens
();
const
std
::
vector
<
std
::
size_t
>*
s1
=
&
args
[
1
]
->
get_shape
().
lens
();
// Make sure s0 is the smaller size
if
(
s0
->
size
()
>
s1
->
size
())
std
::
swap
(
s0
,
s1
);
// Copy the larger vector to output_lens
std
::
vector
<
std
::
size_t
>
output_lens
(
s1
->
size
());
auto
offset
=
s1
->
size
()
-
s0
->
size
();
std
::
transform
(
s0
->
begin
(),
s0
->
end
(),
s1
->
begin
()
+
offset
,
output_lens
.
begin
()
+
offset
,
[](
auto
a
,
auto
b
)
{
return
std
::
max
(
a
,
b
);
});
auto
l0
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
args
[
0
]);
auto
l1
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
args
[
1
]);
return
prog
.
add_instruction
(
x
,
l0
,
l1
);
}
}
});
}
template
<
class
T
>
void
add_generic_op
(
std
::
string
name
,
T
x
)
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
x
,
args
);
return
prog
.
add_instruction
(
x
,
args
);
});
});
}
}
...
@@ -607,8 +653,10 @@ struct onnx_parser
...
@@ -607,8 +653,10 @@ struct onnx_parser
}
}
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
size_t
>
dims
;
auto
&&
tensor_dims
=
t
.
tensor_type
().
shape
().
dim
();
auto
&&
tensor_dims
=
t
.
tensor_type
().
shape
().
dim
();
std
::
transform
(
std
::
transform
(
tensor_dims
.
begin
(),
tensor_dims
.
begin
(),
tensor_dims
.
end
(),
std
::
back_inserter
(
dims
),
[](
auto
&&
d
)
{
tensor_dims
.
end
(),
std
::
back_inserter
(
dims
),
[](
auto
&&
d
)
->
std
::
size_t
{
if
(
not
d
.
has_dim_value
())
if
(
not
d
.
has_dim_value
())
{
{
long
default_batch_size
=
1
;
// FIXME
long
default_batch_size
=
1
;
// FIXME
...
...
test/cpu_ops_test.cpp
View file @
803eb3ce
...
@@ -487,6 +487,7 @@ TEST_CASE(broadcast_test)
...
@@ -487,6 +487,7 @@ TEST_CASE(broadcast_test)
}
}
TEST_CASE
(
add_broadcast_test
)
TEST_CASE
(
add_broadcast_test
)
{
{
{
migraph
::
program
p
;
migraph
::
program
p
;
migraph
::
shape
a_shape
{
migraph
::
shape
::
float_type
,
{
2
,
2
,
3
}};
migraph
::
shape
a_shape
{
migraph
::
shape
::
float_type
,
{
2
,
2
,
3
}};
std
::
vector
<
float
>
a_data
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
};
std
::
vector
<
float
>
a_data
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
};
...
@@ -504,6 +505,26 @@ TEST_CASE(add_broadcast_test)
...
@@ -504,6 +505,26 @@ TEST_CASE(add_broadcast_test)
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
0
,
1
,
2
,
2
,
3
,
4
,
4
,
5
,
6
,
6
,
7
,
8
};
std
::
vector
<
float
>
gold
=
{
0
,
1
,
2
,
2
,
3
,
4
,
4
,
5
,
6
,
6
,
7
,
8
};
EXPECT
(
migraph
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraph
::
verify_range
(
results_vector
,
gold
));
}
{
migraph
::
program
p
;
migraph
::
shape
a_shape
{
migraph
::
shape
::
float_type
,
{
2
,
2
,
3
}};
std
::
vector
<
float
>
a_data
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
};
migraph
::
shape
b_shape
{
migraph
::
shape
::
float_type
,
{
2
,
2
,
1
}};
std
::
vector
<
float
>
b_data
{
0
,
-
1
,
-
2
,
-
3
};
auto
l1
=
p
.
add_literal
(
migraph
::
literal
{
a_shape
,
a_data
});
auto
l2
=
p
.
add_literal
(
migraph
::
literal
{
b_shape
,
b_data
});
auto
l3
=
p
.
add_instruction
(
migraph
::
op
::
multibroadcast
{{
2
,
2
,
3
}},
l1
);
auto
l4
=
p
.
add_instruction
(
migraph
::
op
::
multibroadcast
{{
2
,
2
,
3
}},
l2
);
p
.
add_instruction
(
migraph
::
op
::
add
{},
l3
,
l4
);
p
.
compile
(
migraph
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
EXPECT
(
result
.
get_shape
().
packed
());
std
::
vector
<
float
>
results_vector
(
12
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
0
,
1
,
2
,
2
,
3
,
4
,
4
,
5
,
6
,
6
,
7
,
8
};
EXPECT
(
migraph
::
verify_range
(
results_vector
,
gold
));
}
}
}
TEST_CASE
(
sub_test
)
TEST_CASE
(
sub_test
)
...
...
test/op_shape_test.cpp
View file @
803eb3ce
...
@@ -145,4 +145,68 @@ TEST_CASE(slice_shape)
...
@@ -145,4 +145,68 @@ TEST_CASE(slice_shape)
migraph
::
op
::
slice
{{
2
},
{
2
},
{
10
}},
migraph
::
op
::
slice
{{
2
},
{
2
},
{
10
}},
input
);
input
);
}
}
TEST_CASE
(
multibroadcast
)
{
{
std
::
vector
<
std
::
size_t
>
lens
{
4
,
2
,
5
,
3
};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
2
,
1
,
3
}};
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
lens
,
{
0
,
3
,
0
,
1
}},
migraph
::
op
::
multibroadcast
{
lens
},
input
);
}
{
std
::
vector
<
std
::
size_t
>
lens
{
4
,
2
,
5
,
3
};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
2
,
1
,
1
}};
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
lens
,
{
0
,
1
,
0
,
0
}},
migraph
::
op
::
multibroadcast
{
lens
},
input
);
}
{
std
::
vector
<
std
::
size_t
>
lens
{
4
,
2
,
5
,
3
};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
5
,
1
}};
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
lens
,
{
0
,
0
,
1
,
0
}},
migraph
::
op
::
multibroadcast
{
lens
},
input
);
}
{
std
::
vector
<
std
::
size_t
>
lens
{
4
,
2
,
5
,
3
};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
4
,
1
,
1
,
1
}};
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
lens
,
{
1
,
0
,
0
,
0
}},
migraph
::
op
::
multibroadcast
{
lens
},
input
);
}
{
std
::
vector
<
std
::
size_t
>
lens
{
4
,
2
,
5
,
3
};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
3
}};
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
lens
,
{
0
,
0
,
0
,
1
}},
migraph
::
op
::
multibroadcast
{
lens
},
input
);
}
{
std
::
vector
<
std
::
size_t
>
lens
{
4
,
4
,
1
,
3
};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
4
,
1
,
3
}};
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
lens
,
{
0
,
3
,
3
,
1
}},
migraph
::
op
::
multibroadcast
{
lens
},
input
);
}
{
std
::
vector
<
std
::
size_t
>
lens
{
4
,
1
,
1
,
3
};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
4
,
1
,
1
,
1
}};
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
lens
,
{
1
,
1
,
1
,
0
}},
migraph
::
op
::
multibroadcast
{
lens
},
input
);
}
{
std
::
vector
<
std
::
size_t
>
lens
{
4
,
1
,
3
};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
4
,
1
,
1
,
1
}};
throws_shape
(
migraph
::
op
::
multibroadcast
{
lens
},
input
);
}
{
std
::
vector
<
std
::
size_t
>
lens
{
4
,
1
,
3
};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{}};
throws_shape
(
migraph
::
op
::
multibroadcast
{
lens
},
input
);
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment