Unverified Commit de9e625f authored by eivtho's avatar eivtho Committed by GitHub
Browse files

[Feature] Add argument params to MlflowLoggerHook (#2193)

parent 595c2ebe
......@@ -20,6 +20,8 @@ class MlflowLoggerHook(LoggerHook):
will be created.
tags (Dict[str], optional): Tags for the current run.
Default None. If not None, set tags for the current run.
params (Dict[str], optional): Params for the current run.
Default None. If not None, set params for the current run.
log_model (bool, optional): Whether to log an MLflow artifact.
Default True. If True, log runner.model as an MLflow artifact
for the current run.
......@@ -37,6 +39,7 @@ class MlflowLoggerHook(LoggerHook):
def __init__(self,
exp_name: Optional[str] = None,
tags: Optional[Dict] = None,
params: Optional[Dict] = None,
log_model: bool = True,
interval: int = 10,
ignore_last: bool = True,
......@@ -46,6 +49,7 @@ class MlflowLoggerHook(LoggerHook):
self.import_mlflow()
self.exp_name = exp_name
self.tags = tags
self.params = params
self.log_model = log_model
def import_mlflow(self) -> None:
......@@ -65,6 +69,8 @@ class MlflowLoggerHook(LoggerHook):
self.mlflow.set_experiment(self.exp_name)
if self.tags is not None:
self.mlflow.set_tags(self.tags)
if self.params is not None:
self.mlflow.log_params(self.params)
@master_only
def log(self, runner) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment