Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
742e6a4b
Commit
742e6a4b
authored
Jun 02, 2018
by
Scott Thornton
Browse files
Added gemm operator and cpu target
parent
dc2b0abf
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
67 additions
and
0 deletions
+67
-0
src/include/rtg/operators.hpp
src/include/rtg/operators.hpp
+28
-0
src/targets/cpu/cpu_target.cpp
src/targets/cpu/cpu_target.cpp
+39
-0
No files found.
src/include/rtg/operators.hpp
View file @
742e6a4b
#ifndef RTG_GUARD_OPERATORS_HPP
#define RTG_GUARD_OPERATORS_HPP
#include <array>
#include <rtg/operation.hpp>
#include <rtg/stringutils.hpp>
#include <rtg/streamutils.hpp>
...
...
@@ -218,6 +219,33 @@ struct reshape
}
};
struct
gemm
{
std
::
string
name
()
const
{
return
"gemm"
;}
std
::
size_t
lda
=
1
;
std
::
size_t
ldb
=
1
;
std
::
size_t
ldc
=
1
;
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
2
).
same_type
().
same_dims
().
only_dims
(
2
);
const
shape
&
A
=
inputs
.
at
(
0
);
const
shape
&
B
=
inputs
.
at
(
1
);
auto
t
=
A
.
type
();
if
(
A
.
lens
()[
1
]
!=
B
.
lens
()[
0
])
RTG_THROW
(
"Inner dimensions do not match"
);
return
{
t
,
{
A
.
lens
()[
0
],
B
.
lens
()[
1
]}};
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
gemm
&
op
)
{
os
<<
op
.
name
()
<<
"["
;
os
<<
"]"
;
}
};
}
// namespace rtg
#endif
src/targets/cpu/cpu_target.cpp
View file @
742e6a4b
...
...
@@ -47,6 +47,45 @@ struct cpu_convolution
}
};
struct
cpu_gemm
{
gemm
op
;
std
::
string
name
()
const
{
return
"cpu::gemm"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
{
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
C
{
output_shape
};
visit_all
(
C
,
args
[
0
],
args
[
1
])([
&
](
auto
C
,
auto
A
,
auto
B
)
{
auto
M
=
A
.
get_shape
().
lens
()[
0
];
auto
N
=
B
.
get_shape
().
lens
()[
1
];
auto
K
=
B
.
get_shape
().
lens
()[
0
];
auto
a
=
A
.
data
();
auto
b
=
B
.
data
();
auto
c
=
C
.
data
();
for
(
int
ii
=
0
;
ii
<
M
;
ii
++
)
{
for
(
int
jj
=
0
;
jj
<
N
;
jj
++
)
{
c
[
ii
*
N
+
jj
]
=
0
;
}
}
for
(
int
ii
=
0
;
ii
<
M
;
ii
++
)
{
for
(
int
kk
=
0
;
kk
<
K
;
kk
++
)
{
auto
aik
=
a
[
ii
*
K
+
kk
];
auto
*
bkj
=
&
b
[
kk
*
N
];
auto
*
cij
=
&
c
[
ii
*
N
];
for
(
int
jj
=
0
;
jj
<
N
;
jj
++
,
cij
++
,
bkj
++
)
{
*
cij
+=
aik
*
(
*
bkj
);
}
}
}
});
}
};
struct
relu
{
std
::
string
name
()
const
{
return
"cpu::relu"
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment