Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
77212cc1
Commit
77212cc1
authored
Apr 03, 2019
by
Shucai Xiao
Browse files
code backup.
parent
900bad8b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
8 deletions
+54
-8
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+13
-3
src/targets/cpu/gemm.cpp
src/targets/cpu/gemm.cpp
+8
-3
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+33
-2
No files found.
src/include/migraphx/operators.hpp
View file @
77212cc1
...
@@ -820,7 +820,7 @@ struct gather
...
@@ -820,7 +820,7 @@ struct gather
struct
dot
struct
dot
{
{
float
alpha
=
1.0
;
float
alpha
=
1.0
;
float
beta
=
0
.0
;
float
beta
=
1
.0
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -839,7 +839,7 @@ struct dot
...
@@ -839,7 +839,7 @@ struct dot
// according to the specification of the numpy.matmul()
// according to the specification of the numpy.matmul()
// inputs with the shape dims more than 2 are acceptable
// inputs with the shape dims more than 2 are acceptable
// as long as dim values are the same in the two inputs
// as long as dim values are the same in the two inputs
if
(
!
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
))
if
(
!
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()
))
{
{
MIGRAPHX_THROW
(
"DOT: dim values mismatch"
);
MIGRAPHX_THROW
(
"DOT: dim values mismatch"
);
}
}
...
@@ -847,10 +847,20 @@ struct dot
...
@@ -847,10 +847,20 @@ struct dot
std
::
size_t
dim_0
=
a
.
lens
().
size
()
-
2
;
std
::
size_t
dim_0
=
a
.
lens
().
size
()
-
2
;
std
::
size_t
dim_1
=
a
.
lens
().
size
()
-
1
;
std
::
size_t
dim_1
=
a
.
lens
().
size
()
-
1
;
if
(
a
.
lens
()[
dim_1
]
!=
b
.
lens
()[
dim_0
])
if
(
a
.
lens
()[
dim_1
]
!=
b
.
lens
()[
dim_0
])
MIGRAPHX_THROW
(
"Inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
{
MIGRAPHX_THROW
(
"DOT: inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
}
auto
out_lens
=
a
.
lens
();
auto
out_lens
=
a
.
lens
();
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
if
(
inputs
.
size
()
==
3
&&
out_lens
!=
inputs
.
at
(
2
).
lens
())
{
MIGRAPHX_THROW
(
"DOT: dimension mismatch, operand C: {"
+
to_string_range
(
c_lens
)
+
"}, cannot add to operand A * B: {"
+
to_string_range
(
out_lens
)
+
"}"
);
}
return
{
t
,
out_lens
};
return
{
t
,
out_lens
};
}
}
};
};
...
...
src/targets/cpu/gemm.cpp
View file @
77212cc1
...
@@ -55,7 +55,13 @@ void migemm_impl(tensor_view<T> cmat,
...
@@ -55,7 +55,13 @@ void migemm_impl(tensor_view<T> cmat,
visit_mat
(
amat
,
[
&
](
const
auto
&
a
)
{
visit_mat
(
amat
,
[
&
](
const
auto
&
a
)
{
visit_mat
(
bmat
,
[
&
](
const
auto
&
b
)
{
visit_mat
(
bmat
,
[
&
](
const
auto
&
b
)
{
auto
c
=
make_mat
(
cmat
);
auto
c
=
make_mat
(
cmat
);
c
=
(
a
*
b
)
*
alpha
+
beta
*
c
;
c
=
beta
*
c
;
// This is a simple optimization to avoid
// compute A * B if alpha is 0.0
if
(
alpha
!=
0.0
)
{
c
=
c
+
alpha
*
a
*
b
;
}
});
});
});
});
}
}
...
@@ -95,8 +101,7 @@ void migemm_impl(
...
@@ -95,8 +101,7 @@ void migemm_impl(
{
{
auto
lens
=
amat
.
get_shape
().
lens
();
auto
lens
=
amat
.
get_shape
().
lens
();
bool
batch_mul
=
bool
batch_mul
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
())
==
std
::
accumulate
(
lens
.
rbegin
()
+
2
,
lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
())
==
1
;
(
*
lens
.
rbegin
())
*
(
*
(
lens
.
rbegin
()
+
1
));
if
(
batch_mul
)
if
(
batch_mul
)
{
{
migemm_impl
(
cmat
,
amat
,
bmat
,
alpha
,
beta
,
is_fast_gemm_type
<
T
>
{});
migemm_impl
(
cmat
,
amat
,
bmat
,
alpha
,
beta
,
is_fast_gemm_type
<
T
>
{});
...
...
src/targets/cpu/lowering.cpp
View file @
77212cc1
...
@@ -369,12 +369,43 @@ struct cpu_gemm
...
@@ -369,12 +369,43 @@ struct cpu_gemm
{
{
op
::
dot
op
;
op
::
dot
op
;
std
::
string
name
()
const
{
return
"cpu::dot"
;
}
std
::
string
name
()
const
{
return
"cpu::dot"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
if
(
inputs
.
size
()
==
3
)
{
auto
c_shape
=
inputs
.
at
(
2
);
check_shapes
{{
c_shape
}}.
not_broadcasted
();
}
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
// 3 inputs, it is alpha * A * B + beta * C, then
// A and B are matrics, and C is broadcastable to A * B
if
(
args
.
size
()
==
3
)
{
// no need to consider the value of args[2]
if
(
op
.
beta
==
0.0
f
)
{
result
.
visit
([
&
](
auto
output
)
{
std
::
fill
(
output
.
begin
(),
output
.
end
(),
0
);
});
}
else
{
visit_all
(
result
,
args
[
2
])([
&
](
auto
output
,
auto
input
)
{
std
::
copy
(
input
.
begin
(),
input
.
end
(),
output
.
begin
());
});
}
migemm
(
result
,
args
[
0
],
args
[
1
],
op
.
alpha
,
op
.
beta
);
migemm
(
result
,
args
[
0
],
args
[
1
],
op
.
alpha
,
op
.
beta
);
return
result
;
}
// 2 input arguments
migemm
(
result
,
args
[
0
],
args
[
1
],
op
.
alpha
,
0.0
f
);
return
result
;
return
result
;
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment