Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
238bfadd
Commit
238bfadd
authored
Aug 04, 2018
by
Paul
Browse files
Add simple fallback for now
parent
0b5fa390
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
36 additions
and
18 deletions
+36
-18
src/include/migraph/tensor_view.hpp
src/include/migraph/tensor_view.hpp
+2
-0
src/targets/cpu/gemm.cpp
src/targets/cpu/gemm.cpp
+19
-10
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+9
-7
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+6
-1
No files found.
src/include/migraph/tensor_view.hpp
View file @
238bfadd
...
@@ -29,6 +29,7 @@ struct tensor_view
...
@@ -29,6 +29,7 @@ struct tensor_view
template
<
class
...
Ts
,
MIGRAPH_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
template
<
class
...
Ts
,
MIGRAPH_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
const
T
&
operator
()(
Ts
...
xs
)
const
const
T
&
operator
()(
Ts
...
xs
)
const
{
{
assert
(
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
xs
)...}
<
m_shape
.
lens
());
assert
(
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})
<
m_shape
.
bytes
()
/
sizeof
(
T
));
assert
(
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})
<
m_shape
.
bytes
()
/
sizeof
(
T
));
return
m_data
[
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})];
return
m_data
[
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})];
}
}
...
@@ -36,6 +37,7 @@ struct tensor_view
...
@@ -36,6 +37,7 @@ struct tensor_view
template
<
class
...
Ts
,
MIGRAPH_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
template
<
class
...
Ts
,
MIGRAPH_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
T
&
operator
()(
Ts
...
xs
)
T
&
operator
()(
Ts
...
xs
)
{
{
assert
(
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
xs
)...}
<
m_shape
.
lens
());
assert
(
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})
<
m_shape
.
bytes
()
/
sizeof
(
T
));
assert
(
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})
<
m_shape
.
bytes
()
/
sizeof
(
T
));
return
m_data
[
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})];
return
m_data
[
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})];
}
}
...
...
src/targets/cpu/gemm.cpp
View file @
238bfadd
#include <migraph/cpu/gemm.hpp>
#include <migraph/cpu/gemm.hpp>
#include <migraph/dfor.hpp>
#include <migraph/requires.hpp>
#include <migraph/requires.hpp>
#include <blaze/math/CustomMatrix.h>
#include <blaze/math/CustomMatrix.h>
...
@@ -50,9 +51,6 @@ void migemm_impl(tensor_view<T> cmat,
...
@@ -50,9 +51,6 @@ void migemm_impl(tensor_view<T> cmat,
visit_mat
(
amat
,
[
&
](
const
auto
&
a
)
{
visit_mat
(
amat
,
[
&
](
const
auto
&
a
)
{
visit_mat
(
bmat
,
[
&
](
const
auto
&
b
)
{
visit_mat
(
bmat
,
[
&
](
const
auto
&
b
)
{
auto
c
=
make_mat
(
cmat
);
auto
c
=
make_mat
(
cmat
);
if
(
alpha
==
1.0
and
beta
==
0.0
)
c
=
a
*
b
;
else
c
=
(
a
*
b
)
*
alpha
+
beta
*
c
;
c
=
(
a
*
b
)
*
alpha
+
beta
*
c
;
});
});
});
});
...
@@ -66,12 +64,23 @@ void migemm_impl(tensor_view<T> cmat,
...
@@ -66,12 +64,23 @@ void migemm_impl(tensor_view<T> cmat,
float
beta
,
float
beta
,
std
::
false_type
)
std
::
false_type
)
{
{
(
void
)
cmat
;
auto
m
=
cmat
.
get_shape
().
lens
()[
0
];
(
void
)
amat
;
auto
n
=
cmat
.
get_shape
().
lens
()[
1
];
(
void
)
bmat
;
auto
k
=
amat
.
get_shape
().
lens
()[
1
];
(
void
)
alpha
;
(
void
)
beta
;
assert
(
amat
.
get_shape
().
lens
()[
1
]
==
bmat
.
get_shape
().
lens
()[
0
]);
assert
(
true
&&
"TODO"
);
assert
(
m
==
amat
.
get_shape
().
lens
()[
0
]);
assert
(
n
==
bmat
.
get_shape
().
lens
()[
1
]);
dfor
(
m
,
n
)([
&
](
auto
ii
,
auto
jj
)
{
double
s
=
cmat
(
ii
,
jj
)
*
beta
;
dfor
(
k
)([
&
](
auto
kk
)
{
s
+=
amat
(
ii
,
kk
)
*
bmat
(
kk
,
jj
);
});
cmat
(
ii
,
jj
)
=
alpha
*
s
;
});
}
}
template
<
class
T
>
template
<
class
T
>
...
...
test/cpu_ops_test.cpp
View file @
238bfadd
...
@@ -242,14 +242,15 @@ void reshape_test()
...
@@ -242,14 +242,15 @@ void reshape_test()
}
}
}
}
template
<
class
T
>
void
gemm_test
()
void
gemm_test
()
{
{
migraph
::
program
p
;
migraph
::
program
p
;
std
::
vector
<
float
>
a
=
{
-
0.00925222
,
0.56250403
,
0.70107397
,
0.75402161
,
-
0.505885
,
std
::
vector
<
T
>
a
=
{
-
0.00925222
,
0.56250403
,
0.70107397
,
0.75402161
,
-
0.505885
,
1.33628943
,
-
0.11413
,
-
0.31270559
,
1.59336732
,
-
0.19361027
,
1.33628943
,
-
0.11413
,
-
0.31270559
,
1.59336732
,
-
0.19361027
,
-
0.91620867
,
0.40108416
,
-
0.06969921
,
0.68483471
,
-
0.39906632
,
-
0.91620867
,
0.40108416
,
-
0.06969921
,
0.68483471
,
-
0.39906632
,
-
1.66423624
,
0.69040076
,
-
1.31490171
,
-
0.11282616
,
-
0.79391814
};
-
1.66423624
,
0.69040076
,
-
1.31490171
,
-
0.11282616
,
-
0.79391814
};
std
::
vector
<
float
>
b
=
{
6.09568541e-01
,
std
::
vector
<
T
>
b
=
{
6.09568541e-01
,
-
6.10527007e-01
,
-
6.10527007e-01
,
3.66646462e-01
,
3.66646462e-01
,
1.18951101e-01
,
1.18951101e-01
,
...
@@ -264,7 +265,7 @@ void gemm_test()
...
@@ -264,7 +265,7 @@ void gemm_test()
1.53027987e+00
,
1.53027987e+00
,
-
3.81407415e-04
,
-
3.81407415e-04
,
-
3.29650255e-01
};
-
3.29650255e-01
};
std
::
vector
<
float
>
c
=
{
-
1.56327541e+00
,
std
::
vector
<
T
>
c
=
{
-
1.56327541e+00
,
-
7.09570140e-01
,
-
7.09570140e-01
,
-
5.37424982e-01
,
-
5.37424982e-01
,
-
2.22994831e-01
,
-
2.22994831e-01
,
...
@@ -276,14 +277,14 @@ void gemm_test()
...
@@ -276,14 +277,14 @@ void gemm_test()
-
1.29885596e+00
,
-
1.29885596e+00
,
2.16294914e+00
,
2.16294914e+00
,
-
1.48101497e-01
};
-
1.48101497e-01
};
migraph
::
shape
a_shape
{
migraph
::
shape
::
floa
t_type
,
{
4
,
5
}};
migraph
::
shape
a_shape
{
migraph
::
shape
::
ge
t_type
<
T
>
{}
,
{
4
,
5
}};
auto
al
=
p
.
add_literal
(
migraph
::
literal
{
a_shape
,
a
});
auto
al
=
p
.
add_literal
(
migraph
::
literal
{
a_shape
,
a
});
migraph
::
shape
b_shape
{
migraph
::
shape
::
floa
t_type
,
{
5
,
3
}};
migraph
::
shape
b_shape
{
migraph
::
shape
::
ge
t_type
<
T
>
{}
,
{
5
,
3
}};
auto
bl
=
p
.
add_literal
(
migraph
::
literal
{
b_shape
,
b
});
auto
bl
=
p
.
add_literal
(
migraph
::
literal
{
b_shape
,
b
});
p
.
add_instruction
(
migraph
::
gemm
{},
al
,
bl
);
p
.
add_instruction
(
migraph
::
gemm
{},
al
,
bl
);
p
.
compile
(
migraph
::
cpu
::
cpu_target
{});
p
.
compile
(
migraph
::
cpu
::
cpu_target
{});
auto
result
=
p
.
eval
({});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
results_vector
(
12
);
std
::
vector
<
T
>
results_vector
(
12
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
float
tol
=
1e-6
;
float
tol
=
1e-6
;
for
(
int
i
=
0
;
i
<
results_vector
.
size
();
i
++
)
for
(
int
i
=
0
;
i
<
results_vector
.
size
();
i
++
)
...
@@ -656,7 +657,8 @@ int main()
...
@@ -656,7 +657,8 @@ int main()
add_broadcast_test
();
add_broadcast_test
();
sub_test
();
sub_test
();
mul_test
();
mul_test
();
gemm_test
();
gemm_test
<
float
>
();
gemm_test
<
double
>
();
reshape_test
();
reshape_test
();
transpose_test
();
transpose_test
();
contiguous_test
();
contiguous_test
();
...
...
test/gpu/miopen.cpp
View file @
238bfadd
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include <migraph/gpu/miopen.hpp>
#include <migraph/gpu/miopen.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/manage_ptr.hpp>
#include <migraph/manage_ptr.hpp>
#include <migraph/type_name.hpp>
#include <miopen/miopen.h>
#include <miopen/miopen.h>
...
@@ -48,7 +49,11 @@ void verify_program()
...
@@ -48,7 +49,11 @@ void verify_program()
{
{
auto
cpu_arg
=
run_cpu
<
V
>
();
auto
cpu_arg
=
run_cpu
<
V
>
();
auto
gpu_arg
=
run_gpu
<
V
>
();
auto
gpu_arg
=
run_gpu
<
V
>
();
visit_all
(
cpu_arg
,
gpu_arg
)([](
auto
cpu
,
auto
gpu
)
{
EXPECT
(
test
::
verify_range
(
cpu
,
gpu
));
});
visit_all
(
cpu_arg
,
gpu_arg
)([](
auto
cpu
,
auto
gpu
)
{
if
(
not
test
::
verify_range
(
cpu
,
gpu
))
{
std
::
cout
<<
"FAILED: "
<<
migraph
::
get_type_name
<
V
>
()
<<
std
::
endl
;
}
});
}
}
struct
test_literals
struct
test_literals
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment