Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
94d13003
Commit
94d13003
authored
Feb 27, 2019
by
Shucai Xiao
Browse files
add tests for extending the gemm operation
parent
ff40d99c
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
210 additions
and
9 deletions
+210
-9
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+6
-1
src/targets/cpu/gemm.cpp
src/targets/cpu/gemm.cpp
+13
-8
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+49
-0
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+42
-0
test/onnx/gemm_test_ex.onnx
test/onnx/gemm_test_ex.onnx
+0
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+21
-0
test/op_shape_test.cpp
test/op_shape_test.cpp
+79
-0
No files found.
src/onnx/onnx.cpp
View file @
94d13003
...
...
@@ -469,7 +469,12 @@ struct onnx_parser
{
transb
=
parse_value
(
attributes
.
at
(
"transB"
)).
at
<
bool
>
();
}
std
::
vector
<
int64_t
>
perm
=
{
1
,
0
};
std
::
vector
<
int64_t
>
perm
(
args
[
0
]
->
get_shape
().
lens
().
size
());
std
::
iota
(
perm
.
begin
(),
perm
.
end
(),
int64_t
{
0
});
// swap the last two elements
std
::
swap
(
*
perm
.
rbegin
(),
*
(
perm
.
rbegin
()
+
1
));
auto
l1
=
(
transa
)
?
prog
.
add_instruction
(
op
::
transpose
{
perm
},
args
[
0
])
:
args
[
0
];
auto
l2
=
(
transb
)
?
prog
.
add_instruction
(
op
::
transpose
{
perm
},
args
[
1
])
:
args
[
1
];
if
(
args
.
size
()
==
3
)
...
...
src/targets/cpu/gemm.cpp
View file @
94d13003
#include <migraphx/cpu/gemm.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/shape_for_each.hpp>
#include <blaze/math/CustomMatrix.h>
namespace
migraphx
{
...
...
@@ -70,18 +71,22 @@ void migemm_impl(tensor_view<T> cmat,
std
::
size_t
n_dims
=
cmat
.
get_shape
().
lens
().
size
();
std
::
size_t
dim_0
=
n_dims
-
2
;
std
::
size_t
dim_1
=
n_dims
-
1
;
auto
m
=
cmat
.
get_shape
().
lens
()[
dim_0
];
auto
n
=
cmat
.
get_shape
().
lens
()[
dim_1
];
auto
k
=
amat
.
get_shape
().
lens
()[
dim_1
];
assert
(
amat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
m
==
amat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
n
==
bmat
.
get_shape
().
lens
()[
dim_1
]);
assert
(
cmat
.
get_shape
().
lens
()[
dim_0
]
==
amat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
cmat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_1
]);
dfor
(
m
,
n
)([
&
](
auto
ii
,
auto
jj
)
{
double
s
=
cmat
(
ii
,
jj
)
*
beta
;
dfor
(
k
)([
&
](
auto
kk
)
{
s
+=
amat
(
ii
,
kk
)
*
bmat
(
kk
,
jj
);
});
cmat
(
ii
,
jj
)
=
alpha
*
s
;
shape_for_each
(
cmat
.
get_shape
(),
[
&
](
const
auto
&
c_idx
)
{
double
s
=
cmat
(
c_idx
.
begin
(),
c_idx
.
end
())
*
beta
;
auto
a_idx
=
c_idx
;
auto
b_idx
=
c_idx
;
dfor
(
k
)([
&
](
auto
kk
)
{
a_idx
[
dim_1
]
=
b_idx
[
dim_0
]
=
kk
;
s
+=
amat
(
a_idx
.
begin
(),
a_idx
.
end
())
*
bmat
(
b_idx
.
begin
(),
b_idx
.
end
());
});
cmat
(
c_idx
.
begin
(),
c_idx
.
end
())
=
alpha
*
s
;
});
}
...
...
test/cpu_ops_test.cpp
View file @
94d13003
...
...
@@ -925,6 +925,55 @@ void gemm_test()
TEST_CASE_REGISTER
(
gemm_test
<
float
>
)
TEST_CASE_REGISTER
(
gemm_test
<
double
>
)
template
<
class
T
>
void
gemm_test_ex
()
{
migraphx
::
program
p
;
std
::
vector
<
T
>
a
=
{
-
0.00925222
,
0.56250403
,
0.70107397
,
0.75402161
,
-
0.505885
,
1.33628943
,
-
0.11413
,
-
0.31270559
,
1.59336732
,
-
0.19361027
,
-
0.91620867
,
0.40108416
,
-
0.06969921
,
0.68483471
,
-
0.39906632
,
-
1.66423624
,
0.69040076
,
-
1.31490171
,
-
0.11282616
,
-
0.79391814
};
std
::
vector
<
float
>
b
=
{
6.09568541e-01
,
-
6.10527007e-01
,
3.66646462e-01
,
1.18951101e-01
,
5.58777432e-01
,
-
3.21296298e-01
,
-
5.95997198e-01
,
-
5.01425721e-01
,
-
2.84606807e-01
,
-
5.73673557e-01
,
-
8.99430260e-01
,
-
4.25103093e-01
,
1.53027987e+00
,
-
3.81407415e-04
,
-
3.29650255e-01
};
std
::
vector
<
float
>
c
=
{
-
1.56327541e+00
,
-
7.09570140e-01
,
-
5.37424982e-01
,
-
2.22994831e-01
,
-
2.15586437e+00
,
2.09177941e-03
,
-
1.47279677e+00
,
2.02627040e-01
,
-
6.04527691e-01
,
-
1.29885596e+00
,
2.16294914e+00
,
-
1.48101497e-01
};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
get_type
<
T
>
{},
{
1
,
1
,
4
,
5
}};
auto
al
=
p
.
add_literal
(
migraphx
::
literal
{
a_shape
,
a
});
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
get_type
<
T
>
{},
{
1
,
1
,
5
,
3
}};
auto
bl
=
p
.
add_literal
(
migraphx
::
literal
{
b_shape
,
b
});
p
.
add_instruction
(
migraphx
::
op
::
dot
{},
al
,
bl
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
T
>
results_vector
(
12
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
c
,
results_vector
));
}
TEST_CASE_REGISTER
(
gemm_test_ex
<
float
>
)
TEST_CASE_REGISTER
(
gemm_test_ex
<
double
>
)
TEST_CASE
(
maxpool_test
)
{
migraphx
::
program
p
;
...
...
test/gpu/miopen.cpp
View file @
94d13003
...
...
@@ -746,6 +746,18 @@ struct test_gemm
}
};
struct
test_gemm_ex
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
a
=
p
.
add_parameter
(
"a"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
4
,
5
}});
auto
b
=
p
.
add_parameter
(
"b"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
5
,
3
}});
p
.
add_instruction
(
migraphx
::
op
::
dot
{},
a
,
b
);
return
p
;
}
};
struct
test_gemm_half
{
migraphx
::
program
create_program
()
const
...
...
@@ -785,6 +797,19 @@ struct test_gemm_transposeb
}
};
struct
test_gemm_transposeb_ex
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
a
=
p
.
add_parameter
(
"a"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
5
}});
auto
b
=
p
.
add_parameter
(
"b"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
5
}});
auto
bt
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
1
}},
b
);
p
.
add_instruction
(
migraphx
::
op
::
dot
{},
a
,
bt
);
return
p
;
}
};
struct
test_gemm_transposea
{
migraphx
::
program
create_program
()
const
...
...
@@ -798,6 +823,20 @@ struct test_gemm_transposea
}
};
struct
test_gemm_transposea_ex
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
a
=
p
.
add_parameter
(
"a"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
5
,
4
}});
auto
b
=
p
.
add_parameter
(
"b"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
5
,
3
}});
auto
at
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
3
,
2
}},
a
);
p
.
add_instruction
(
migraphx
::
op
::
dot
{},
at
,
b
);
return
p
;
}
};
struct
test_gemm_transposeab
{
migraphx
::
program
create_program
()
const
...
...
@@ -2936,10 +2975,13 @@ int main()
verify_program
<
test_global_avg_pooling
>
();
verify_program
<
test_global_max_pooling
>
();
verify_program
<
test_gemm
>
();
verify_program
<
test_gemm_ex
>
();
verify_program
<
test_gemm_half
>
();
// verify_program<test_gemm_ld>();
verify_program
<
test_gemm_transposeb
>
();
verify_program
<
test_gemm_transposeb_ex
>
();
verify_program
<
test_gemm_transposea
>
();
verify_program
<
test_gemm_transposea_ex
>
();
verify_program
<
test_gemm_transposeab
>
();
verify_program
<
test_contiguous
>
();
verify_program
<
test_eliminate_contiguous
>
();
...
...
test/onnx/gemm_test_ex.onnx
0 → 100644
View file @
94d13003
File added
test/onnx/onnx_test.cpp
View file @
94d13003
...
...
@@ -572,6 +572,27 @@ TEST_CASE(gemm_test)
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
gemm_ex
)
{
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"1"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
5
,
6
}});
auto
l1
=
p
.
add_parameter
(
"2"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
5
,
7
}});
auto
l2
=
p
.
add_parameter
(
"3"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
6
,
7
}});
auto
t0
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
3
,
2
}},
l0
);
auto
alpha
=
0.5
f
;
auto
res_ab
=
p
.
add_instruction
(
migraphx
::
op
::
dot
{
alpha
},
t0
,
l1
);
auto
beta
=
0.8
f
;
auto
l_beta
=
p
.
add_literal
(
beta
);
auto
brcst_beta
=
p
.
add_instruction
(
migraphx
::
op
::
scalar
{
l2
->
get_shape
()},
l_beta
);
auto
res_c
=
p
.
add_instruction
(
migraphx
::
op
::
mul
{},
l2
,
brcst_beta
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
res_ab
,
res_c
);
auto
prog
=
migraphx
::
parse_onnx
(
"gemm_test_ex.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
add_scalar_test
)
{
migraphx
::
program
p
;
...
...
test/op_shape_test.cpp
View file @
94d13003
...
...
@@ -316,6 +316,85 @@ TEST_CASE(gather)
}
}
TEST_CASE
(
dot
)
{
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
8
}},
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
6
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
throws_shape
(
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
1
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
1
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
}},
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
7
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
7
}},
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
5
,
7
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
4
,
7
}},
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
5
,
7
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
4
,
7
}},
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
4
,
6
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
5
,
7
}};
throws_shape
(
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
5
,
7
}};
throws_shape
(
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
5
,
7
}};
throws_shape
(
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
2
,
1
,
5
,
7
}};
throws_shape
(
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
);
}
}
TEST_CASE
(
rnn
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment