Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
57444235
Commit
57444235
authored
Oct 29, 2018
by
Khalique
Browse files
fix merge conflict
parents
a0ea12f6
d8bf45cf
Changes
49
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
236 additions
and
18 deletions
+236
-18
src/targets/gpu/relu.cpp
src/targets/gpu/relu.cpp
+1
-1
src/targets/gpu/rocblas.cpp
src/targets/gpu/rocblas.cpp
+7
-0
src/targets/gpu/softmax.cpp
src/targets/gpu/softmax.cpp
+1
-1
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+4
-5
test/CMakeLists.txt
test/CMakeLists.txt
+5
-2
test/common_subexpression_elimination_test.cpp
test/common_subexpression_elimination_test.cpp
+121
-0
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+52
-1
test/fwd_conv_batchnorm_rewrite_test.cpp
test/fwd_conv_batchnorm_rewrite_test.cpp
+3
-3
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+42
-5
No files found.
src/targets/gpu/relu.cpp
View file @
57444235
...
@@ -20,7 +20,7 @@ argument miopen_relu::compute(context& ctx,
...
@@ -20,7 +20,7 @@ argument miopen_relu::compute(context& ctx,
float
alpha
=
1
,
beta
=
0
;
float
alpha
=
1
,
beta
=
0
;
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
auto
y_desc
=
make_tensor
(
output_shape
);
miopenActivationForward
(
ctx
.
handle
.
get
(),
miopenActivationForward
(
ctx
.
get_stream
().
get_miopen
(),
ad
.
get
(),
ad
.
get
(),
&
alpha
,
&
alpha
,
x_desc
.
get
(),
x_desc
.
get
(),
...
...
src/targets/gpu/rocblas.cpp
View file @
57444235
...
@@ -10,6 +10,13 @@ rocblas_handle_ptr create_rocblas_handle_ptr()
...
@@ -10,6 +10,13 @@ rocblas_handle_ptr create_rocblas_handle_ptr()
return
rocblas_handle_ptr
{
handle
};
return
rocblas_handle_ptr
{
handle
};
}
}
rocblas_handle_ptr
create_rocblas_handle_ptr
(
hipStream_t
s
)
{
rocblas_handle_ptr
rb
=
create_rocblas_handle_ptr
();
rocblas_set_stream
(
rb
.
get
(),
s
);
return
rb
;
}
}
// namespace gpu
}
// namespace gpu
}
// namespace migraph
}
// namespace migraph
src/targets/gpu/softmax.cpp
View file @
57444235
...
@@ -20,7 +20,7 @@ argument miopen_softmax::compute(context& ctx,
...
@@ -20,7 +20,7 @@ argument miopen_softmax::compute(context& ctx,
float
alpha
=
1
,
beta
=
0
;
float
alpha
=
1
,
beta
=
0
;
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
auto
y_desc
=
make_tensor
(
output_shape
);
miopenSoftmaxForward
(
ctx
.
handle
.
get
(),
miopenSoftmaxForward
(
ctx
.
get_stream
().
get_miopen
(),
&
alpha
,
&
alpha
,
x_desc
.
get
(),
x_desc
.
get
(),
args
[
0
].
implicit
(),
args
[
0
].
implicit
(),
...
...
src/targets/gpu/target.cpp
View file @
57444235
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
#include <migraph/simplify_algebra.hpp>
#include <migraph/simplify_algebra.hpp>
#include <migraph/constant_propagate.hpp>
#include <migraph/constant_propagate.hpp>
#include <migraph/eliminate_contiguous.hpp>
#include <migraph/eliminate_contiguous.hpp>
#include <migraph/common_subexpression_elimination.hpp>
#include <migraph/fwd_conv_batchnorm_rewrite.hpp>
#include <migraph/fwd_conv_batchnorm_rewrite.hpp>
namespace
migraph
{
namespace
migraph
{
...
@@ -27,6 +28,8 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
...
@@ -27,6 +28,8 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
dead_code_elimination
{},
dead_code_elimination
{},
fwd_conv_batchnorm_rewrite
{},
fwd_conv_batchnorm_rewrite
{},
dead_code_elimination
{},
dead_code_elimination
{},
common_subexpression_elimination
{},
dead_code_elimination
{},
simplify_algebra
{},
simplify_algebra
{},
dead_code_elimination
{},
dead_code_elimination
{},
constant_propagate
{},
constant_propagate
{},
...
@@ -51,10 +54,6 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
...
@@ -51,10 +54,6 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
std
::
string
target
::
name
()
const
{
return
"miopen"
;
}
std
::
string
target
::
name
()
const
{
return
"miopen"
;
}
migraph
::
context
target
::
get_context
()
const
migraph
::
context
target
::
get_context
()
const
{
return
context
{};
}
{
return
context
{
share
(
make_obj
<
miopen_handle
>
(
&
miopenCreate
)),
share
(
create_rocblas_handle_ptr
()),
{}};
}
}
// namespace gpu
}
// namespace gpu
}
// namespace migraph
}
// namespace migraph
test/CMakeLists.txt
View file @
57444235
...
@@ -7,7 +7,7 @@ find_package(Threads REQUIRED)
...
@@ -7,7 +7,7 @@ find_package(Threads REQUIRED)
include
(
ProcessorCount
)
include
(
ProcessorCount
)
ProcessorCount
(
N
)
ProcessorCount
(
N
)
set
(
CTEST_PARALLEL_LEVEL
${
N
}
CACHE STRING
"CTest parallel level"
)
set
(
CTEST_PARALLEL_LEVEL
${
N
}
CACHE STRING
"CTest parallel level"
)
add_custom_target
(
check COMMAND
${
CMAKE_CTEST_COMMAND
}
--output-on-failure -j
${
CTEST_PARALLEL_LEVEL
}
-C
${
CMAKE_CFG_INTDIR
}
)
add_custom_target
(
check COMMAND
${
CMAKE_CTEST_COMMAND
}
--output-on-failure -j
${
CTEST_PARALLEL_LEVEL
}
-C
${
CMAKE_CFG_INTDIR
}
--timeout 1500
)
add_custom_target
(
tests
)
add_custom_target
(
tests
)
find_program
(
MIGRAPH_GDB gdb
)
find_program
(
MIGRAPH_GDB gdb
)
...
@@ -103,7 +103,10 @@ if(MIGRAPH_ENABLE_GPU)
...
@@ -103,7 +103,10 @@ if(MIGRAPH_ENABLE_GPU)
get_filename_component
(
BASE_NAME
${
TEST
}
NAME_WE
)
get_filename_component
(
BASE_NAME
${
TEST
}
NAME_WE
)
add_test_executable
(
test_gpu_
${
BASE_NAME
}
${
TEST
}
)
add_test_executable
(
test_gpu_
${
BASE_NAME
}
${
TEST
}
)
rocm_clang_tidy_check
(
test_gpu_
${
BASE_NAME
}
)
rocm_clang_tidy_check
(
test_gpu_
${
BASE_NAME
}
)
set_tests_properties
(
test_gpu_
${
BASE_NAME
}
PROPERTIES COST 10
)
set_tests_properties
(
test_gpu_
${
BASE_NAME
}
PROPERTIES
COST 10
RESOURCE_LOCK gpu
)
target_link_libraries
(
test_gpu_
${
BASE_NAME
}
migraph_gpu
)
target_link_libraries
(
test_gpu_
${
BASE_NAME
}
migraph_gpu
)
endforeach
()
endforeach
()
endif
()
endif
()
...
...
test/common_subexpression_elimination_test.cpp
0 → 100644
View file @
57444235
#include <migraph/common_subexpression_elimination.hpp>
#include <migraph/dead_code_elimination.hpp>
#include <migraph/operators.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct
cse_target
{
std
::
string
name
()
const
{
return
"dce"
;
}
std
::
vector
<
migraph
::
pass
>
get_passes
(
migraph
::
context
&
)
const
{
return
{
migraph
::
common_subexpression_elimination
{},
migraph
::
dead_code_elimination
{}};
}
migraph
::
context
get_context
()
const
{
return
{};
}
};
void
cse_test1
()
{
migraph
::
program
p1
;
{
auto
one
=
p1
.
add_literal
(
1
);
auto
two
=
p1
.
add_literal
(
2
);
auto
sum1
=
p1
.
add_instruction
(
migraph
::
op
::
add
{},
one
,
two
);
auto
sum2
=
p1
.
add_instruction
(
migraph
::
op
::
add
{},
one
,
two
);
auto
sum3
=
p1
.
add_instruction
(
migraph
::
op
::
add
{},
sum1
,
sum2
);
p1
.
add_instruction
(
pass_op
{},
sum3
);
}
p1
.
compile
(
cse_target
{});
migraph
::
program
p2
;
{
auto
one
=
p2
.
add_literal
(
1
);
auto
two
=
p2
.
add_literal
(
2
);
auto
sum1
=
p2
.
add_instruction
(
migraph
::
op
::
add
{},
one
,
two
);
auto
sum3
=
p2
.
add_instruction
(
migraph
::
op
::
add
{},
sum1
,
sum1
);
p2
.
add_instruction
(
pass_op
{},
sum3
);
}
EXPECT
(
p1
==
p2
);
}
void
cse_test2
()
{
migraph
::
program
p1
;
{
auto
one
=
p1
.
add_literal
(
1
);
auto
two
=
p1
.
add_literal
(
2
);
auto
sum1
=
p1
.
add_instruction
(
migraph
::
op
::
add
{},
one
,
two
);
auto
sum2
=
p1
.
add_instruction
(
migraph
::
op
::
add
{},
two
,
one
);
auto
sum3
=
p1
.
add_instruction
(
migraph
::
op
::
add
{},
sum1
,
sum2
);
p1
.
add_instruction
(
pass_op
{},
sum3
);
}
p1
.
compile
(
cse_target
{});
migraph
::
program
p2
;
{
auto
one
=
p2
.
add_literal
(
1
);
auto
two
=
p2
.
add_literal
(
2
);
auto
sum1
=
p2
.
add_instruction
(
migraph
::
op
::
add
{},
one
,
two
);
auto
sum2
=
p2
.
add_instruction
(
migraph
::
op
::
add
{},
two
,
one
);
auto
sum3
=
p2
.
add_instruction
(
migraph
::
op
::
add
{},
sum1
,
sum2
);
p2
.
add_instruction
(
pass_op
{},
sum3
);
}
EXPECT
(
p1
==
p2
);
}
void
cse_test3
()
{
migraph
::
program
p1
;
{
auto
one
=
p1
.
add_literal
(
1
);
auto
two
=
p1
.
add_literal
(
1
);
auto
sum1
=
p1
.
add_instruction
(
migraph
::
op
::
add
{},
one
,
two
);
auto
sum2
=
p1
.
add_instruction
(
migraph
::
op
::
add
{},
two
,
one
);
auto
sum3
=
p1
.
add_instruction
(
migraph
::
op
::
add
{},
sum1
,
sum2
);
p1
.
add_instruction
(
pass_op
{},
sum3
);
}
p1
.
compile
(
cse_target
{});
migraph
::
program
p2
;
{
auto
one
=
p2
.
add_literal
(
1
);
auto
sum1
=
p2
.
add_instruction
(
migraph
::
op
::
add
{},
one
,
one
);
auto
sum3
=
p2
.
add_instruction
(
migraph
::
op
::
add
{},
sum1
,
sum1
);
p2
.
add_instruction
(
pass_op
{},
sum3
);
}
EXPECT
(
p1
==
p2
);
}
void
cse_test4
()
{
migraph
::
program
p1
;
{
auto
one
=
p1
.
add_literal
(
1
);
auto
two
=
p1
.
add_literal
(
1
);
auto
sum1
=
p1
.
add_instruction
(
migraph
::
op
::
add
{},
one
,
two
);
auto
sum2
=
p1
.
add_instruction
(
migraph
::
op
::
add
{},
two
,
one
);
auto
sum3
=
p1
.
add_instruction
(
migraph
::
op
::
add
{},
sum1
,
one
);
auto
sum4
=
p1
.
add_instruction
(
migraph
::
op
::
add
{},
sum2
,
two
);
auto
sum5
=
p1
.
add_instruction
(
migraph
::
op
::
add
{},
sum4
,
sum3
);
p1
.
add_instruction
(
pass_op
{},
sum5
);
}
p1
.
compile
(
cse_target
{});
migraph
::
program
p2
;
{
auto
one
=
p2
.
add_literal
(
1
);
auto
sum1
=
p2
.
add_instruction
(
migraph
::
op
::
add
{},
one
,
one
);
auto
sum3
=
p2
.
add_instruction
(
migraph
::
op
::
add
{},
sum1
,
one
);
auto
sum5
=
p2
.
add_instruction
(
migraph
::
op
::
add
{},
sum3
,
sum3
);
p2
.
add_instruction
(
pass_op
{},
sum5
);
}
EXPECT
(
p1
==
p2
);
}
int
main
()
{
cse_test1
();
cse_test2
();
cse_test3
();
cse_test4
();
}
test/cpu_ops_test.cpp
View file @
57444235
...
@@ -47,6 +47,56 @@ void slice_test()
...
@@ -47,6 +47,56 @@ void slice_test()
}
}
}
}
void
concat_test
()
{
{
migraph
::
program
p
;
std
::
size_t
axis
=
1
;
std
::
vector
<
int
>
data0
=
{
0
,
1
,
5
,
6
};
std
::
vector
<
int
>
data1
=
{
2
,
3
,
4
,
7
,
8
,
9
};
std
::
vector
<
int
>
data2
=
{
10
,
20
};
migraph
::
shape
s0
{
migraph
::
shape
::
int32_type
,
{
2
,
2
}};
migraph
::
shape
s1
{
migraph
::
shape
::
int32_type
,
{
2
,
3
}};
migraph
::
shape
s2
{
migraph
::
shape
::
int32_type
,
{
2
,
1
}};
auto
l0
=
p
.
add_literal
(
migraph
::
literal
{
s0
,
data0
});
auto
l1
=
p
.
add_literal
(
migraph
::
literal
{
s1
,
data1
});
auto
l2
=
p
.
add_literal
(
migraph
::
literal
{
s2
,
data2
});
p
.
add_instruction
(
migraph
::
op
::
concat
{
axis
},
l0
,
l1
,
l2
);
p
.
compile
(
migraph
::
cpu
::
cpu_target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
int
>
gold
=
{
0
,
1
,
2
,
3
,
4
,
10
,
5
,
6
,
7
,
8
,
9
,
20
};
std
::
vector
<
int
>
results_vector
(
2
*
6
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraph
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraph
::
verify_range
(
result
.
get_shape
().
lens
(),
std
::
vector
<
std
::
size_t
>
({
2
,
6
})));
EXPECT
(
migraph
::
verify_range
(
result
.
get_shape
().
strides
(),
std
::
vector
<
std
::
size_t
>
({
6
,
1
})));
}
{
migraph
::
program
p
;
std
::
size_t
axis
=
0
;
std
::
vector
<
int
>
data0
=
{
0
,
1
,
2
,
3
};
std
::
vector
<
int
>
data1
=
{
4
,
5
,
6
,
7
,
8
,
9
};
std
::
vector
<
int
>
data2
=
{
10
,
11
};
migraph
::
shape
s0
{
migraph
::
shape
::
int32_type
,
{
2
,
2
}};
migraph
::
shape
s1
{
migraph
::
shape
::
int32_type
,
{
3
,
2
}};
migraph
::
shape
s2
{
migraph
::
shape
::
int32_type
,
{
1
,
2
}};
auto
l0
=
p
.
add_literal
(
migraph
::
literal
{
s0
,
data0
});
auto
l1
=
p
.
add_literal
(
migraph
::
literal
{
s1
,
data1
});
auto
l2
=
p
.
add_literal
(
migraph
::
literal
{
s2
,
data2
});
p
.
add_instruction
(
migraph
::
op
::
concat
{
axis
},
l0
,
l1
,
l2
);
p
.
compile
(
migraph
::
cpu
::
cpu_target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
int
>
gold
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
};
std
::
vector
<
int
>
results_vector
(
6
*
2
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraph
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraph
::
verify_range
(
result
.
get_shape
().
lens
(),
std
::
vector
<
std
::
size_t
>
({
6
,
2
})));
EXPECT
(
migraph
::
verify_range
(
result
.
get_shape
().
strides
(),
std
::
vector
<
std
::
size_t
>
({
2
,
1
})));
}
}
void
squeeze_test
()
void
squeeze_test
()
{
{
{
{
...
@@ -605,7 +655,7 @@ void gemm_test()
...
@@ -605,7 +655,7 @@ void gemm_test()
auto
al
=
p
.
add_literal
(
migraph
::
literal
{
a_shape
,
a
});
auto
al
=
p
.
add_literal
(
migraph
::
literal
{
a_shape
,
a
});
migraph
::
shape
b_shape
{
migraph
::
shape
::
get_type
<
T
>
{},
{
5
,
3
}};
migraph
::
shape
b_shape
{
migraph
::
shape
::
get_type
<
T
>
{},
{
5
,
3
}};
auto
bl
=
p
.
add_literal
(
migraph
::
literal
{
b_shape
,
b
});
auto
bl
=
p
.
add_literal
(
migraph
::
literal
{
b_shape
,
b
});
p
.
add_instruction
(
migraph
::
op
::
gemm
{},
al
,
bl
);
p
.
add_instruction
(
migraph
::
op
::
dot
{},
al
,
bl
);
p
.
compile
(
migraph
::
cpu
::
cpu_target
{});
p
.
compile
(
migraph
::
cpu
::
cpu_target
{});
auto
result
=
p
.
eval
({});
auto
result
=
p
.
eval
({});
std
::
vector
<
T
>
results_vector
(
12
);
std
::
vector
<
T
>
results_vector
(
12
);
...
@@ -970,6 +1020,7 @@ void contiguous_test()
...
@@ -970,6 +1020,7 @@ void contiguous_test()
int
main
()
int
main
()
{
{
concat_test
();
slice_test
();
slice_test
();
squeeze_test
();
squeeze_test
();
unsqueeze_test
();
unsqueeze_test
();
...
...
test/fwd_conv_batchnorm_rewrite_test.cpp
View file @
57444235
...
@@ -36,9 +36,9 @@ void fwd_conv_batchnorm_rewrite_test()
...
@@ -36,9 +36,9 @@ void fwd_conv_batchnorm_rewrite_test()
auto
create_program
=
[
&
]()
{
auto
create_program
=
[
&
]()
{
migraph
::
program
p
;
migraph
::
program
p
;
auto
x
=
p
.
add_literal
(
xs
,
xdata
);
auto
x
=
p
.
add_literal
(
xs
,
xdata
);
auto
w
=
p
.
add_literal
(
ws
,
wdata
);
auto
w
=
p
.
add_literal
(
ws
,
wdata
);
auto
conv
=
p
.
add_instruction
(
migraph
::
op
::
convolution
{{
0
,
0
},
{
1
,
1
},
{
1
,
1
}},
x
,
w
);
auto
conv
=
p
.
add_instruction
(
migraph
::
op
::
convolution
{{
{
0
,
0
}
}
,
{
{
1
,
1
}
}
,
{
{
1
,
1
}}
}
,
x
,
w
);
auto
scale
=
p
.
add_literal
(
migraph
::
literal
{
vars
,
{
3.0
f
}});
auto
scale
=
p
.
add_literal
(
migraph
::
literal
{
vars
,
{
3.0
f
}});
auto
bias
=
p
.
add_literal
(
migraph
::
literal
{
vars
,
{
8.1
f
}});
auto
bias
=
p
.
add_literal
(
migraph
::
literal
{
vars
,
{
8.1
f
}});
auto
mean
=
p
.
add_literal
(
migraph
::
literal
{
vars
,
{
4.0
f
}});
auto
mean
=
p
.
add_literal
(
migraph
::
literal
{
vars
,
{
4.0
f
}});
...
...
test/gpu/miopen.cpp
View file @
57444235
...
@@ -129,6 +129,7 @@ template <class V>
...
@@ -129,6 +129,7 @@ template <class V>
void
verify_program
()
void
verify_program
()
{
{
auto_print
::
set_terminate_handler
(
migraph
::
get_type_name
<
V
>
());
auto_print
::
set_terminate_handler
(
migraph
::
get_type_name
<
V
>
());
// std::cout << migraph::get_type_name<V>() << std::endl;
migraph
::
program
cpu_prog
;
migraph
::
program
cpu_prog
;
migraph
::
program
gpu_prog
;
migraph
::
program
gpu_prog
;
auto
cpu_arg_f
=
detach_async
([
&
]
{
return
run_cpu
<
V
>
(
cpu_prog
);
});
auto
cpu_arg_f
=
detach_async
([
&
]
{
return
run_cpu
<
V
>
(
cpu_prog
);
});
...
@@ -429,7 +430,7 @@ struct test_gemm
...
@@ -429,7 +430,7 @@ struct test_gemm
migraph
::
program
p
;
migraph
::
program
p
;
auto
a
=
p
.
add_parameter
(
"a"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
4
,
5
}});
auto
a
=
p
.
add_parameter
(
"a"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
4
,
5
}});
auto
b
=
p
.
add_parameter
(
"b"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
5
,
3
}});
auto
b
=
p
.
add_parameter
(
"b"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
5
,
3
}});
p
.
add_instruction
(
migraph
::
op
::
gemm
{},
a
,
b
);
p
.
add_instruction
(
migraph
::
op
::
dot
{},
a
,
b
);
return
p
;
return
p
;
}
}
};
};
...
@@ -441,7 +442,7 @@ struct test_gemm_ld
...
@@ -441,7 +442,7 @@ struct test_gemm_ld
migraph
::
program
p
;
migraph
::
program
p
;
auto
a
=
p
.
add_parameter
(
"a"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
4
,
5
},
{
10
,
1
}});
auto
a
=
p
.
add_parameter
(
"a"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
4
,
5
},
{
10
,
1
}});
auto
b
=
p
.
add_parameter
(
"b"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
5
,
3
},
{
20
,
1
}});
auto
b
=
p
.
add_parameter
(
"b"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
5
,
3
},
{
20
,
1
}});
p
.
add_instruction
(
migraph
::
op
::
gemm
{},
a
,
b
);
p
.
add_instruction
(
migraph
::
op
::
dot
{},
a
,
b
);
return
p
;
return
p
;
}
}
};
};
...
@@ -454,7 +455,7 @@ struct test_gemm_transposeb
...
@@ -454,7 +455,7 @@ struct test_gemm_transposeb
auto
a
=
p
.
add_parameter
(
"a"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
4
,
5
}});
auto
a
=
p
.
add_parameter
(
"a"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
4
,
5
}});
auto
b
=
p
.
add_parameter
(
"b"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
3
,
5
}});
auto
b
=
p
.
add_parameter
(
"b"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
3
,
5
}});
auto
bt
=
p
.
add_instruction
(
migraph
::
op
::
transpose
{{
1
,
0
}},
b
);
auto
bt
=
p
.
add_instruction
(
migraph
::
op
::
transpose
{{
1
,
0
}},
b
);
p
.
add_instruction
(
migraph
::
op
::
gemm
{},
a
,
bt
);
p
.
add_instruction
(
migraph
::
op
::
dot
{},
a
,
bt
);
return
p
;
return
p
;
}
}
};
};
...
@@ -467,7 +468,7 @@ struct test_gemm_transposea
...
@@ -467,7 +468,7 @@ struct test_gemm_transposea
auto
a
=
p
.
add_parameter
(
"a"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
5
,
4
}});
auto
a
=
p
.
add_parameter
(
"a"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
5
,
4
}});
auto
b
=
p
.
add_parameter
(
"b"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
5
,
3
}});
auto
b
=
p
.
add_parameter
(
"b"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
5
,
3
}});
auto
at
=
p
.
add_instruction
(
migraph
::
op
::
transpose
{{
1
,
0
}},
a
);
auto
at
=
p
.
add_instruction
(
migraph
::
op
::
transpose
{{
1
,
0
}},
a
);
p
.
add_instruction
(
migraph
::
op
::
gemm
{},
at
,
b
);
p
.
add_instruction
(
migraph
::
op
::
dot
{},
at
,
b
);
return
p
;
return
p
;
}
}
};
};
...
@@ -481,7 +482,7 @@ struct test_gemm_transposeab
...
@@ -481,7 +482,7 @@ struct test_gemm_transposeab
auto
b
=
p
.
add_parameter
(
"b"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
3
,
5
}});
auto
b
=
p
.
add_parameter
(
"b"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
3
,
5
}});
auto
at
=
p
.
add_instruction
(
migraph
::
op
::
transpose
{{
1
,
0
}},
a
);
auto
at
=
p
.
add_instruction
(
migraph
::
op
::
transpose
{{
1
,
0
}},
a
);
auto
bt
=
p
.
add_instruction
(
migraph
::
op
::
transpose
{{
1
,
0
}},
b
);
auto
bt
=
p
.
add_instruction
(
migraph
::
op
::
transpose
{{
1
,
0
}},
b
);
p
.
add_instruction
(
migraph
::
op
::
gemm
{},
at
,
bt
);
p
.
add_instruction
(
migraph
::
op
::
dot
{},
at
,
bt
);
return
p
;
return
p
;
}
}
};
};
...
@@ -604,6 +605,40 @@ struct test_conv_bn_relu_pooling
...
@@ -604,6 +605,40 @@ struct test_conv_bn_relu_pooling
}
}
};
};
struct
test_concat
{
migraph
::
program
create_program
()
const
{
migraph
::
program
p
;
std
::
size_t
axis
=
1
;
migraph
::
shape
s0
{
migraph
::
shape
::
int32_type
,
{
2
,
2
}};
migraph
::
shape
s1
{
migraph
::
shape
::
int32_type
,
{
2
,
3
}};
migraph
::
shape
s2
{
migraph
::
shape
::
int32_type
,
{
2
,
1
}};
auto
l0
=
p
.
add_parameter
(
"x"
,
s0
);
auto
l1
=
p
.
add_parameter
(
"y"
,
s1
);
auto
l2
=
p
.
add_parameter
(
"z"
,
s2
);
p
.
add_instruction
(
migraph
::
op
::
concat
{
axis
},
l0
,
l1
,
l2
);
return
p
;
}
};
struct
test_concat2
{
migraph
::
program
create_program
()
const
{
migraph
::
program
p
;
std
::
size_t
axis
=
0
;
migraph
::
shape
s0
{
migraph
::
shape
::
int32_type
,
{
2
,
2
}};
migraph
::
shape
s1
{
migraph
::
shape
::
int32_type
,
{
3
,
2
}};
migraph
::
shape
s2
{
migraph
::
shape
::
int32_type
,
{
1
,
2
}};
auto
l0
=
p
.
add_parameter
(
"x"
,
s0
);
auto
l1
=
p
.
add_parameter
(
"y"
,
s1
);
auto
l2
=
p
.
add_parameter
(
"z"
,
s2
);
p
.
add_instruction
(
migraph
::
op
::
concat
{
axis
},
l0
,
l1
,
l2
);
return
p
;
}
};
struct
test_conv_bn_relu_pooling2
struct
test_conv_bn_relu_pooling2
{
{
static
migraph
::
instruction_ref
static
migraph
::
instruction_ref
...
@@ -642,6 +677,8 @@ struct test_conv_bn_relu_pooling2
...
@@ -642,6 +677,8 @@ struct test_conv_bn_relu_pooling2
int
main
()
int
main
()
{
{
verify_program
<
test_concat
>
();
verify_program
<
test_concat2
>
();
verify_program
<
test_add
>
();
verify_program
<
test_add
>
();
verify_program
<
test_mul
>
();
verify_program
<
test_mul
>
();
verify_program
<
test_scale
>
();
verify_program
<
test_scale
>
();
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment