Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3a0c7c77
Commit
3a0c7c77
authored
Jun 25, 2018
by
Scott Thornton
Browse files
Fixed up computing shape for pooling
parent
bff0223b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
16 deletions
+72
-16
src/include/rtg/operators.hpp
src/include/rtg/operators.hpp
+14
-16
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+58
-0
No files found.
src/include/rtg/operators.hpp
View file @
3a0c7c77
...
@@ -183,20 +183,17 @@ struct pooling
...
@@ -183,20 +183,17 @@ struct pooling
assert
(
lengths
[
0
]
<
(
input
.
lens
()[
2
]
+
2
*
padding
[
0
]));
assert
(
lengths
[
0
]
<
(
input
.
lens
()[
2
]
+
2
*
padding
[
0
]));
assert
(
lengths
[
1
]
<
(
input
.
lens
()[
3
]
+
2
*
padding
[
1
]));
assert
(
lengths
[
1
]
<
(
input
.
lens
()[
3
]
+
2
*
padding
[
1
]));
return
{
t
,
return
{
t
,
{
{
input
.
lens
()[
0
],
input
.
lens
()[
0
],
input
.
lens
()[
1
],
input
.
lens
()[
1
],
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
1
,
(
input
.
lens
()[
2
]
+
2
*
padding
[
0
]
-
lengths
[
0
])
/
stride
[
0
])
+
std
::
ceil
((
input
.
lens
()[
2
]
+
2
*
padding
[
0
]
-
lengths
[
0
])
/
1
),
static_cast
<
float
>
(
stride
[
0
]))
+
1
)),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
1
,
(
input
.
lens
()[
3
]
+
2
*
padding
[
1
]
-
lengths
[
1
])
/
stride
[
1
])
+
std
::
ceil
((
input
.
lens
()[
3
]
+
2
*
padding
[
1
]
-
lengths
[
1
])
/
1
),
static_cast
<
float
>
(
stride
[
1
]))
+
1
)),
}};
}};
}
}
...
@@ -320,7 +317,7 @@ struct gemm
...
@@ -320,7 +317,7 @@ struct gemm
std
::
string
name
()
const
{
return
"gemm"
;
}
std
::
string
name
()
const
{
return
"gemm"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
}.
has
(
2
).
same_type
();
check_shapes
{
inputs
}.
has
(
2
).
same_type
()
.
same_ndims
().
only_dims
(
2
)
;
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
b
=
inputs
.
at
(
1
);
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
auto
t
=
a
.
type
();
...
@@ -431,6 +428,7 @@ struct broadcast
...
@@ -431,6 +428,7 @@ struct broadcast
auto
input
=
inputs
.
at
(
1
);
auto
input
=
inputs
.
at
(
1
);
std
::
vector
<
size_t
>
bcast_strides
(
result
.
lens
().
size
(),
0
);
std
::
vector
<
size_t
>
bcast_strides
(
result
.
lens
().
size
(),
0
);
if
(
std
::
all_of
(
if
(
std
::
all_of
(
result
.
lens
().
cbegin
(),
result
.
lens
().
cend
(),
[
&
](
auto
x
)
{
return
x
==
1
;
}))
result
.
lens
().
cbegin
(),
result
.
lens
().
cend
(),
[
&
](
auto
x
)
{
return
x
==
1
;
}))
{
{
...
...
test/cpu_ops_test.cpp
View file @
3a0c7c77
...
@@ -252,6 +252,63 @@ void gemm_test()
...
@@ -252,6 +252,63 @@ void gemm_test()
}
}
}
}
void
maxpool_test
()
{
rtg
::
program
p
;
std
::
vector
<
float
>
a
=
{
-
2.1314404
,
-
1.63041711
,
1.54562736
,
1.04625261
,
-
1.42931843
,
-
0.48703974
,
0.4065806
,
-
0.1524526
,
1.30775225
,
0.45538983
,
-
0.06631992
,
-
1.75332725
,
1.33493888
,
0.47327688
,
0.36873096
,
1.18358743
,
-
0.34640595
,
1.22098756
,
0.01946825
,
-
0.20238149
,
0.43348005
,
-
0.67991608
,
-
0.83041084
,
0.93537551
,
0.70241445
,
-
0.5654031
,
-
1.30899191
,
-
0.26735824
,
-
0.52444768
,
1.99097753
,
1.86504853
,
-
0.26506025
,
0.26236168
,
0.43763575
,
0.95300823
,
-
1.02733946
,
-
0.74655169
,
-
0.5374338
,
-
0.28901565
,
-
0.59789604
,
0.5310151
,
0.99125904
,
0.40609556
,
-
1.57175648
,
0.22031412
,
1.45862222
,
0.53217483
,
1.39087725
,
1.00170159
,
-
0.87175864
,
-
1.7204628
,
-
1.72008383
,
-
0.38656762
,
-
0.01443311
,
1.46645272
,
-
1.39995027
,
0.22505587
,
-
0.43461126
,
-
0.05511411
,
-
0.79950953
,
-
0.01439556
,
0.08795211
,
1.18943918
,
-
0.84079367
,
-
1.73383629
,
-
0.55662078
,
-
0.30626822
,
-
0.67339015
,
0.44179603
,
0.54316711
,
0.40899998
,
-
0.27831686
,
-
1.11900508
,
-
0.0881724
,
0.35483059
,
2.36277103
,
-
0.04765317
,
-
0.36865309
,
0.73814237
,
1.47151589
,
1.36546791
,
-
0.32649881
,
-
1.0517807
,
2.24768877
,
0.68883753
,
0.58646208
,
-
0.91017133
,
-
0.50462508
,
-
0.4013325
,
-
0.72348958
,
-
0.47368807
,
0.35285577
,
-
1.01817429
,
-
0.5152272
,
0.60321307
,
0.43521205
,
-
0.23733577
,
0.66427642
,
0.82949388
,
0.82443929
,
0.71550399
,
0.34561086
,
0.68570769
,
-
0.40718508
,
-
1.20350206
,
0.15793853
,
-
2.31013632
,
-
0.07934658
,
-
0.09348056
,
0.36576006
,
2.46601582
,
0.11090943
,
0.9144392
,
0.56759721
,
-
0.22112127
,
-
0.21955389
,
0.72474903
,
-
1.28448462
,
1.53285873
,
0.37437943
,
0.31409341
,
1.95433736
,
0.91620457
,
0.86205518
,
1.24365854
,
0.19248386
,
0.22526583
,
0.13462132
,
-
0.27561715
,
-
2.06446075
,
-
0.02306402
,
-
1.38278747
,
1.1411345
,
1.31293464
,
-
1.86041689
,
1.06763375
,
-
0.26541466
,
1.4545635
,
1.11430049
,
-
0.66491818
,
0.87101674
,
0.67768967
,
-
1.02062869
,
-
1.05031872
,
-
2.2764678
,
-
2.0200038
,
0.37592548
,
-
0.26701379
,
-
0.83388507
,
0.19403623
,
1.00968623
,
0.11020003
,
1.16736257
,
-
1.1160326
,
0.47346735
,
0.6126079
,
-
0.19135755
,
1.33624589
,
-
0.29802522
,
-
0.57873946
,
-
1.06555879
,
-
0.20686582
,
1.36892557
,
-
0.19937795
,
0.8649236
,
-
1.40126073
,
1.53441942
,
0.34682792
,
-
1.31724346
,
-
1.32898355
,
2.40126371
,
0.07845283
,
1.35732043
,
-
0.63678312
,
0.39429256
,
-
1.36487007
,
-
0.31026676
,
-
0.44981545
,
-
0.28994772
,
-
0.14657612
,
-
1.75206447
,
-
0.70612341
,
1.20071781
,
-
1.64647579
,
-
0.7133292
,
0.88494766
,
0.52119428
,
-
2.77387547
,
2.07681108
,
-
0.90133125
,
0.2847338
,
0.6174528
,
-
0.20616426
,
-
0.64263535
,
-
1.08496261
,
0.54275119
,
-
0.88503587
,
0.6629802
,
1.47319221
,
-
1.05829155
,
-
0.97027361
,
-
0.93187737
,
-
1.39954746
,
-
0.52359426
,
-
0.14743951
,
1.51522756
,
0.2078452
,
-
1.28156149
,
-
1.19363916
,
-
0.78680223
,
-
0.89094824
,
1.30212069
,
-
0.77974445
,
-
0.58411664
,
0.48764706
,
-
0.67132682
};
std
::
vector
<
float
>
c
=
{
1.33493888
,
1.54562736
,
1.22098756
,
1.33493888
,
1.18358743
,
1.99097753
,
1.00170159
,
1.45862222
,
1.39087725
,
1.46645272
,
1.18943918
,
-
0.01443311
,
1.47151589
,
2.36277103
,
2.24768877
,
0.68883753
,
0.82949388
,
0.71550399
,
1.95433736
,
2.46601582
,
1.53285873
,
1.95433736
,
1.06763375
,
1.4545635
,
1.33624589
,
1.16736257
,
0.6126079
,
1.36892557
,
2.40126371
,
1.53441942
,
0.52119428
,
2.07681108
,
0.88494766
,
1.51522756
,
0.54275119
,
0.6629802
};
rtg
::
shape
a_shape
{
rtg
::
shape
::
float_type
,
{
2
,
3
,
6
,
6
}};
auto
al
=
p
.
add_literal
(
rtg
::
literal
{
a_shape
,
a
});
p
.
add_instruction
(
rtg
::
pooling
{
"max"
,
{{
0
,
0
}},
{{
2
,
2
}},
{{
3
,
2
}}},
al
);
p
.
compile
(
rtg
::
cpu
::
cpu_target
{});
auto
result
=
p
.
eval
({});
std
::
cout
<<
result
.
get_shape
()
<<
std
::
endl
;
std
::
vector
<
float
>
results_vector
(
36
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
float
tol
=
1e-6
;
for
(
int
i
=
0
;
i
<
results_vector
.
size
();
i
++
)
{
// std::cout << results_vector[i] << " " << c[i] << std::endl;
EXPECT
(
std
::
abs
(
results_vector
[
i
]
-
c
[
i
])
<
tol
);
}
}
void
softmax_test
()
void
softmax_test
()
{
{
rtg
::
program
p
;
rtg
::
program
p
;
...
@@ -564,6 +621,7 @@ int main()
...
@@ -564,6 +621,7 @@ int main()
transpose_test
();
transpose_test
();
contiguous_test
();
contiguous_test
();
softmax_test
();
softmax_test
();
maxpool_test
();
conv2d_test
();
conv2d_test
();
conv2d_padding_test
();
conv2d_padding_test
();
conv2d_padding_stride_test
();
conv2d_padding_stride_test
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment