predict.py 7.77 KB
Newer Older
chicm-ms's avatar
chicm-ms committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

import os
import glob
import argparse
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F

import settings
from loader import get_test_loader, add_depth_channel
from models import UNetResNetV4, UNetResNetV5, UNetResNetV6, UNet7, UNet8
from postprocessing import crop_image, binarize, resize_image
from metrics import intersection_over_union, intersection_over_union_thresholds
from utils import create_submission

def do_tta_predict(args, model, ckp_path, tta_num=4):
    '''
    return 18000x128x128 np array
    '''
    model.eval()
    preds = []
    meta = None

    # i is tta index, 0: no change, 1: horizon flip, 2: vertical flip, 3: do both
    for flip_index in range(tta_num):
        print('flip_index:', flip_index)
        test_loader = get_test_loader(args.batch_size, index=flip_index, dev_mode=False, pad_mode=args.pad_mode)
        meta = test_loader.meta
        outputs = None
        with torch.no_grad():
            for i, img in enumerate(test_loader):
                add_depth_channel(img, args.pad_mode)
                img = img.cuda()
                output, _ = model(img)
                output = torch.sigmoid(output)
                if outputs is None:
                    outputs = output.squeeze()
                else:
                    outputs = torch.cat([outputs, output.squeeze()], 0)

                print('{} / {}'.format(args.batch_size*(i+1), test_loader.num), end='\r')
        outputs = outputs.cpu().numpy()
        # flip back masks
        if flip_index == 1:
            outputs = np.flip(outputs, 2)
        elif flip_index == 2:
            outputs = np.flip(outputs, 1)
        elif flip_index == 3:
            outputs = np.flip(outputs, 2)
            outputs = np.flip(outputs, 1)
        #print(outputs.shape)
        preds.append(outputs)
74

chicm-ms's avatar
chicm-ms committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    parent_dir = ckp_path+'_out'
    if not os.path.exists(parent_dir):
        os.makedirs(parent_dir)
    np_file = os.path.join(parent_dir, 'pred.npy')

    model_pred_result = np.mean(preds, 0)
    np.save(np_file, model_pred_result)

    return model_pred_result, meta

def predict(args, model, checkpoint, out_file):
    print('predicting {}...'.format(checkpoint))
    pred, meta = do_tta_predict(args, model, checkpoint, tta_num=2)
    print(pred.shape)
    y_pred_test = generate_preds(pred, (settings.ORIG_H, settings.ORIG_W), pad_mode=args.pad_mode)

    submission = create_submission(meta, y_pred_test)
    submission.to_csv(out_file, index=None, encoding='utf-8')


def ensemble(args, model, checkpoints):
    preds = []
    meta = None
    for checkpoint in checkpoints:
        model.load_state_dict(torch.load(checkpoint))
        model = model.cuda()
        print('predicting...', checkpoint)

        pred, meta = do_tta_predict(args, model, checkpoint, tta_num=2)
        preds.append(pred)

    y_pred_test = generate_preds(np.mean(preds, 0), (settings.ORIG_H, settings.ORIG_W), args.pad_mode)

    submission = create_submission(meta, y_pred_test)
    submission.to_csv(args.sub_file, index=None, encoding='utf-8')

def ensemble_np(args, np_files, save_np=None):
    preds = []
    for np_file in np_files:
        pred = np.load(np_file)
        print(np_file, pred.shape)
        preds.append(pred)

    y_pred_test = generate_preds(np.mean(preds, 0), (settings.ORIG_H, settings.ORIG_W), args.pad_mode)

    if save_np is not None:
        np.save(save_np, np.mean(preds, 0))

    meta = get_test_loader(args.batch_size, index=0, dev_mode=False, pad_mode=args.pad_mode).meta

    submission = create_submission(meta, y_pred_test)
    submission.to_csv(args.sub_file, index=None, encoding='utf-8')

def generate_preds(outputs, target_size, pad_mode, threshold=0.5):
    preds = []

    for output in outputs:
        #print(output.shape)
        if pad_mode == 'resize':
            cropped = resize_image(output, target_size=target_size)
        else:
            cropped = crop_image_softmax(output, target_size=target_size)
        pred = binarize(cropped, threshold)
        preds.append(pred)

    return preds


def ensemble_predict(args):
    model = eval(args.model_name)(args.layers, num_filters=args.nf)
145

chicm-ms's avatar
chicm-ms committed
146
    checkpoints = [
147
        r'D:\data\salt\models\pseudo\UNetResNetV4_34\edge\best_5.pth',
chicm-ms's avatar
chicm-ms committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        r'D:\data\salt\models\pseudo\UNetResNetV4_34\edge\best_6.pth',
        r'D:\data\salt\models\pseudo\UNetResNetV4_34\edge\best_8.pth',
        r'D:\data\salt\models\pseudo\UNetResNetV4_34\edge\best_9.pth'
    ]
    print(checkpoints)

    ensemble(args, model, checkpoints)

def ensemble_np_results(args):
    np_files1 = glob.glob(r'D:\data\salt\models\depths\UNetResNetV5_50\edge\*pth_out\*.npy')
    np_files2 = glob.glob(r'D:\data\salt\models\depths\UNetResNetV4_34\edge\*pth_out\*.npy')
    np_files3 = glob.glob(r'D:\data\salt\models\depths\UNetResNetV6_34\edge\*pth_out\*.npy')
    np_files6 = glob.glob(r'D:\data\salt\models\ensemble\*.npy')
    np_files = np_files1 + np_files2 + np_files3 + np_files6
    print(np_files)
163
    ensemble_np(args, np_files)
chicm-ms's avatar
chicm-ms committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200

def predict_model(args):
    model = eval(args.model_name)(args.layers, num_filters=args.nf)
    model_subdir = args.pad_mode
    if args.meta_version == 2:
        model_subdir = args.pad_mode+'_meta2'
    if args.exp_name is None:
        model_file = os.path.join(settings.MODEL_DIR, model.name,model_subdir, 'best_{}.pth'.format(args.ifold))
    else:
        model_file = os.path.join(settings.MODEL_DIR, args.exp_name, model.name, model_subdir, 'best_{}.pth'.format(args.ifold))

    if os.path.exists(model_file):
        print('loading {}...'.format(model_file))
        model.load_state_dict(torch.load(model_file))
    else:
        raise ValueError('model file not found: {}'.format(model_file))
    model = model.cuda()
    predict(args, model, model_file, args.sub_file)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Salt segmentation')
    parser.add_argument('--model_name', required=True, type=str, help='')
    parser.add_argument('--layers', default=34, type=int, help='model layers')
    parser.add_argument('--nf', default=32, type=int, help='num_filters param for model')
    parser.add_argument('--ifold', required=True, type=int, help='kfold indices')
    parser.add_argument('--batch_size', default=32, type=int, help='batch_size')
    parser.add_argument('--pad_mode', required=True, choices=['reflect', 'edge', 'resize'], help='pad method')
    parser.add_argument('--exp_name', default='depths', type=str, help='exp name')
    parser.add_argument('--meta_version', default=2, type=int, help='meta version')
    parser.add_argument('--sub_file', default='all_ensemble.csv', type=str, help='submission file')

    args = parser.parse_args()

    predict_model(args)
    #ensemble_predict(args)
    #ensemble_np_results(args)