get_metrics.py 2.39 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import argparse
import json
import os

from data import get_dataset
from metrics.fid import compute_fid
from metrics.image_reward import compute_image_reward
from metrics.multimodal import compute_image_multimodal_metrics
from metrics.similarity import compute_image_similarity_metrics


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("input_roots", type=str, nargs="*")
    parser.add_argument("-o", "--output-path", type=str, default="metrics.json", help="Image output path")
    args = parser.parse_args()
    return args


def main():
    args = get_args()
    assert len(args.input_roots) > 0
    assert len(args.input_roots) <= 2

    image_root1 = args.input_roots[0]
    if len(args.input_roots) > 1:
        image_root2 = args.input_roots[1]
    else:
        image_root2 = None

    results = {}
    dataset_names = sorted(os.listdir(image_root1))
    for dataset_name in dataset_names:
        print("##Results for dataset:", dataset_name)
        results[dataset_name] = {}
        dataset = get_dataset(name=dataset_name, return_gt=True)
        fid = compute_fid(ref_dirpath_or_dataset=dataset, gen_dirpath=os.path.join(image_root1, dataset_name))
        results[dataset_name]["fid"] = fid
        print("FID:", fid)
        multimodal_metrics = compute_image_multimodal_metrics(
            ref_dataset=dataset, gen_dirpath=os.path.join(image_root1, dataset_name)
        )
        results[dataset_name].update(multimodal_metrics)
        for k, v in multimodal_metrics.items():
            print(f"{k}:", v)
        image_reward = compute_image_reward(ref_dataset=dataset, gen_dirpath=os.path.join(image_root1, dataset_name))
        results[dataset_name].update(image_reward)
        for k, v in image_reward.items():
            print(f"{k}:", v)

        if image_root2 is not None and os.path.exists(os.path.join(image_root2, dataset_name)):
            similarity_results = compute_image_similarity_metrics(
                os.path.join(image_root1, dataset_name), os.path.join(image_root2, dataset_name)
            )
            results[dataset_name].update(similarity_results)
            for k, v in similarity_results.items():
                print(f"{k}:", v)

    os.makedirs(os.path.dirname(os.path.abspath(args.output_path)), exist_ok=True)
    with open(args.output_path, "w") as f:
        json.dump(results, f, indent=2, sort_keys=True)


if __name__ == "__main__":
    main()