Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
facefa02
Commit
facefa02
authored
Jun 20, 2020
by
Davis King
Browse files
Fix random foreset regression not doing quite the right thing.
parent
fe803b56
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
11 deletions
+20
-11
dlib/random_forest/random_forest_regression.h
dlib/random_forest/random_forest_regression.h
+7
-6
dlib/test/random_forest.cpp
dlib/test/random_forest.cpp
+13
-5
No files found.
dlib/random_forest/random_forest_regression.h
View file @
facefa02
...
@@ -376,8 +376,6 @@ namespace dlib
...
@@ -376,8 +376,6 @@ namespace dlib
std
::
vector
<
std
::
vector
<
internal_tree_node
<
feature_extractor
>>>
all_trees
(
num_trees
);
std
::
vector
<
std
::
vector
<
internal_tree_node
<
feature_extractor
>>>
all_trees
(
num_trees
);
std
::
vector
<
std
::
vector
<
float
>>
all_leaves
(
num_trees
);
std
::
vector
<
std
::
vector
<
float
>>
all_leaves
(
num_trees
);
const
double
sumy
=
sum
(
mat
(
y
));
const
size_t
feats_per_node
=
std
::
max
(
1.0
,
std
::
round
(
fe
.
max_num_feats
()
*
feature_subsampling_frac
));
const
size_t
feats_per_node
=
std
::
max
(
1.0
,
std
::
round
(
fe
.
max_num_feats
()
*
feature_subsampling_frac
));
// Each tree couldn't have more than this many interior nodes. It might
// Each tree couldn't have more than this many interior nodes. It might
...
@@ -412,15 +410,18 @@ namespace dlib
...
@@ -412,15 +410,18 @@ namespace dlib
// don't make any tree. Just average the things and be done.
// don't make any tree. Just average the things and be done.
if
(
y
.
size
()
<=
min_samples_per_leaf
)
if
(
y
.
size
()
<=
min_samples_per_leaf
)
{
{
leaves
.
push_back
(
sumy
/
y
.
size
(
));
leaves
.
push_back
(
mean
(
mat
(
y
)
));
return
;
return
;
}
}
double
sumy
=
0
;
// pick a random bootstrap of the data.
// pick a random bootstrap of the data.
std
::
vector
<
std
::
pair
<
float
,
uint32_t
>>
idxs
(
y
.
size
());
std
::
vector
<
std
::
pair
<
float
,
uint32_t
>>
idxs
(
y
.
size
());
for
(
auto
&
idx
:
idxs
)
for
(
auto
&
idx
:
idxs
)
{
idx
=
std
::
make_pair
(
0.0
f
,
static_cast
<
uint32_t
>
(
rnd
.
get_integer
(
y
.
size
())));
idx
=
std
::
make_pair
(
0.0
f
,
static_cast
<
uint32_t
>
(
rnd
.
get_integer
(
y
.
size
())));
sumy
+=
y
[
idx
.
second
];
}
// We are going to use ranges_to_process as a stack that tracks which
// We are going to use ranges_to_process as a stack that tracks which
// range of samples we are going to split next.
// range of samples we are going to split next.
...
@@ -702,7 +703,7 @@ namespace dlib
...
@@ -702,7 +703,7 @@ namespace dlib
for
(
auto
i
=
range
.
begin
;
i
<
range
.
end
;
++
i
)
for
(
auto
i
=
range
.
begin
;
i
<
range
.
end
;
++
i
)
idxs
[
i
].
first
=
fe
.
extract_feature_value
(
x
[
idxs
[
i
].
second
],
feat
);
idxs
[
i
].
first
=
fe
.
extract_feature_value
(
x
[
idxs
[
i
].
second
],
feat
);
std
::
sort
(
idxs
.
begin
()
+
range
.
begin
,
idxs
.
begin
()
+
range
.
end
,
compare_first
);
std
::
stable_
sort
(
idxs
.
begin
()
+
range
.
begin
,
idxs
.
begin
()
+
range
.
end
,
compare_first
);
auto
split
=
find_best_split
(
range
,
y
,
idxs
);
auto
split
=
find_best_split
(
range
,
y
,
idxs
);
...
@@ -716,7 +717,7 @@ namespace dlib
...
@@ -716,7 +717,7 @@ namespace dlib
// resort idxs based on winning feat
// resort idxs based on winning feat
for
(
auto
i
=
range
.
begin
;
i
<
range
.
end
;
++
i
)
for
(
auto
i
=
range
.
begin
;
i
<
range
.
end
;
++
i
)
idxs
[
i
].
first
=
fe
.
extract_feature_value
(
x
[
idxs
[
i
].
second
],
best
.
split_feature
);
idxs
[
i
].
first
=
fe
.
extract_feature_value
(
x
[
idxs
[
i
].
second
],
best
.
split_feature
);
std
::
sort
(
idxs
.
begin
()
+
range
.
begin
,
idxs
.
begin
()
+
range
.
end
,
compare_first
);
std
::
stable_
sort
(
idxs
.
begin
()
+
range
.
begin
,
idxs
.
begin
()
+
range
.
end
,
compare_first
);
return
best
;
return
best
;
}
}
...
...
dlib/test/random_forest.cpp
View file @
facefa02
...
@@ -62,15 +62,23 @@ namespace
...
@@ -62,15 +62,23 @@ namespace
DLIB_TEST
(
df
.
get_num_trees
()
==
1000
);
DLIB_TEST
(
df
.
get_num_trees
()
==
1000
);
auto
result
=
test_regression_function
(
df
,
samples
,
labels
);
auto
result
=
test_regression_function
(
df
,
samples
,
labels
);
// train:
2.239 0.987173 0.970669 1.1399
// train:
1.95064 0.990374 0.92738 1.04536
dlog
<<
LINFO
<<
"train: "
<<
result
;
dlog
<<
LINFO
<<
"train: "
<<
result
;
DLIB_TEST_MSG
(
result
(
0
)
<
2.3
,
result
(
0
));
DLIB_TEST_MSG
(
result
(
0
)
<
2.0
,
result
(
0
));
// By construction, output values should be in the span of the training labels.
const
double
min_label
=
min
(
mat
(
labels
));
const
double
max_label
=
max
(
mat
(
labels
));
for
(
auto
&&
x
:
samples
)
{
double
y
=
df
(
x
);
DLIB_TEST
(
min_label
<=
y
&&
y
<=
max_label
);
}
running_stats
<
double
>
rs
;
running_stats
<
double
>
rs
;
for
(
size_t
i
=
0
;
i
<
oobs
.
size
();
++
i
)
for
(
size_t
i
=
0
;
i
<
oobs
.
size
();
++
i
)
rs
.
add
(
std
::
pow
(
oobs
[
i
]
-
labels
[
i
],
2.0
));
rs
.
add
(
std
::
pow
(
oobs
[
i
]
-
labels
[
i
],
2.0
));
dlog
<<
LINFO
<<
"OOB MSE: "
<<
rs
.
mean
();
dlog
<<
LINFO
<<
"OOB MSE: "
<<
rs
.
mean
();
DLIB_TEST_MSG
(
rs
.
mean
()
<
10.
2
,
rs
.
mean
());
DLIB_TEST_MSG
(
rs
.
mean
()
<
10.
0
,
rs
.
mean
());
print_spinner
();
print_spinner
();
...
@@ -80,9 +88,9 @@ namespace
...
@@ -80,9 +88,9 @@ namespace
deserialize
(
df2
,
ss
);
deserialize
(
df2
,
ss
);
DLIB_TEST
(
df2
.
get_num_trees
()
==
1000
);
DLIB_TEST
(
df2
.
get_num_trees
()
==
1000
);
result
=
test_regression_function
(
df2
,
samples
,
labels
);
result
=
test_regression_function
(
df2
,
samples
,
labels
);
// train:
2.239 0.987173 0.970669 1.1399
// train:
1.95064 0.990374 0.92738 1.04536
dlog
<<
LINFO
<<
"serialized train results: "
<<
result
;
dlog
<<
LINFO
<<
"serialized train results: "
<<
result
;
DLIB_TEST_MSG
(
result
(
0
)
<
2.
3
,
result
(
0
));
DLIB_TEST_MSG
(
result
(
0
)
<
2.
0
,
result
(
0
));
}
}
}
a
;
}
a
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment