Unverified Commit 3a46b6c6 authored by Yike Yuan's avatar Yike Yuan Committed by GitHub
Browse files

[Fix] Fix bugs of multiple rounds of inference when using mm_eval (#201)

parent 4fc17012
...@@ -6,6 +6,7 @@ import random ...@@ -6,6 +6,7 @@ import random
import time import time
from typing import List, Sequence from typing import List, Sequence
import mmengine
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from mmengine.config import Config, ConfigDict from mmengine.config import Config, ConfigDict
...@@ -75,8 +76,8 @@ class MultimodalInferTask: ...@@ -75,8 +76,8 @@ class MultimodalInferTask:
dataset_name = self.dataloader['dataset']['type'] dataset_name = self.dataloader['dataset']['type']
evaluator_name = self.evaluator[0]['type'] evaluator_name = self.evaluator[0]['type']
return osp.join(model_name, return osp.join(self.cfg.work_dir, model_name, dataset_name,
f'{dataset_name}-{evaluator_name}.{file_extension}') f'{evaluator_name}.{file_extension}')
def get_output_paths(self, file_extension: str = 'json') -> List[str]: def get_output_paths(self, file_extension: str = 'json') -> List[str]:
"""Get the path to the output file. """Get the path to the output file.
...@@ -90,7 +91,7 @@ class MultimodalInferTask: ...@@ -90,7 +91,7 @@ class MultimodalInferTask:
evaluator_name = self.evaluator[0]['type'] evaluator_name = self.evaluator[0]['type']
return [ return [
osp.join(model_name, dataset_name, osp.join(self.cfg.work_dir, model_name, dataset_name,
f'{evaluator_name}.{file_extension}') f'{evaluator_name}.{file_extension}')
] ]
...@@ -134,7 +135,8 @@ class MultimodalInferTask: ...@@ -134,7 +135,8 @@ class MultimodalInferTask:
evaluator.process(data_samples) evaluator.process(data_samples)
metrics = evaluator.evaluate(len(dataloader.dataset)) metrics = evaluator.evaluate(len(dataloader.dataset))
metrics_file = osp.join(cfg.work_dir, 'res.log') metrics_file = self.get_output_paths()[0]
mmengine.mkdir_or_exist(osp.split(metrics_file)[0])
with open(metrics_file, 'w') as f: with open(metrics_file, 'w') as f:
json.dump(metrics, f) json.dump(metrics, f)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment