Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2d98b64e
Commit
2d98b64e
authored
Mar 11, 2019
by
Shucai Xiao
Browse files
Added the CPU implementation of the dot operator
parent
b2106be7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
93 additions
and
16 deletions
+93
-16
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+4
-0
src/targets/cpu/gemm.cpp
src/targets/cpu/gemm.cpp
+25
-12
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+64
-4
No files found.
src/include/migraphx/operators.hpp
View file @
2d98b64e
...
...
@@ -819,6 +819,10 @@ struct gather
// vector input; if A or B is 2-dim, it is a matrix (no case of a batch of
// vectors as input). If A or B is 3 or more dims, it is considered as a
// stack(batch) of matrices.
// Note that, we optimze the scenario of either the Matmul or Gemm operators,
// But for extensional scenarios like GEMM with three inputs, and each arg
// is a batch is matrices, the implementation may need further optimization
// later.
struct
dot
{
float
alpha
=
1.0
;
...
...
src/targets/cpu/gemm.cpp
View file @
2d98b64e
...
...
@@ -73,18 +73,32 @@ void migemm_impl(tensor_view<T> cmat,
float
beta
,
std
::
false_type
)
{
std
::
size_t
n_dims
=
cmat
.
get_shape
().
lens
().
size
();
auto
a_lens
=
amat
.
get_shape
().
lens
();
auto
b_lens
=
bmat
.
get_shape
().
lens
();
auto
c_lens
=
cmat
.
get_shape
().
lens
();
std
::
size_t
n_dims
=
c_lens
.
size
();
std
::
size_t
dim_0
=
n_dims
-
2
;
std
::
size_t
dim_1
=
n_dims
-
1
;
auto
k
=
amat
.
get_shape
().
lens
()[
dim_1
];
auto
k
=
a_lens
[
dim_1
];
assert
(
a_lens
[
dim_1
]
==
b_lens
[
dim_0
]);
assert
(
c_lens
[
dim_0
]
==
a_lens
[
dim_0
]);
assert
(
c_lens
[
dim_1
]
==
b_lens
[
dim_1
]);
assert
(
amat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
cmat
.
get_shape
().
lens
()[
dim_0
]
==
amat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
cmat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_1
]);
std
::
size_t
a_len_diff
=
c_lens
.
size
()
-
a_lens
.
size
();
std
::
size_t
b_len_diff
=
c_lens
.
size
()
-
b_lens
.
size
();
std
::
vector
<
std
::
size_t
>
a_idx
(
a_lens
.
size
());
std
::
vector
<
std
::
size_t
>
b_idx
(
b_lens
.
size
());
shape_for_each
(
cmat
.
get_shape
(),
[
&
](
const
auto
&
c_idx
)
{
auto
a_idx
=
c_idx
;
auto
b_idx
=
c_idx
;
std
::
transform
(
c_lens
.
begin
()
+
a_len_diff
,
c_lens
.
end
(),
a_lens
.
begin
(),
a_idx
.
begin
(),
[
&
](
auto
i
,
auto
j
)
{
return
(
j
==
1
)
?
0
:
i
;
});
std
::
transform
(
c_lens
.
begin
()
+
b_len_diff
,
c_lens
.
end
(),
b_lens
.
begin
(),
b_idx
.
begin
(),
[
&
](
auto
i
,
auto
j
)
{
return
(
j
==
1
)
?
0
:
i
;
});
double
s
=
0.0
;
dfor
(
k
)([
&
](
auto
kk
)
{
a_idx
[
dim_1
]
=
b_idx
[
dim_0
]
=
kk
;
...
...
@@ -98,11 +112,10 @@ template <class T>
void
migemm_impl
(
tensor_view
<
T
>
cmat
,
tensor_view
<
T
>
amat
,
tensor_view
<
T
>
bmat
,
float
alpha
,
float
beta
)
{
auto
lens
=
amat
.
get_shape
().
lens
();
bool
batch_mul
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
())
==
(
*
lens
.
rbegin
())
*
(
*
(
lens
.
rbegin
()
+
1
));
if
(
batch_mul
)
auto
lens
=
cmat
.
get_shape
().
lens
();
std
::
size_t
num_matrices
=
std
::
accumulate
(
lens
.
rbegin
()
+
2
,
lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
if
(
num_matrices
==
1
)
{
migemm_impl
(
cmat
,
amat
,
bmat
,
alpha
,
beta
,
is_fast_gemm_type
<
T
>
{});
}
...
...
src/targets/cpu/lowering.cpp
View file @
2d98b64e
...
...
@@ -374,20 +374,80 @@ struct cpu_gemm
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
// all args are scalar
if
(
output_shape
.
scalar
())
{
visit_all
(
result
,
args
[
0
],
args
[
1
],
args
[
2
])([
&
](
auto
ret
,
auto
a
,
auto
b
,
auto
c
)
{
ret
[
0
]
=
op
.
alpha
*
a
[
0
]
*
b
[
0
]
+
op
.
beta
*
c
[
0
];
});
return
result
;
}
// first argument is 1-dim, pre-pend 1 at beginning
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
auto
out_lens
=
output_shape
.
lens
();
bool
is_a_prepended
=
false
;
shape
::
type_t
t
=
output_shape
.
type
();
if
(
a_lens
.
size
()
==
1
)
{
is_a_prepended
=
true
;
a_lens
.
insert
(
a_lens
.
begin
(),
1
);
out_lens
.
push_back
(
1
);
std
::
swap
(
*
out_lens
.
rbegin
(),
*
(
out_lens
.
rbegin
()
+
1
));
}
bool
is_b_appended
=
false
;
if
(
b_lens
.
size
()
==
1
)
{
is_b_appended
=
true
;
b_lens
.
push_back
(
1
);
out_lens
.
push_back
(
1
);
}
// if there is a C input
if
(
args
.
size
()
==
3
)
if
(
args
.
size
()
==
2
)
{
result
.
visit
([
&
](
auto
output
)
{
std
::
fill
(
output
.
begin
(),
output
.
end
(),
0
);
});
migemm
({{
t
,
out_lens
},
result
.
data
()},
{{
t
,
a_lens
},
args
[
0
].
data
()},
{{
t
,
b_lens
},
args
[
1
].
data
()},
op
.
alpha
,
op
.
beta
);
return
result
;
}
// 3 input arguments
auto
c_shape
=
args
[
2
].
get_shape
();
// In GEMM, C is broadcastable to A * B, so we should consider C
// is not the same shape as A * B. If the same shape, copy C to
// the memory of the output
if
(
c_shape
==
output_shape
)
{
// memory copy is more efficient than doing element by element
result
.
visit
([
&
](
auto
output
)
{
args
[
2
].
visit
(
[
&
](
auto
input
)
{
std
::
copy
(
input
.
begin
(),
input
.
end
(),
output
.
begin
());
});
[
&
](
auto
input
)
{
std
::
memcpy
(
output
.
data
(),
input
.
data
(),
c_shape
.
bytes
());
});
});
}
else
{
result
.
visit
([
&
](
auto
output
)
{
std
::
fill
(
output
.
begin
(),
output
.
end
(),
0
);
});
auto
out_len
=
output_shape
.
lens
();
auto
c_lens
=
c_shape
.
lens
();
std
::
size_t
len_diff
=
out_len
.
size
()
-
c_lens
.
size
();
visit_all
(
result
,
args
[
2
])
([
&
](
auto
output
,
auto
c
)
{
shape_for_each
(
output_shape
,
[
&
](
auto
out_idx
)
{
// compute the input index
std
::
vector
<
std
::
size_t
>
in_idx
(
c_lens
.
size
());
std
::
transform
(
c_lens
.
begin
(),
c_lens
.
end
(),
out_len
.
begin
()
+
len_diff
,
in_idx
.
begin
(),
[
&
](
auto
i
,
auto
j
)
{
return
(
i
==
1
)
?
0
:
j
;
});
output
(
out_idx
.
begin
(),
out_idx
.
end
())
=
c
(
in_idx
.
begin
(),
in_idx
.
end
());
});
});
}
migemm
(
result
,
args
[
0
],
args
[
1
],
op
.
alpha
,
op
.
beta
);
migemm
({{
t
,
out_lens
},
result
.
data
()},
{{
t
,
a_lens
},
args
[
0
].
data
()},
{{
t
,
b_lens
},
args
[
1
].
data
()},
op
.
alpha
,
op
.
beta
);
return
result
;
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment