Unverified Commit 7df4ebed authored by Hakjin Lee's avatar Hakjin Lee Committed by GitHub
Browse files

[Fix] Fix the mismatch torch version of MlflowLoggerHook(#1680)

* [Fix] mlflow logger error

* [Update] mlflow hook test arguments

* Replace torch.__version__ with mmcv.TORCH_VERSION

* update test code
parent 26aba2f5
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import TORCH_VERSION
from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook
......@@ -73,4 +74,7 @@ class MlflowLoggerHook(LoggerHook):
@master_only
def after_run(self, runner):
if self.log_model:
self.mlflow_pytorch.log_model(runner.model, 'models')
self.mlflow_pytorch.log_model(
runner.model,
'models',
pip_requirements=[f'torch=={TORCH_VERSION}'])
......@@ -34,6 +34,7 @@ from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook,
FlatCosineAnnealingLrUpdaterHook,
OneCycleLrUpdaterHook,
StepLrUpdaterHook)
from mmcv.utils import TORCH_VERSION
sys.modules['petrel_client'] = MagicMock()
sys.modules['petrel_client.client'] = MagicMock()
......@@ -1392,7 +1393,9 @@ def test_mlflow_hook(log_model):
}, step=6)
if log_model:
hook.mlflow_pytorch.log_model.assert_called_with(
runner.model, 'models')
runner.model,
'models',
pip_requirements=[f'torch=={TORCH_VERSION}'])
else:
assert not hook.mlflow_pytorch.log_model.called
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment