"configs/wan_t2v_dist.json" did not exist on "56af41ebaf3d5420736be25f96aca06b910a3447"
llava_trainer_eval.py 3.36 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import json
import subprocess

from llava.train.llava_trainer import LLaVATrainer


class LLaVAEvalTrainer(LLaVATrainer):
    def evaluate(self, evaluate_args):
        cmd = f"accelerate launch --num_processes {evaluate_args.eval_num_processes} -m lmms_eval \
                --model {evaluate_args.model} \
                --model_args {evaluate_args.model_args} \
                --tasks {evaluate_args.task_names} \
                --batch_size {evaluate_args.batch_size} \
                --log_samples_suffix {evaluate_args.log_samples_suffix} \
                --output_path {evaluate_args.output_path}"
        if evaluate_args.limit:
            cmd += f" --limit {evaluate_args.limit}"
        if evaluate_args.num_fewshot:
            cmd += f" --num_fewshot {evaluate_args.num_fewshot}"
        if evaluate_args.gen_kwargs != "":
            cmd += f" --gen_kwargs {evaluate_args.gen_kwargs}"
        if evaluate_args.log_samples:
            cmd += f" --log_samples"
        else:
            assert False, "Please log samples so that the result can be parsed"
        results = subprocess.run([cmd], shell=True, capture_output=True, text=True)
        try:
            result_file_index_start = results.stdout.index("Saved samples to ")
            result_file_index_end = results.stdout.index(f".json")
            result_file_index_start += len("Saved samples to ")
            file = results.stdout[result_file_index_start:result_file_index_end]
        except:
            result_file_index_start = results.stderr.index("Saved samples to ")
            result_file_index_end = results.stderr.index(f".json")
            result_file_index_start += len("Saved samples to ")
            file = results.stderr[result_file_index_start:result_file_index_end]
        file = file.split("/")[:-1]
        file = "/".join(file) + "/results.json"
        with open(file, "r") as f:
            lmms_eval_results = json.load(f)
        result_dict = {}
        tasks_list = evaluate_args.task_names.split(",")
        for task in tasks_list:
            task_results = lmms_eval_results["results"][task]
            for k, v in task_results.items():
                if k != "alias" and "stderr" not in k:
                    metric = k.split(",")[0]
                    result_dict[f"{task}_{metric}"] = v
        return result_dict

    """def evaluate(self, evaluate_args):
        initialize_tasks()
        tasks_list = evaluate_args.task_names.split(",")
        result_dict = {}
        results = evaluator.simple_evaluate(
            model=evaluate_args.model,
            model_args=evaluate_args.model_args,
            tasks=tasks_list,
            num_fewshot=evaluate_args.num_fewshot,
            batch_size=evaluate_args.batch_size,
            device=evaluate_args.device,
            limit=evaluate_args.limit,
            check_integrity=evaluate_args.check_integrity,
            show_task_to_terminal=evaluate_args.show_task_to_terminal,
            log_samples=evaluate_args.log_samples,
            gen_kwargs=evaluate_args.gen_kwargs,
            cli_args=evaluate_args,
        )
        for task in tasks_list:
            task_results = results["results"][task]
            for k,v in task_results.items():
                if k != "alias" and "stderr" not in k:
                    metric = k.split(",")[0]
                    result_dict[f"{task}_{metric}"] = v
            
        return result_dict"""