Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
007ea283
Commit
007ea283
authored
Mar 08, 2019
by
Shucai Xiao
Browse files
backup second implementation of the compute_shape for the dot operator
parent
359ec2f8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
98 additions
and
107 deletions
+98
-107
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+98
-107
No files found.
src/include/migraphx/operators.hpp
View file @
007ea283
...
@@ -830,16 +830,36 @@ struct dot
...
@@ -830,16 +830,36 @@ struct dot
return
pack
(
f
(
self
.
alpha
,
"alpha"
),
f
(
self
.
beta
,
"beta"
));
return
pack
(
f
(
self
.
alpha
,
"alpha"
),
f
(
self
.
beta
,
"beta"
));
}
}
// if not a multi_broadcast, b should be broadcastable to a
std
::
vector
<
std
::
size_t
>
shape_broadcast
(
std
::
vector
<
std
::
size_t
>&
a
,
std
::
vector
<
std
::
size_t
>
shape_broadcast
(
std
::
vector
<
std
::
size_t
>&
a
,
std
::
vector
<
std
::
size_t
>&
b
)
const
std
::
vector
<
std
::
size_t
>&
b
,
bool
is_mutli_broadcast
=
true
)
const
{
{
if
(
a
.
empty
())
if
(
b
.
empty
())
return
b
;
else
if
(
b
.
empty
())
return
a
;
return
a
;
if
(
a
.
empty
())
{
if
(
is_mutli_broadcast
)
{
return
b
;
}
else
{
MIGRAPHX_THROW
(
"DOT: C is not broadcastable to A * B (scalar)"
);
}
}
auto
a_size
=
a
.
size
();
auto
a_size
=
a
.
size
();
auto
b_size
=
b
.
size
();
auto
b_size
=
b
.
size
();
if
(
is_mutli_broadcast
&&
b_size
>
a_size
)
{
MIGRAPHX_THROW
(
"DOT: C {"
+
to_string_range
(
b
)
+
"} is not broadcastable to A * b {"
+
to_string_range
(
a
)
+
"}"
);
}
auto
n_dim
=
std
::
min
(
a_size
,
b_size
);
auto
n_dim
=
std
::
min
(
a_size
,
b_size
);
std
::
vector
<
std
::
size_t
>
out_lens
(
std
::
max
(
a_size
,
b_size
));
std
::
vector
<
std
::
size_t
>
out_lens
(
std
::
max
(
a_size
,
b_size
));
for
(
std
::
size_t
i
=
0
;
i
<
n_dim
;
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
n_dim
;
++
i
)
...
@@ -848,19 +868,31 @@ struct dot
...
@@ -848,19 +868,31 @@ struct dot
{
{
out_lens
[
i
]
=
a
[
a_size
-
1
-
i
];
out_lens
[
i
]
=
a
[
a_size
-
1
-
i
];
}
}
else
if
(
a
[
a_size
-
1
-
i
]
==
1
)
{
out_lens
[
i
]
=
b
[
b_size
-
1
-
i
];
}
else
if
(
b
[
b_size
-
1
-
i
]
==
1
)
else
if
(
b
[
b_size
-
1
-
i
]
==
1
)
{
{
out_lens
[
i
]
=
a
[
a_size
-
1
-
i
];
out_lens
[
i
]
=
a
[
a_size
-
1
-
i
];
}
}
else
else
{
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, matrix A: {"
+
to_string_range
(
a
)
+
if
(
a
[
a_size
-
1
-
i
]
==
1
&&
is_mutli_broadcast
)
"}, and matrix B: {"
+
to_string_range
(
b
)
+
{
"} are not broadcastable"
);
out_lens
[
i
]
=
b
[
b_size
-
1
-
i
];
}
else
{
if
(
is_mutli_broadcast
)
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, matrix A: {"
+
to_string_range
(
a
)
+
"}, and matrix B: {"
+
to_string_range
(
b
)
+
"} are not broadcastable"
);
}
else
{
MIGRAPHX_THROW
(
"DOT: C {"
+
to_string_range
(
b
)
+
"} is not broadcastable to A * b {"
+
to_string_range
(
a
)
+
"}"
);
}
}
}
}
}
}
...
@@ -894,120 +926,79 @@ struct dot
...
@@ -894,120 +926,79 @@ struct dot
auto
a_lens
=
a
.
lens
();
auto
a_lens
=
a
.
lens
();
auto
b_lens
=
b
.
lens
();
auto
b_lens
=
b
.
lens
();
std
::
vector
<
std
::
size_t
>
out_lens
;
bool
is_a_appended
=
false
;
if
(
a_lens
.
size
()
==
1
)
bool
is_b_appended
=
false
;
if
(
a_lens
.
size
()
==
1
)
{
{
// inner product, output is a scalar, following numpy.matmul()
a_lens
.
insert
(
a_lens
.
begin
(),
1
);
if
(
b_lens
.
size
()
==
1
)
is_a_appended
=
true
;
{
}
if
(
a_lens
.
front
()
!=
b_lens
.
front
())
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, vector A: {"
+
to_string_range
(
a_lens
)
+
"}, cannot multiply vector B: {"
+
to_string_range
(
b_lens
)
+
"}"
);
}
}
else
{
std
::
size_t
dim_0
=
b_lens
.
size
()
-
2
;
if
(
a_lens
.
front
()
!=
b_lens
[
dim_0
])
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, vector A: {"
+
to_string_range
(
a_lens
)
+
"}, cannot multiply matrix B: {"
+
to_string_range
(
b_lens
)
+
"}"
);
}
out_lens
=
b_lens
;
if
(
b_lens
.
size
()
==
1
)
out_lens
.
erase
(
out_lens
.
begin
()
+
dim_0
);
{
}
b_lens
.
push_back
(
1
);
is_b_appended
=
true
;
}
}
else
std
::
size_t
dim_0
=
a_lens
.
size
()
-
1
;
std
::
size_t
dim_1
=
b_lens
.
size
()
-
2
;
if
(
a_lens
[
dim_0
]
!=
b_lens
[
dim_1
])
{
{
std
::
size_t
dim_0
=
a_lens
.
size
()
-
1
;
MIGRAPHX_THROW
(
"DOT : dimension mismatch, operand A: {"
+
if
(
b_lens
.
size
()
==
1
)
to_string_range
(
a
.
lens
())
+
"}, cannot multiply operand B: {"
+
{
to_string_range
(
b
.
lens
())
+
"}"
);
if
(
a_lens
.
back
()
!=
b_lens
.
back
())
}
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, matrix A: {"
+
to_string_range
(
a_lens
)
+
"}, cannot multiply vector B: {"
+
to_string_range
(
b_lens
)
+
"}"
);
}
out_lens
=
a_lens
;
// remove the matrix dims, do multi_broadcast of the shape of the batch
out_lens
.
pop_back
();
a_lens
.
pop_back
();
}
std
::
size_t
out_m
=
a_lens
.
back
();
else
a_lens
.
pop_back
();
{
std
::
size_t
dim_0
=
a_lens
.
size
()
-
1
;
std
::
size_t
dim_1
=
b_lens
.
size
()
-
2
;
if
(
a_lens
[
dim_0
]
!=
b_lens
[
dim_1
])
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, matrix A: {"
+
to_string_range
(
a_lens
)
+
"}, cannot multiply matrix B: {"
+
to_string_range
(
b_lens
)
+
"}"
);
}
a
_lens
.
pop_
back
();
std
::
size_t
out_n
=
b
_lens
.
back
();
std
::
size_t
out_m
=
a
_lens
.
back
();
b
_lens
.
pop_
back
();
a
_lens
.
pop_back
();
b
_lens
.
pop_back
();
std
::
size_t
out_n
=
b_lens
.
back
(
);
auto
out_lens
=
shape_broadcast
(
a_lens
,
b_lens
);
b
_lens
.
p
op
_back
();
out
_lens
.
p
ush
_back
(
out_m
);
b
_lens
.
p
op
_back
();
out
_lens
.
p
ush
_back
(
out_n
);
out_lens
=
shape_broadcast
(
a_lens
,
b_lens
);
// remove the prepended 1, if a is a vector
out_lens
.
push_back
(
out_m
);
if
(
is_a_appended
)
out_lens
.
push_back
(
out_n
);
{
}
out_lens
.
erase
(
out_lens
.
begin
()
+
out_lens
.
size
()
-
2
);
}
}
// c is broadcast
// remove the appended 1, if b is a vector
if
(
inputs
.
size
()
==
3
)
if
(
is_b_appended
)
// according to the specification of the numpy.matmul()
// inputs with the shape dims more than 2 are acceptable
// as long as dim values are the same in the two inputs
if
(
!
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
))
{
MIGRAPHX_THROW
(
"DOT: number of matrices in stack are different in A and B"
);
}
if
(
inputs
.
size
()
==
3
)
{
{
check_shapes
{{
inputs
[
0
],
inputs
[
2
]},
*
this
}.
has
(
2
).
same_type
();
out_lens
.
pop_back
();
const
shape
&
c
=
inputs
.
at
(
2
);
if
(
!
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
c
.
lens
().
rbegin
()
+
2
))
{
MIGRAPHX_THROW
(
"DOT: number of matrices in stack are different in A and C"
);
}
}
}
std
::
size_t
dim_0
=
a
.
lens
().
size
()
-
2
;
std
::
size_t
dim_1
=
a
.
lens
().
size
()
-
1
;
// c is unibroadcastable to A * B
if
(
a
.
lens
()[
dim_1
]
!=
b
.
lens
()[
dim_0
])
MIGRAPHX_THROW
(
"DOT : inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
if
(
inputs
.
size
()
==
3
)
if
(
inputs
.
size
()
==
3
)
{
{
const
shape
&
c
=
inputs
.
at
(
2
);
// same type as A and B
if
(
a
.
lens
()[
dim_0
]
!=
c
.
lens
()[
dim_0
])
check_shapes
{{
inputs
[
0
],
inputs
[
2
]},
*
this
}.
has
(
2
).
same_type
();
if
(
out_lens
.
empty
()
&&
(
!
inputs
[
2
].
scalar
()))
{
{
MIGRAPHX_THROW
(
"DOT : matrix size does not match: A: {"
+
MIGRAPHX_THROW
(
"DOT: C is not broadcastable to A*B (scalar)"
);
to_string_range
(
a
.
lens
())
+
"}, C: {"
+
to_string_range
(
c
.
lens
())
+
"}"
);
}
}
if
(
b
.
lens
()[
dim_1
]
!=
c
.
lens
()[
dim_1
])
//check c is broadcastable to A * B
{
auto
c_lens
=
inputs
[
2
].
lens
();
MIGRAPHX_THROW
(
"DOT : matrix size does not match: B: {"
+
shape_broadcast
(
out_lens
,
c_lens
,
false
);
to_string_range
(
b
.
lens
())
+
"}, C: {"
+
to_string_range
(
c
.
lens
())
+
"}"
);
}
}
}
auto
out_lens
=
a
.
lens
();
if
(
out_lens
.
empty
())
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
{
return
{
t
,
out_lens
};
return
{
t
};
}
else
{
return
{
t
,
out_lens
};
}
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment