predictor.py 12.7 KB
Newer Older
mibaumgartner's avatar
mibaumgartner committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import time
import torch
import copy
import collections
import numpy as np

from loguru import logger
from typing import Hashable, List, Sequence, Dict, Union, Any, Optional, Callable, TypeVar
from pathlib import Path

from nndet.io.load import save_pickle
mibaumgartner's avatar
mibaumgartner committed
28
from nndet.arch.abstract import AbstractModel
mibaumgartner's avatar
mibaumgartner committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
from nndet.io.transforms import NoOp
from nndet.inference.patching import save_get_crop, create_grid
from nndet.utils import to_device, maybe_verbose_iterable

from rising.transforms import AbstractTransform
from rising.loading import DataLoader


__all__ = ["Predictor"]
torch_device = Union[torch.device, str]


class Predictor:
    def __init__(self,
                 ensembler: Dict[str, Callable],
                 models: Sequence[AbstractModel],
                 crop_size: Sequence[int],
                 overlap: float = 0.5, 
                 tile_keys: Sequence[str] = ('data',),
                 model_keys: Sequence[str] = ('data',),
                 tta_transforms: Sequence[AbstractTransform] = (NoOp(),),
                 tta_inverse_transforms: Sequence[AbstractTransform] = (NoOp(),),
                 pre_transform: AbstractTransform = None,
                 post_transform: AbstractTransform = None,
                 batch_size: int = 4,
                 model_weights: Sequence[float] = None,
                 device: torch_device = "cuda:0",
                 ensemble_on_device: bool = True,
                 ):
        """
        Predict entire cases with TTA and Model-Ensembling

        Workflow
        - Load whole patient
        -> create predictor from patient
        - tile patient
        * for each model:
            * for each batch (batches of tiles):
                * for each tta transform:
                    - pre transform
                    - tta transform
                    - post transform
                    - predict batch
                    - inverse tta transform
                    - forward predictions and batch to ensembler classes
        <- return patient result

        Args:
            ensembler: Callable to instantiate ensembler from case and
                properties
            models: models to ensemble
            crop_size: size of each crop (for most cases this should be
                the same as in training)
            overlap: overlap of crops
            tile_keys: keys which are tiles
            model_keys: this kyes are passed as positional arugments to the
                model
            tta_transforms: tta transformations
            tta_inverse_transforms: inverse tta transformation
            pre_transform: transform which is performed before every tta
                transform
            post_transform: transform which is performed after every tta
                transform
            batch_size: batch size to use for prediction
            model_weights: additional weighting of individual models
            device: device used for prediction
            ensemble_on_device: The results will be passed to the ensembler
                class with the current device. The ensembler needs to make
                sure to avoid memory leaks!
        """
        self.ensemble_on_device = ensemble_on_device
        self.device = device
        self.ensembler_fns = ensembler
        self.ensembler = {}

        self.models = models
        self.model_weights = [1.] * len(models) if model_weights is None else model_weights

        self.crop_size = crop_size
        self.overlap = overlap
        self.tile_keys = tile_keys
        self.model_keys = model_keys
        
        self.batch_size = batch_size

        if len(tta_transforms) != len(tta_inverse_transforms):
            raise ValueError("Every tta transform needs a reverse transform")
        self.tta_transforms = tta_transforms
        self.tta_inverse_transforms = tta_inverse_transforms
        self.post_transform = post_transform
        self.pre_transform = pre_transform
        
        self.grid_mode = 'symmetric'
        self.save_get_mode = 'shift'

    @classmethod
    def create(cls, *args, **kwargs):
        """
        Create predictor object with specific ensembler objects

        Raises:
            NotImplementedError: Need to be overwritten in subclasses
        """
        raise NotImplementedError

    @classmethod
    def get_ensembler(cls, key: Hashable, dim: int) -> Callable:
        """
        Return ensembler class for specific keys
        Typically: `boxes`, `seg`, `instances`

        Args:
            key: Key to return
            dim: number of spatial dimensions the network expects

        Raises:
            NotImplementedError: Need to be overwritten in subclasses

        Returns:
            Callable: Ensembler class
        """
        raise NotImplementedError

    def predict_case(self,
                     case: Dict,
                     properties: Optional[Dict] = None,
                     save_dir: Optional[Union[Path, str]] = None,
                     case_id: Optional[str] = None,
                     restore: bool = False,
                     ) -> dict:
        """
        Load and predict a single case.

        Args:
            case: data of a single case
            properties: additional properties of the case. E.g. to
                restore prediction in original image space
            save_dir: directory to save predictions
            case_id: used for saving
            restore: restore prediction in original image space
                ("revert" preprocessing)

        Returns:
            dict: result of each ensembler (converted to numpy)
        """
        tic = time.perf_counter()
        for name, fn in self.ensembler_fns.items():
            self.ensembler[name] = fn(case, properties=properties)

        tiles = self.tile_case(case)
        self.predict_tiles(tiles)

        result = {key: value.get_case_result(restore=restore) for key, value in self.ensembler.items()}
        if save_dir is not None:
            save_dir = Path(save_dir)
            save_dir.mkdir(parents=True, exist_ok=True)
            for ensembler in self.ensembler.values():
                ensembler.save_state(save_dir, name=case_id)
            save_pickle(properties, save_dir / f"{case_id}_properties.pkl")
        toc = time.perf_counter()
        logger.info(f"Prediction took {toc - tic} s")
        return result

    def tile_case(self, case: dict, update_remaining: bool = True) -> \
            Sequence[Dict[str, np.ndarray]]:
        """
        Create patches from whole patient for prediction

        Args:
            case: data of a single case
            update_remaining: properties from case which are not tiles
                are saved into all patches

        Returns:
            Sequence[Dict[str, np.ndarray]]: extracted crops from case
                and added new key:
                    `tile_origin`: Sequence[int] offset of tile relative
                        to case origin
        """
        dshape = case[self.tile_keys[0]].shape
        overlap = [int(c * self.overlap) for c in self.crop_size]
        crops = create_grid(
            cshape=self.crop_size,
            dshape=dshape[1:],
            overlap=overlap,
            mode=self.grid_mode,
            )

        tiles = []
        for crop in crops:
            try:
                # try selected extraction mode
                tile = {key: save_get_crop(case[key], crop, mode=self.save_get_mode)[0]
                        for key in self.tile_keys}
                _, tile["tile_origin"], tile["crop"] = save_get_crop(
                    case[self.tile_keys[0]], crop, mode=self.save_get_mode)
            except RuntimeError:
                # fallback to symmetric
                logger.warning("Path size is bigger than whole case, padding case to match patch size")
                tile = {key: save_get_crop(case[key], crop, mode="symmetric")[0]
                        for key in self.tile_keys}
                _, tile["tile_origin"], tile["crop"] = save_get_crop(
                    case[self.tile_keys[0]], crop, mode="symmetric")

            if update_remaining:
                tile.update({key: item for key, item in case.items()
                             if key not in self.tile_keys})
            tiles.append(tile)
        return tiles

    @torch.no_grad()
    def predict_tiles(self, tiles: Sequence[Dict]) -> None:
        """
        Predict tiles of a single case with ensembling and tta. Results
        are saved inside ensemblers

        Args:
            tiles: tiles from single case
        """
        dataloader = DataLoader(tiles,
                                batch_size=self.batch_size,
                                shuffle=False,
                                collate_fn=slice_collate,
                                )
        for model_idx, (model, model_weight) in enumerate(
            zip(self.models, self.model_weights)):
            logger.info(f"Predicting model {model_idx + 1} of "
                        f"{len(self.models)} with weight {model_weight}.")

            model.to(device=self.device)
            model.eval()

            for t, (transform, inverse_transform) in enumerate(maybe_verbose_iterable(
                    list(zip(self.tta_transforms, self.tta_inverse_transforms)),
                    desc="Transform", position=0)):
                for ensembler in self.ensembler.values():
                    ensembler.add_model(name=f"model{model_idx}_t{t}", model_weight=model_weight)

                for batch_num, batch in enumerate(maybe_verbose_iterable(
                    dataloader, desc="Crop", position=1)):
                    self.predict_with_transformation(
                        model=model,
                        batch=batch,
                        batch_num=batch_num,
                        transform=transform,
                        inverse_transform=inverse_transform,
                    )

            model.cpu()
            torch.cuda.empty_cache()

    def predict_with_transformation(self,
                                    model: AbstractModel,
                                    batch: Dict,
                                    batch_num: int,
                                    transform: Callable,
                                    inverse_transform: Callable,
                                    ):
        """
        Run prediction with the specified transformations

        Args:
            model: model to predict
            batch: input batch to model
            batch_num: batch index
            transform: transform to apply to batch.
            inverse_transform: inverse transform to apply to batch and resuls
        """
        batch = to_device(batch, device=self.device)
        if self.pre_transform is not None:
            batch = self.pre_transform(**batch)

        transformed = transform(**batch)

        if self.post_transform is not None:
            transformed = self.post_transform(**transformed)

        inp = [transformed[key] for key in self.model_keys]
        with torch.cuda.amp.autocast():
            result = model.inference_step(*inp, batch_num=batch_num)
        result = inverse_transform(**result)

        if not self.ensemble_on_device:
            result = to_device(result, device="cpu")

        for ensembler in self.ensembler.values():
            ensembler.process_batch(result=result, batch=batch)


def slice_collate(batch: List[Any]):
    """
    Add support for slices to collate function
    
    Args:
        batch: batch to collate
    
    Returns:
        Any: collated items
    """
    elem = batch[0]
    elem_type = type(elem)
    if isinstance(batch[0], slice):
        return batch
    elif isinstance(elem, collections.abc.Mapping):
        return {key: slice_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(slice_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, collections.abc.Sequence):
        transposed = zip(*batch)
        return [slice_collate(samples) for samples in transposed]
    else:
        return torch.utils.data._utils.collate.default_collate(batch)


PredictorType = TypeVar('PredictorType', bound=Predictor)