Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
238bfadd
Commit
238bfadd
authored
Aug 04, 2018
by
Paul
Browse files
Add simple fallback for now
parent
0b5fa390
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
36 additions
and
18 deletions
+36
-18
src/include/migraph/tensor_view.hpp
src/include/migraph/tensor_view.hpp
+2
-0
src/targets/cpu/gemm.cpp
src/targets/cpu/gemm.cpp
+19
-10
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+9
-7
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+6
-1
No files found.
src/include/migraph/tensor_view.hpp
View file @
238bfadd
...
@@ -29,6 +29,7 @@ struct tensor_view
...
@@ -29,6 +29,7 @@ struct tensor_view
template
<
class
...
Ts
,
MIGRAPH_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
template
<
class
...
Ts
,
MIGRAPH_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
const
T
&
operator
()(
Ts
...
xs
)
const
const
T
&
operator
()(
Ts
...
xs
)
const
{
{
assert
(
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
xs
)...}
<
m_shape
.
lens
());
assert
(
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})
<
m_shape
.
bytes
()
/
sizeof
(
T
));
assert
(
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})
<
m_shape
.
bytes
()
/
sizeof
(
T
));
return
m_data
[
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})];
return
m_data
[
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})];
}
}
...
@@ -36,6 +37,7 @@ struct tensor_view
...
@@ -36,6 +37,7 @@ struct tensor_view
template
<
class
...
Ts
,
MIGRAPH_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
template
<
class
...
Ts
,
MIGRAPH_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
T
&
operator
()(
Ts
...
xs
)
T
&
operator
()(
Ts
...
xs
)
{
{
assert
(
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
xs
)...}
<
m_shape
.
lens
());
assert
(
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})
<
m_shape
.
bytes
()
/
sizeof
(
T
));
assert
(
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})
<
m_shape
.
bytes
()
/
sizeof
(
T
));
return
m_data
[
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})];
return
m_data
[
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})];
}
}
...
...
src/targets/cpu/gemm.cpp
View file @
238bfadd
#include <migraph/cpu/gemm.hpp>
#include <migraph/cpu/gemm.hpp>
#include <migraph/dfor.hpp>
#include <migraph/requires.hpp>
#include <migraph/requires.hpp>
#include <blaze/math/CustomMatrix.h>
#include <blaze/math/CustomMatrix.h>
...
@@ -50,10 +51,7 @@ void migemm_impl(tensor_view<T> cmat,
...
@@ -50,10 +51,7 @@ void migemm_impl(tensor_view<T> cmat,
visit_mat
(
amat
,
[
&
](
const
auto
&
a
)
{
visit_mat
(
amat
,
[
&
](
const
auto
&
a
)
{
visit_mat
(
bmat
,
[
&
](
const
auto
&
b
)
{
visit_mat
(
bmat
,
[
&
](
const
auto
&
b
)
{
auto
c
=
make_mat
(
cmat
);
auto
c
=
make_mat
(
cmat
);
if
(
alpha
==
1.0
and
beta
==
0.0
)
c
=
(
a
*
b
)
*
alpha
+
beta
*
c
;
c
=
a
*
b
;
else
c
=
(
a
*
b
)
*
alpha
+
beta
*
c
;
});
});
});
});
}
}
...
@@ -66,12 +64,23 @@ void migemm_impl(tensor_view<T> cmat,
...
@@ -66,12 +64,23 @@ void migemm_impl(tensor_view<T> cmat,
float
beta
,
float
beta
,
std
::
false_type
)
std
::
false_type
)
{
{
(
void
)
cmat
;
auto
m
=
cmat
.
get_shape
().
lens
()[
0
];
(
void
)
amat
;
auto
n
=
cmat
.
get_shape
().
lens
()[
1
];
(
void
)
bmat
;
auto
k
=
amat
.
get_shape
().
lens
()[
1
];
(
void
)
alpha
;
(
void
)
beta
;
assert
(
amat
.
get_shape
().
lens
()[
1
]
==
bmat
.
get_shape
().
lens
()[
0
]);
assert
(
true
&&
"TODO"
);
assert
(
m
==
amat
.
get_shape
().
lens
()[
0
]);
assert
(
n
==
bmat
.
get_shape
().
lens
()[
1
]);
dfor
(
m
,
n
)([
&
](
auto
ii
,
auto
jj
)
{
double
s
=
cmat
(
ii
,
jj
)
*
beta
;
dfor
(
k
)([
&
](
auto
kk
)
{
s
+=
amat
(
ii
,
kk
)
*
bmat
(
kk
,
jj
);
});
cmat
(
ii
,
jj
)
=
alpha
*
s
;
});
}
}
template
<
class
T
>
template
<
class
T
>
...
...
test/cpu_ops_test.cpp
View file @
238bfadd
...
@@ -242,14 +242,15 @@ void reshape_test()
...
@@ -242,14 +242,15 @@ void reshape_test()
}
}
}
}
template
<
class
T
>
void
gemm_test
()
void
gemm_test
()
{
{
migraph
::
program
p
;
migraph
::
program
p
;
std
::
vector
<
float
>
a
=
{
-
0.00925222
,
0.56250403
,
0.70107397
,
0.75402161
,
-
0.505885
,
std
::
vector
<
T
>
a
=
{
-
0.00925222
,
0.56250403
,
0.70107397
,
0.75402161
,
-
0.505885
,
1.33628943
,
-
0.11413
,
-
0.31270559
,
1.59336732
,
-
0.19361027
,
1.33628943
,
-
0.11413
,
-
0.31270559
,
1.59336732
,
-
0.19361027
,
-
0.91620867
,
0.40108416
,
-
0.06969921
,
0.68483471
,
-
0.39906632
,
-
0.91620867
,
0.40108416
,
-
0.06969921
,
0.68483471
,
-
0.39906632
,
-
1.66423624
,
0.69040076
,
-
1.31490171
,
-
0.11282616
,
-
0.79391814
};
-
1.66423624
,
0.69040076
,
-
1.31490171
,
-
0.11282616
,
-
0.79391814
};
std
::
vector
<
float
>
b
=
{
6.09568541e-01
,
std
::
vector
<
T
>
b
=
{
6.09568541e-01
,
-
6.10527007e-01
,
-
6.10527007e-01
,
3.66646462e-01
,
3.66646462e-01
,
1.18951101e-01
,
1.18951101e-01
,
...
@@ -264,7 +265,7 @@ void gemm_test()
...
@@ -264,7 +265,7 @@ void gemm_test()
1.53027987e+00
,
1.53027987e+00
,
-
3.81407415e-04
,
-
3.81407415e-04
,
-
3.29650255e-01
};
-
3.29650255e-01
};
std
::
vector
<
float
>
c
=
{
-
1.56327541e+00
,
std
::
vector
<
T
>
c
=
{
-
1.56327541e+00
,
-
7.09570140e-01
,
-
7.09570140e-01
,
-
5.37424982e-01
,
-
5.37424982e-01
,
-
2.22994831e-01
,
-
2.22994831e-01
,
...
@@ -276,14 +277,14 @@ void gemm_test()
...
@@ -276,14 +277,14 @@ void gemm_test()
-
1.29885596e+00
,
-
1.29885596e+00
,
2.16294914e+00
,
2.16294914e+00
,
-
1.48101497e-01
};
-
1.48101497e-01
};
migraph
::
shape
a_shape
{
migraph
::
shape
::
floa
t_type
,
{
4
,
5
}};
migraph
::
shape
a_shape
{
migraph
::
shape
::
ge
t_type
<
T
>
{}
,
{
4
,
5
}};
auto
al
=
p
.
add_literal
(
migraph
::
literal
{
a_shape
,
a
});
auto
al
=
p
.
add_literal
(
migraph
::
literal
{
a_shape
,
a
});
migraph
::
shape
b_shape
{
migraph
::
shape
::
floa
t_type
,
{
5
,
3
}};
migraph
::
shape
b_shape
{
migraph
::
shape
::
ge
t_type
<
T
>
{}
,
{
5
,
3
}};
auto
bl
=
p
.
add_literal
(
migraph
::
literal
{
b_shape
,
b
});
auto
bl
=
p
.
add_literal
(
migraph
::
literal
{
b_shape
,
b
});
p
.
add_instruction
(
migraph
::
gemm
{},
al
,
bl
);
p
.
add_instruction
(
migraph
::
gemm
{},
al
,
bl
);
p
.
compile
(
migraph
::
cpu
::
cpu_target
{});
p
.
compile
(
migraph
::
cpu
::
cpu_target
{});
auto
result
=
p
.
eval
({});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
results_vector
(
12
);
std
::
vector
<
T
>
results_vector
(
12
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
float
tol
=
1e-6
;
float
tol
=
1e-6
;
for
(
int
i
=
0
;
i
<
results_vector
.
size
();
i
++
)
for
(
int
i
=
0
;
i
<
results_vector
.
size
();
i
++
)
...
@@ -656,7 +657,8 @@ int main()
...
@@ -656,7 +657,8 @@ int main()
add_broadcast_test
();
add_broadcast_test
();
sub_test
();
sub_test
();
mul_test
();
mul_test
();
gemm_test
();
gemm_test
<
float
>
();
gemm_test
<
double
>
();
reshape_test
();
reshape_test
();
transpose_test
();
transpose_test
();
contiguous_test
();
contiguous_test
();
...
...
test/gpu/miopen.cpp
View file @
238bfadd
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include <migraph/gpu/miopen.hpp>
#include <migraph/gpu/miopen.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/manage_ptr.hpp>
#include <migraph/manage_ptr.hpp>
#include <migraph/type_name.hpp>
#include <miopen/miopen.h>
#include <miopen/miopen.h>
...
@@ -48,7 +49,11 @@ void verify_program()
...
@@ -48,7 +49,11 @@ void verify_program()
{
{
auto
cpu_arg
=
run_cpu
<
V
>
();
auto
cpu_arg
=
run_cpu
<
V
>
();
auto
gpu_arg
=
run_gpu
<
V
>
();
auto
gpu_arg
=
run_gpu
<
V
>
();
visit_all
(
cpu_arg
,
gpu_arg
)([](
auto
cpu
,
auto
gpu
)
{
EXPECT
(
test
::
verify_range
(
cpu
,
gpu
));
});
visit_all
(
cpu_arg
,
gpu_arg
)([](
auto
cpu
,
auto
gpu
)
{
if
(
not
test
::
verify_range
(
cpu
,
gpu
))
{
std
::
cout
<<
"FAILED: "
<<
migraph
::
get_type_name
<
V
>
()
<<
std
::
endl
;
}
});
}
}
struct
test_literals
struct
test_literals
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment