test.py 3.16 KB
Newer Older
yuhai's avatar
yuhai committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import numpy as np
import torch
import torch.nn as nn
try:
    import deepks
except ImportError as e:
    import sys
    sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/../../")
from deepks.model.model import CorrNet
from deepks.model.reader import GroupReader
from deepks.utils import load_yaml, load_dirs, check_list


DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def test(model, g_reader, dump_prefix="test", group=False):
    model.eval()
    loss_fn=nn.MSELoss()
    label_list = []
    pred_list = []

    for i in range(g_reader.nsystems):
        sample = g_reader.sample_all(i)
        nframes = sample["lb_e"].shape[0]
        sample = {k: v.to(DEVICE, non_blocking=True) for k, v in sample.items()}
        label, data = sample["lb_e"], sample["eig"]
        pred = model(data)
        error = torch.sqrt(loss_fn(pred, label))

        error_np = error.item()
        label_np = label.cpu().numpy().reshape(nframes, -1).sum(axis=1)
        pred_np = pred.detach().cpu().numpy().reshape(nframes, -1).sum(axis=1)
        error_l1 = np.mean(np.abs(label_np - pred_np))
        label_list.append(label_np)
        pred_list.append(pred_np)

        if not group and dump_prefix is not None:
            nd = max(len(str(g_reader.nsystems)), 2)
            dump_res = np.stack([label_np, pred_np], axis=1)
            header = f"{g_reader.path_list[i]}\nmean l1 error: {error_l1}\nmean l2 error: {error_np}\nreal_ene  pred_ene"
            filename = f"{dump_prefix}.{i:0{nd}}.out"
            np.savetxt(filename, dump_res, header=header)
            # print(f"system {i} finished")

    all_label = np.concatenate(label_list, axis=0)
    all_pred = np.concatenate(pred_list, axis=0)
    all_err_l1 = np.mean(np.abs(all_label - all_pred))
    all_err_l2 = np.sqrt(np.mean((all_label - all_pred) ** 2))
    info = f"all systems mean l1 error: {all_err_l1}\nall systems mean l2 error: {all_err_l2}"
    print(info)
    if dump_prefix is not None and group:
        np.savetxt(f"{dump_prefix}.out", np.stack([all_label, all_pred], axis=1), 
                   header=info + "\nreal_ene  pred_ene")
    return all_err_l1, all_err_l2


def main(data_paths, model_file="model.pth", 
         output_prefix='test', group=False,
         e_name='l_e_delta', d_name=['dm_eig']):
    data_paths = load_dirs(data_paths)
    if len(d_name) == 1:
        d_name = d_name[0]
    g_reader = GroupReader(data_paths, e_name=e_name, d_name=d_name, 
                           conv_filter=False, extra_label=True)
    model_file = check_list(model_file)
    for f in model_file:
        print(f)
        p = os.path.dirname(f)
        model = CorrNet.load(f).double().to(DEVICE)
        dump = os.path.join(p, output_prefix)
        dir_name = os.path.dirname(dump)
        if dir_name:
            os.makedirs(dir_name, exist_ok=True)
        if model.elem_table is not None:
            elist, econst = model.elem_table
            g_reader.collect_elems(elist)
            g_reader.subtract_elem_const(econst)
        test(model, g_reader, dump_prefix=dump, group=group)
        g_reader.revert_elem_const()


if __name__ == "__main__":
    from deepks.main import test_cli as cli
    cli()