Unverified Commit de9e625f authored by eivtho's avatar eivtho Committed by GitHub
Browse files

[Feature] Add argument params to MlflowLoggerHook (#2193)

parent 595c2ebe
...@@ -20,6 +20,8 @@ class MlflowLoggerHook(LoggerHook): ...@@ -20,6 +20,8 @@ class MlflowLoggerHook(LoggerHook):
will be created. will be created.
tags (Dict[str], optional): Tags for the current run. tags (Dict[str], optional): Tags for the current run.
Default None. If not None, set tags for the current run. Default None. If not None, set tags for the current run.
params (Dict[str], optional): Params for the current run.
Default None. If not None, set params for the current run.
log_model (bool, optional): Whether to log an MLflow artifact. log_model (bool, optional): Whether to log an MLflow artifact.
Default True. If True, log runner.model as an MLflow artifact Default True. If True, log runner.model as an MLflow artifact
for the current run. for the current run.
...@@ -37,6 +39,7 @@ class MlflowLoggerHook(LoggerHook): ...@@ -37,6 +39,7 @@ class MlflowLoggerHook(LoggerHook):
def __init__(self, def __init__(self,
exp_name: Optional[str] = None, exp_name: Optional[str] = None,
tags: Optional[Dict] = None, tags: Optional[Dict] = None,
params: Optional[Dict] = None,
log_model: bool = True, log_model: bool = True,
interval: int = 10, interval: int = 10,
ignore_last: bool = True, ignore_last: bool = True,
...@@ -46,6 +49,7 @@ class MlflowLoggerHook(LoggerHook): ...@@ -46,6 +49,7 @@ class MlflowLoggerHook(LoggerHook):
self.import_mlflow() self.import_mlflow()
self.exp_name = exp_name self.exp_name = exp_name
self.tags = tags self.tags = tags
self.params = params
self.log_model = log_model self.log_model = log_model
def import_mlflow(self) -> None: def import_mlflow(self) -> None:
...@@ -65,6 +69,8 @@ class MlflowLoggerHook(LoggerHook): ...@@ -65,6 +69,8 @@ class MlflowLoggerHook(LoggerHook):
self.mlflow.set_experiment(self.exp_name) self.mlflow.set_experiment(self.exp_name)
if self.tags is not None: if self.tags is not None:
self.mlflow.set_tags(self.tags) self.mlflow.set_tags(self.tags)
if self.params is not None:
self.mlflow.log_params(self.params)
@master_only @master_only
def log(self, runner) -> None: def log(self, runner) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment