Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
f36b62a9
Unverified
Commit
f36b62a9
authored
Apr 17, 2020
by
SparkSnail
Committed by
GitHub
Apr 17, 2020
Browse files
Update sklearn regression example (#2330)
parent
ae72aec8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
17 deletions
+10
-17
examples/trials/sklearn/regression/main.py
examples/trials/sklearn/regression/main.py
+8
-14
examples/trials/sklearn/regression/search_space.json
examples/trials/sklearn/regression/search_space.json
+2
-3
No files found.
examples/trials/sklearn/regression/main.py
View file @
f36b62a9
...
@@ -24,9 +24,9 @@ import numpy as np
...
@@ -24,9 +24,9 @@ import numpy as np
from
sklearn.metrics
import
r2_score
from
sklearn.metrics
import
r2_score
from
sklearn.preprocessing
import
StandardScaler
from
sklearn.preprocessing
import
StandardScaler
from
sklearn.linear_model
import
LinearRegression
from
sklearn.linear_model
import
LinearRegression
from
sklearn.
svm
import
SVR
from
sklearn.
linear_model
import
Ridge
from
sklearn.
neighbors
import
KNeighborsRegressor
from
sklearn.
linear_model
import
Lars
from
sklearn.
tree
import
DecisionTree
Regress
or
from
sklearn.
linear_model
import
ARD
Regress
ion
LOG
=
logging
.
getLogger
(
'sklearn_regression'
)
LOG
=
logging
.
getLogger
(
'sklearn_regression'
)
...
@@ -55,24 +55,18 @@ def get_model(PARAMS):
...
@@ -55,24 +55,18 @@ def get_model(PARAMS):
'''Get model according to parameters'''
'''Get model according to parameters'''
model_dict
=
{
model_dict
=
{
'LinearRegression'
:
LinearRegression
(),
'LinearRegression'
:
LinearRegression
(),
'SVR'
:
SVR
(),
'Ridge'
:
Ridge
(),
'KNeighborsRegressor'
:
KNeighborsRegressor
(),
'Lars'
:
Lars
(),
'DecisionTreeRegressor'
:
DecisionTreeRegressor
()
'ARDRegression'
:
ARDRegression
()
}
}
if
not
model_dict
.
get
(
PARAMS
[
'model_name'
]):
if
not
model_dict
.
get
(
PARAMS
[
'model_name'
]):
LOG
.
exception
(
'Not supported model!'
)
LOG
.
exception
(
'Not supported model!'
)
exit
(
1
)
exit
(
1
)
model
=
model_dict
[
PARAMS
[
'model_name'
]]
model
=
model_dict
[
PARAMS
[
'model_name'
]]
model
.
normalize
=
bool
(
PARAMS
[
'normalize'
])
try
:
if
PARAMS
[
'model_name'
]
==
'SVR'
:
model
.
kernel
=
PARAMS
[
'svr_kernel'
]
elif
PARAMS
[
'model_name'
]
==
'KNeighborsRegressor'
:
model
.
weights
=
PARAMS
[
'knr_weights'
]
except
Exception
as
exception
:
LOG
.
exception
(
exception
)
raise
return
model
return
model
def
run
(
X_train
,
X_test
,
y_train
,
y_test
,
model
):
def
run
(
X_train
,
X_test
,
y_train
,
y_test
,
model
):
...
...
examples/trials/sklearn/regression/search_space.json
View file @
f36b62a9
{
{
"model_name"
:{
"_type"
:
"choice"
,
"_value"
:[
"LinearRegression"
,
"SVR"
,
"KNeighborsRegressor"
,
"DecisionTreeRegressor"
]},
"model_name"
:{
"_type"
:
"choice"
,
"_value"
:[
"LinearRegression"
,
"Lars"
,
"Ridge"
,
"ARDRegression"
]},
"svr_kernel"
:
{
"_type"
:
"choice"
,
"_value"
:[
"linear"
,
"poly"
,
"rbf"
]},
"normalize"
:
{
"_type"
:
"choice"
,
"_value"
:[
"true"
,
"false"
]}
"knr_weights"
:
{
"_type"
:
"choice"
,
"_value"
:[
"uniform"
,
"distance"
]}
}
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment