Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0ea7b7a3
Commit
0ea7b7a3
authored
Mar 11, 2019
by
Shucai Xiao
Browse files
clang format
parent
2d98b64e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
31 deletions
+43
-31
src/targets/cpu/gemm.cpp
src/targets/cpu/gemm.cpp
+15
-11
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+28
-20
No files found.
src/targets/cpu/gemm.cpp
View file @
0ea7b7a3
...
@@ -92,14 +92,18 @@ void migemm_impl(tensor_view<T> cmat,
...
@@ -92,14 +92,18 @@ void migemm_impl(tensor_view<T> cmat,
std
::
vector
<
std
::
size_t
>
b_idx
(
b_lens
.
size
());
std
::
vector
<
std
::
size_t
>
b_idx
(
b_lens
.
size
());
shape_for_each
(
cmat
.
get_shape
(),
[
&
](
const
auto
&
c_idx
)
{
shape_for_each
(
cmat
.
get_shape
(),
[
&
](
const
auto
&
c_idx
)
{
std
::
transform
(
c_lens
.
begin
()
+
a_len_diff
,
c_lens
.
end
(),
a_lens
.
begin
(),
a_idx
.
begin
(),
[
&
](
auto
i
,
auto
j
)
{
std
::
transform
(
c_lens
.
begin
()
+
a_len_diff
,
return
(
j
==
1
)
?
0
:
i
;
c_lens
.
end
(),
});
a_lens
.
begin
(),
std
::
transform
(
c_lens
.
begin
()
+
b_len_diff
,
c_lens
.
end
(),
b_lens
.
begin
(),
b_idx
.
begin
(),
[
&
](
auto
i
,
auto
j
)
{
a_idx
.
begin
(),
return
(
j
==
1
)
?
0
:
i
;
[
&
](
auto
i
,
auto
j
)
{
return
(
j
==
1
)
?
0
:
i
;
});
});
std
::
transform
(
c_lens
.
begin
()
+
b_len_diff
,
c_lens
.
end
(),
double
s
=
0.0
;
b_lens
.
begin
(),
b_idx
.
begin
(),
[
&
](
auto
i
,
auto
j
)
{
return
(
j
==
1
)
?
0
:
i
;
});
double
s
=
0.0
;
dfor
(
k
)([
&
](
auto
kk
)
{
dfor
(
k
)([
&
](
auto
kk
)
{
a_idx
[
dim_1
]
=
b_idx
[
dim_0
]
=
kk
;
a_idx
[
dim_1
]
=
b_idx
[
dim_0
]
=
kk
;
s
+=
amat
(
a_idx
.
begin
(),
a_idx
.
end
())
*
bmat
(
b_idx
.
begin
(),
b_idx
.
end
());
s
+=
amat
(
a_idx
.
begin
(),
a_idx
.
end
())
*
bmat
(
b_idx
.
begin
(),
b_idx
.
end
());
...
@@ -112,9 +116,9 @@ template <class T>
...
@@ -112,9 +116,9 @@ template <class T>
void
migemm_impl
(
void
migemm_impl
(
tensor_view
<
T
>
cmat
,
tensor_view
<
T
>
amat
,
tensor_view
<
T
>
bmat
,
float
alpha
,
float
beta
)
tensor_view
<
T
>
cmat
,
tensor_view
<
T
>
amat
,
tensor_view
<
T
>
bmat
,
float
alpha
,
float
beta
)
{
{
auto
lens
=
cmat
.
get_shape
().
lens
();
auto
lens
=
cmat
.
get_shape
().
lens
();
std
::
size_t
num_matrices
=
std
::
size_t
num_matrices
=
std
::
accumulate
(
std
::
accumulate
(
lens
.
rbegin
()
+
2
,
lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
lens
.
rbegin
()
+
2
,
lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
if
(
num_matrices
==
1
)
if
(
num_matrices
==
1
)
{
{
migemm_impl
(
cmat
,
amat
,
bmat
,
alpha
,
beta
,
is_fast_gemm_type
<
T
>
{});
migemm_impl
(
cmat
,
amat
,
bmat
,
alpha
,
beta
,
is_fast_gemm_type
<
T
>
{});
...
...
src/targets/cpu/lowering.cpp
View file @
0ea7b7a3
...
@@ -375,22 +375,22 @@ struct cpu_gemm
...
@@ -375,22 +375,22 @@ struct cpu_gemm
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
// all args are scalar
// all args are scalar
if
(
output_shape
.
scalar
())
if
(
output_shape
.
scalar
())
{
{
visit_all
(
result
,
args
[
0
],
args
[
1
],
args
[
2
])([
&
](
auto
ret
,
auto
a
,
auto
b
,
auto
c
)
{
visit_all
(
result
,
args
[
0
],
args
[
1
],
args
[
2
])([
&
](
auto
ret
,
auto
a
,
auto
b
,
auto
c
)
{
ret
[
0
]
=
op
.
alpha
*
a
[
0
]
*
b
[
0
]
+
op
.
beta
*
c
[
0
];
ret
[
0
]
=
op
.
alpha
*
a
[
0
]
*
b
[
0
]
+
op
.
beta
*
c
[
0
];
});
});
return
result
;
return
result
;
}
}
// first argument is 1-dim, pre-pend 1 at beginning
// first argument is 1-dim, pre-pend 1 at beginning
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
auto
out_lens
=
output_shape
.
lens
();
auto
out_lens
=
output_shape
.
lens
();
bool
is_a_prepended
=
false
;
bool
is_a_prepended
=
false
;
shape
::
type_t
t
=
output_shape
.
type
();
shape
::
type_t
t
=
output_shape
.
type
();
if
(
a_lens
.
size
()
==
1
)
if
(
a_lens
.
size
()
==
1
)
{
{
is_a_prepended
=
true
;
is_a_prepended
=
true
;
a_lens
.
insert
(
a_lens
.
begin
(),
1
);
a_lens
.
insert
(
a_lens
.
begin
(),
1
);
...
@@ -399,7 +399,7 @@ struct cpu_gemm
...
@@ -399,7 +399,7 @@ struct cpu_gemm
}
}
bool
is_b_appended
=
false
;
bool
is_b_appended
=
false
;
if
(
b_lens
.
size
()
==
1
)
if
(
b_lens
.
size
()
==
1
)
{
{
is_b_appended
=
true
;
is_b_appended
=
true
;
b_lens
.
push_back
(
1
);
b_lens
.
push_back
(
1
);
...
@@ -410,17 +410,20 @@ struct cpu_gemm
...
@@ -410,17 +410,20 @@ struct cpu_gemm
if
(
args
.
size
()
==
2
)
if
(
args
.
size
()
==
2
)
{
{
result
.
visit
([
&
](
auto
output
)
{
std
::
fill
(
output
.
begin
(),
output
.
end
(),
0
);
});
result
.
visit
([
&
](
auto
output
)
{
std
::
fill
(
output
.
begin
(),
output
.
end
(),
0
);
});
migemm
({{
t
,
out_lens
},
result
.
data
()},
{{
t
,
a_lens
},
args
[
0
].
data
()},
migemm
({{
t
,
out_lens
},
result
.
data
()},
{{
t
,
b_lens
},
args
[
1
].
data
()},
op
.
alpha
,
op
.
beta
);
{{
t
,
a_lens
},
args
[
0
].
data
()},
{{
t
,
b_lens
},
args
[
1
].
data
()},
op
.
alpha
,
op
.
beta
);
return
result
;
return
result
;
}
}
// 3 input arguments
// 3 input arguments
auto
c_shape
=
args
[
2
].
get_shape
();
auto
c_shape
=
args
[
2
].
get_shape
();
// In GEMM, C is broadcastable to A * B, so we should consider C
// In GEMM, C is broadcastable to A * B, so we should consider C
// is not the same shape as A * B. If the same shape, copy C to
// is not the same shape as A * B. If the same shape, copy C to
// the memory of the output
// the memory of the output
if
(
c_shape
==
output_shape
)
if
(
c_shape
==
output_shape
)
{
{
// memory copy is more efficient than doing element by element
// memory copy is more efficient than doing element by element
result
.
visit
([
&
](
auto
output
)
{
result
.
visit
([
&
](
auto
output
)
{
...
@@ -430,23 +433,28 @@ struct cpu_gemm
...
@@ -430,23 +433,28 @@ struct cpu_gemm
}
}
else
else
{
{
auto
out_len
=
output_shape
.
lens
();
auto
out_len
=
output_shape
.
lens
();
auto
c_lens
=
c_shape
.
lens
();
auto
c_lens
=
c_shape
.
lens
();
std
::
size_t
len_diff
=
out_len
.
size
()
-
c_lens
.
size
();
std
::
size_t
len_diff
=
out_len
.
size
()
-
c_lens
.
size
();
visit_all
(
result
,
args
[
2
])
([
&
](
auto
output
,
auto
c
)
{
visit_all
(
result
,
args
[
2
])([
&
](
auto
output
,
auto
c
)
{
shape_for_each
(
output_shape
,
[
&
](
auto
out_idx
)
{
shape_for_each
(
output_shape
,
[
&
](
auto
out_idx
)
{
// compute the input index
// compute the input index
std
::
vector
<
std
::
size_t
>
in_idx
(
c_lens
.
size
());
std
::
vector
<
std
::
size_t
>
in_idx
(
c_lens
.
size
());
std
::
transform
(
c_lens
.
begin
(),
c_lens
.
end
(),
out_len
.
begin
()
+
len_diff
,
in_idx
.
begin
(),
[
&
](
auto
i
,
auto
j
)
{
std
::
transform
(
c_lens
.
begin
(),
return
(
i
==
1
)
?
0
:
j
;
c_lens
.
end
(),
});
out_len
.
begin
()
+
len_diff
,
in_idx
.
begin
(),
[
&
](
auto
i
,
auto
j
)
{
return
(
i
==
1
)
?
0
:
j
;
});
output
(
out_idx
.
begin
(),
out_idx
.
end
())
=
c
(
in_idx
.
begin
(),
in_idx
.
end
());
output
(
out_idx
.
begin
(),
out_idx
.
end
())
=
c
(
in_idx
.
begin
(),
in_idx
.
end
());
});
});
});
});
}
}
migemm
({{
t
,
out_lens
},
result
.
data
()},
{{
t
,
a_lens
},
args
[
0
].
data
()},
migemm
({{
t
,
out_lens
},
result
.
data
()},
{{
t
,
b_lens
},
args
[
1
].
data
()},
op
.
alpha
,
op
.
beta
);
{{
t
,
a_lens
},
args
[
0
].
data
()},
{{
t
,
b_lens
},
args
[
1
].
data
()},
op
.
alpha
,
op
.
beta
);
return
result
;
return
result
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment