Unverified Commit f36b62a9 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Update sklearn regression example (#2330)

parent ae72aec8
......@@ -24,9 +24,9 @@ import numpy as np
from sklearn.metrics import r2_score
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression
from sklearn.svm import SVR
from sklearn.neighbors import KNeighborsRegressor
from sklearn.tree import DecisionTreeRegressor
from sklearn.linear_model import Ridge
from sklearn.linear_model import Lars
from sklearn.linear_model import ARDRegression
LOG = logging.getLogger('sklearn_regression')
......@@ -55,24 +55,18 @@ def get_model(PARAMS):
'''Get model according to parameters'''
model_dict = {
'LinearRegression': LinearRegression(),
'SVR': SVR(),
'KNeighborsRegressor': KNeighborsRegressor(),
'DecisionTreeRegressor': DecisionTreeRegressor()
'Ridge': Ridge(),
'Lars': Lars(),
'ARDRegression': ARDRegression()
}
if not model_dict.get(PARAMS['model_name']):
LOG.exception('Not supported model!')
exit(1)
model = model_dict[PARAMS['model_name']]
model.normalize = bool(PARAMS['normalize'])
try:
if PARAMS['model_name'] == 'SVR':
model.kernel = PARAMS['svr_kernel']
elif PARAMS['model_name'] == 'KNeighborsRegressor':
model.weights = PARAMS['knr_weights']
except Exception as exception:
LOG.exception(exception)
raise
return model
def run(X_train, X_test, y_train, y_test, model):
......
{
"model_name":{"_type":"choice","_value":["LinearRegression", "SVR", "KNeighborsRegressor", "DecisionTreeRegressor"]},
"svr_kernel": {"_type":"choice","_value":["linear", "poly", "rbf"]},
"knr_weights": {"_type":"choice","_value":["uniform", "distance"]}
"model_name":{"_type":"choice","_value":["LinearRegression", "Lars", "Ridge", "ARDRegression"]},
"normalize": {"_type":"choice","_value":["true", "false"]}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment