Unverified Commit bcf02ec7 authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

Update hyperparameter_search.py (#24515)

* Update hyperparameter_search.py

* resolve comments
parent 6fe8d198
...@@ -40,7 +40,8 @@ class HyperParamSearchBackendBase: ...@@ -40,7 +40,8 @@ class HyperParamSearchBackendBase:
name: str name: str
pip_package: str = None pip_package: str = None
def is_available(self): @staticmethod
def is_available():
raise NotImplementedError raise NotImplementedError
def run(self, trainer, n_trials: int, direction: str, **kwargs): def run(self, trainer, n_trials: int, direction: str, **kwargs):
...@@ -63,7 +64,8 @@ class HyperParamSearchBackendBase: ...@@ -63,7 +64,8 @@ class HyperParamSearchBackendBase:
class OptunaBackend(HyperParamSearchBackendBase): class OptunaBackend(HyperParamSearchBackendBase):
name = "optuna" name = "optuna"
def is_available(self): @staticmethod
def is_available():
return is_optuna_available() return is_optuna_available()
def run(self, trainer, n_trials: int, direction: str, **kwargs): def run(self, trainer, n_trials: int, direction: str, **kwargs):
...@@ -77,7 +79,8 @@ class RayTuneBackend(HyperParamSearchBackendBase): ...@@ -77,7 +79,8 @@ class RayTuneBackend(HyperParamSearchBackendBase):
name = "ray" name = "ray"
pip_package = "'ray[tune]'" pip_package = "'ray[tune]'"
def is_available(self): @staticmethod
def is_available():
return is_ray_available() return is_ray_available()
def run(self, trainer, n_trials: int, direction: str, **kwargs): def run(self, trainer, n_trials: int, direction: str, **kwargs):
...@@ -90,7 +93,8 @@ class RayTuneBackend(HyperParamSearchBackendBase): ...@@ -90,7 +93,8 @@ class RayTuneBackend(HyperParamSearchBackendBase):
class SigOptBackend(HyperParamSearchBackendBase): class SigOptBackend(HyperParamSearchBackendBase):
name = "sigopt" name = "sigopt"
def is_available(self): @staticmethod
def is_available():
return is_sigopt_available() return is_sigopt_available()
def run(self, trainer, n_trials: int, direction: str, **kwargs): def run(self, trainer, n_trials: int, direction: str, **kwargs):
...@@ -103,7 +107,8 @@ class SigOptBackend(HyperParamSearchBackendBase): ...@@ -103,7 +107,8 @@ class SigOptBackend(HyperParamSearchBackendBase):
class WandbBackend(HyperParamSearchBackendBase): class WandbBackend(HyperParamSearchBackendBase):
name = "wandb" name = "wandb"
def is_available(self): @staticmethod
def is_available():
return is_wandb_available() return is_wandb_available()
def run(self, trainer, n_trials: int, direction: str, **kwargs): def run(self, trainer, n_trials: int, direction: str, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment