get_metrics.py 2.7 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import argparse
import json
import os

from data import get_dataset
from metrics.fid import compute_fid
from metrics.image_reward import compute_image_reward
from metrics.multimodal import compute_image_multimodal_metrics
from metrics.similarity import compute_image_similarity_metrics


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("input_roots", type=str, nargs="*")
    parser.add_argument("-o", "--output-path", type=str, default="metrics.json", help="Image output path")
16
17
18
19
20
21
    parser.add_argument(
        "--max-dataset-size",
        type=int,
        default=1024,
        help="Maximum number of images to compute metrics for each dataset",
    )
Zhekai Zhang's avatar
Zhekai Zhang committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    args = parser.parse_args()
    return args


def main():
    args = get_args()
    assert len(args.input_roots) > 0
    assert len(args.input_roots) <= 2

    image_root1 = args.input_roots[0]
    if len(args.input_roots) > 1:
        image_root2 = args.input_roots[1]
    else:
        image_root2 = None

    results = {}
    dataset_names = sorted(os.listdir(image_root1))
    for dataset_name in dataset_names:
40
41
42
        if image_root2 is not None and dataset_name not in os.listdir(image_root2):
            continue
        print("Results for dataset:", dataset_name)
Zhekai Zhang's avatar
Zhekai Zhang committed
43
        results[dataset_name] = {}
44
        dataset = get_dataset(name=dataset_name, return_gt=True, max_dataset_size=args.max_dataset_size)
Zhekai Zhang's avatar
Zhekai Zhang committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        fid = compute_fid(ref_dirpath_or_dataset=dataset, gen_dirpath=os.path.join(image_root1, dataset_name))
        results[dataset_name]["fid"] = fid
        print("FID:", fid)
        multimodal_metrics = compute_image_multimodal_metrics(
            ref_dataset=dataset, gen_dirpath=os.path.join(image_root1, dataset_name)
        )
        results[dataset_name].update(multimodal_metrics)
        for k, v in multimodal_metrics.items():
            print(f"{k}:", v)
        image_reward = compute_image_reward(ref_dataset=dataset, gen_dirpath=os.path.join(image_root1, dataset_name))
        results[dataset_name].update(image_reward)
        for k, v in image_reward.items():
            print(f"{k}:", v)

        if image_root2 is not None and os.path.exists(os.path.join(image_root2, dataset_name)):
            similarity_results = compute_image_similarity_metrics(
                os.path.join(image_root1, dataset_name), os.path.join(image_root2, dataset_name)
            )
            results[dataset_name].update(similarity_results)
            for k, v in similarity_results.items():
                print(f"{k}:", v)

    os.makedirs(os.path.dirname(os.path.abspath(args.output_path)), exist_ok=True)
    with open(args.output_path, "w") as f:
        json.dump(results, f, indent=2, sort_keys=True)


if __name__ == "__main__":
    main()