Unverified Commit f36b62a9 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Update sklearn regression example (#2330)

parent ae72aec8
...@@ -24,9 +24,9 @@ import numpy as np ...@@ -24,9 +24,9 @@ import numpy as np
from sklearn.metrics import r2_score from sklearn.metrics import r2_score
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression from sklearn.linear_model import LinearRegression
from sklearn.svm import SVR from sklearn.linear_model import Ridge
from sklearn.neighbors import KNeighborsRegressor from sklearn.linear_model import Lars
from sklearn.tree import DecisionTreeRegressor from sklearn.linear_model import ARDRegression
LOG = logging.getLogger('sklearn_regression') LOG = logging.getLogger('sklearn_regression')
...@@ -55,24 +55,18 @@ def get_model(PARAMS): ...@@ -55,24 +55,18 @@ def get_model(PARAMS):
'''Get model according to parameters''' '''Get model according to parameters'''
model_dict = { model_dict = {
'LinearRegression': LinearRegression(), 'LinearRegression': LinearRegression(),
'SVR': SVR(), 'Ridge': Ridge(),
'KNeighborsRegressor': KNeighborsRegressor(), 'Lars': Lars(),
'DecisionTreeRegressor': DecisionTreeRegressor() 'ARDRegression': ARDRegression()
} }
if not model_dict.get(PARAMS['model_name']): if not model_dict.get(PARAMS['model_name']):
LOG.exception('Not supported model!') LOG.exception('Not supported model!')
exit(1) exit(1)
model = model_dict[PARAMS['model_name']] model = model_dict[PARAMS['model_name']]
model.normalize = bool(PARAMS['normalize'])
try:
if PARAMS['model_name'] == 'SVR':
model.kernel = PARAMS['svr_kernel']
elif PARAMS['model_name'] == 'KNeighborsRegressor':
model.weights = PARAMS['knr_weights']
except Exception as exception:
LOG.exception(exception)
raise
return model return model
def run(X_train, X_test, y_train, y_test, model): def run(X_train, X_test, y_train, y_test, model):
......
{ {
"model_name":{"_type":"choice","_value":["LinearRegression", "SVR", "KNeighborsRegressor", "DecisionTreeRegressor"]}, "model_name":{"_type":"choice","_value":["LinearRegression", "Lars", "Ridge", "ARDRegression"]},
"svr_kernel": {"_type":"choice","_value":["linear", "poly", "rbf"]}, "normalize": {"_type":"choice","_value":["true", "false"]}
"knr_weights": {"_type":"choice","_value":["uniform", "distance"]}
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment