c002.py 16.9 KB
Newer Older
mibaumgartner's avatar
WIP  
mibaumgartner committed
1
2
3
4
5
6
7
8
9
import os
import copy
from typing import Callable, Sequence, List

import torch
import numpy as np
from loguru import logger

from nndet.planning.estimator import MemoryEstimator, MemoryEstimatorDetection
mibaumgartner's avatar
mibaumgartner committed
10
from nndet.planning.architecture.boxes.base import BoxC001
mibaumgartner's avatar
WIP  
mibaumgartner committed
11
12
13
14
15
16
17
18
19
20
21
22
23
from nndet.planning.architecture.boxes.utils import (
    proxy_num_boxes_in_patch,
    scale_with_abs_strides,
    )
from nndet.core.boxes import (
    get_anchor_generator,
    expand_to_boxes,
    box_center,
    box_size_np,
    permute_boxes,
    )


mibaumgartner's avatar
mibaumgartner committed
24
class BoxC002(BoxC001):
mibaumgartner's avatar
WIP  
mibaumgartner committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
    def __init__(self,
                 preprocessed_output_dir: os.PathLike,
                 save_dir: os.PathLike,
                 network_cls: Callable,
                 estimator: MemoryEstimator = MemoryEstimatorDetection(),
                 model_cfg: dict = None,
                 **kwargs,
                 ):
        super().__init__(
            preprocessed_output_dir=preprocessed_output_dir,
            save_dir=save_dir,
            network_cls=network_cls,
            estimator=estimator,
            model_cfg=model_cfg,
            **kwargs
            )

    def create_default_settings(self):
        """
        Generate default settings for the architecture
        """
        super().create_default_settings()
        self.architecture_kwargs["start_channels"] = 48 if self.dim == 2 else 32
        self.architecture_kwargs["fpn_channels"] = \
            self.architecture_kwargs["start_channels"] * 4
        self.architecture_kwargs["head_channels"] = \
            self.architecture_kwargs["fpn_channels"]
        self.batch_size = 16 if self.dim == 2 else 4
        self.min_feature_map_size = 8 if self.dim == 2 else 4
        self.num_decoder_level = 5 if self.dim == 2 else 4

    def get_anchor_init(self, boxes: torch.Tensor) -> Sequence[Sequence[int]]:
        """
        Initialize anchors sizes for optimization

        Args:
            boxes: scales and transposed boxes

        Returns:
            Sequence[Sequence[int]]: anchor initialization
        """
        box_dim = int(boxes.shape[1]) // 2
        return [(4, 8, 16), ] * box_dim

    def process_properties(self, **kwargs):
        """
        Load dataset properties and extract information
        """
        logger.info("Processing dataset properties")
        self.all_boxes = [case["boxes"] for case_id, case
                          in self.dataset_properties["instance_props_per_patient"].items()]
        self.all_spacings = [case["original_spacing"] for case_id, case
                             in self.dataset_properties["instance_props_per_patient"].items()]
        self.num_instances_per_case = {case_id: sum(case["num_instances"].values())
                                for case_id, case in self.dataset_properties["instance_props_per_patient"].items()}

        self.all_ious = self.dataset_properties["all_ious"]
        self.class_ious = self.dataset_properties["class_ious"]
        self.num_instances = self.dataset_properties["num_instances"]
        self.dim = self.dataset_properties["dim"]

        self.architecture_kwargs["classifier_classes"] = \
            len(self.dataset_properties["class_dct"])
        self.architecture_kwargs["seg_classes"] = \
            self.architecture_kwargs["classifier_classes"]
        self.architecture_kwargs["in_channels"] = \
            len(self.dataset_properties["modalities"])
        self.architecture_kwargs["dim"] = \
            self.dataset_properties["dim"]

    def plan(self,
             target_spacing_transposed: Sequence[float],
             median_shape_transposed: Sequence[float],
             transpose_forward: Sequence[int],
             mode: str = '3d',
             ) -> dict:
        """
        Plan network architecture, anchors, patch size and batch size

        Args:
            target_spacing_transposed: spacing after data is transposed and resampled
            median_shape_transposed: median shape after data is
                transposed and resampled
            transpose_forward: new ordering of axes for forward pass
            mode: mode to use for planning ('3d' | '2d')

        Returns:
            dict: training and architecture information

        See Also:
            :method:`_plan_architecture`, :method:`_plan_anchors`
        """
        if mode == "2d":
            logger.info("Running 2d mode")
            self.process_properties()
            kwargs_2d = self.activate_2d_mode(
                transpose_forward=transpose_forward,
                target_spacing_transposed=target_spacing_transposed,
                median_shape_transposed=median_shape_transposed,
            )
            res = super().plan(**kwargs_2d)
        else:
            res = super().plan(
                transpose_forward=transpose_forward,
                target_spacing_transposed=target_spacing_transposed,
                median_shape_transposed=median_shape_transposed,
            )
        return res

    def activate_2d_mode(self,
                         target_spacing_transposed: Sequence[float],
                         median_shape_transposed: Sequence[float],
                         transpose_forward: Sequence[int],
                         ) -> dict:
        target_spacing_transposed = target_spacing_transposed[1:]
        median_shape_transposed = median_shape_transposed[1:]
        keep = copy.copy(transpose_forward[1:])
        transpose_forward = [t - 1 for t in keep]

        keep_box = [0, 0, 0, 0]
        for idx, k in enumerate(keep):
            if k < 2:
                keep_box[idx] = k
                keep_box[idx + 2] = k + 2
            else:
                keep_box[idx] = 2 * k
                keep_box[idx + 2] = 2 * k + 1

        self.all_boxes = [b[:, keep_box] if (not isinstance(b, list) and b.shape[1] == 6) else b
                          for b in self.all_boxes]
        self.all_spacings = [c[keep] if len(c) == 3 else c for c in self.all_spacings]

        self.dim = 2
        self.architecture_kwargs["dim"] = self.dim
        return {
            "target_spacing_transposed": target_spacing_transposed,
            "median_shape_transposed": median_shape_transposed,
            "transpose_forward": transpose_forward,
        }

    def _plan_architecture(self,
                           target_spacing_transposed: Sequence[float],
                           target_median_shape_transposed: Sequence[float],
                           transpose_forward: Sequence[int],
                           **kwargs,
                           ) -> Sequence[int]:
        """
        Plan patch size and main aspects of the architecture
        Fills entries in :param:`self.architecture_kwargs`:
            `conv_kernels`
            `strides`
            `decoder_levels`

        Args:
            target_spacing_transposed: spacing after data is transposed and resampled
            target_median_shape_transposed: median shape after data is 
                transposed and resampled
        
        Returns:
            Sequence[int]: patch size to use for training
        """
        self.estimator.batch_size = self.batch_size
        patch_size = np.asarray(self._get_initial_patch_size(
            target_spacing_transposed, target_median_shape_transposed))
        first_run = True
        while True:
            if first_run:
                pass
            else:
                patch_size = self._decrease_patch_size(
                    patch_size, target_median_shape_transposed, pooling, must_be_divisible_by)
            num_pool_per_axis, pooling, convs, patch_size, must_be_divisible_by = \
                self.plan_pool_and_conv_pool_late(patch_size, target_spacing_transposed)
            self.architecture_kwargs["conv_kernels"] = convs
            self.architecture_kwargs["strides"] = pooling
            num_resolutions = len(self.architecture_kwargs["conv_kernels"])

            decoder_levels_start = min(max(1, num_resolutions - self.num_decoder_level), self.min_decoder_level)
            self.architecture_kwargs["decoder_levels"] = \
                tuple([i for i in range(decoder_levels_start, num_resolutions)])
            _, fits_in_mem = self.estimator.estimate(
                min_shape=must_be_divisible_by,
                target_shape=patch_size,
                in_channels=self.architecture_kwargs["in_channels"],
                network=self.network_cls.from_config_plan(
                    model_cfg=self.model_cfg,
                    plan_arch=self.architecture_kwargs,
                    plan_anchors=self.get_anchors_for_estimation()),
                optimizer_cls=torch.optim.Adam,
                num_instances=self._estimte_num_instances_per_patch(
                    patch_size=patch_size,
                    target_spacing_transposed=target_spacing_transposed,
                    transpose_forward=transpose_forward,
                    ),
                )
            if fits_in_mem:
                break
            first_run = False
        logger.info(f"decoder levels: {self.architecture_kwargs['decoder_levels']}; \n"
                    f"pooling strides: {self.architecture_kwargs['strides']}; \n"
                    f"kernel sizes: {self.architecture_kwargs['conv_kernels']}; \n"
                    f"patch size: {patch_size}; \n")
        return patch_size

    def _estimte_num_instances_per_patch(self,
                                         patch_size,
                                         target_spacing_transposed,
                                         transpose_forward,
                                         ) -> int:
        max_instances_per_image = []
        for boxes in self._get_scaled_boxes(
            target_spacing_transposed=target_spacing_transposed,
            transpose_forward=transpose_forward,
            cat=False,
            ):
            max_instances_per_image.append(
                max(proxy_num_boxes_in_patch(torch.from_numpy(boxes), patch_size)).item())
        return max(max_instances_per_image)

    def _plan_anchors(self,
                      target_spacing_transposed: Sequence[float],
                      transpose_forward: Sequence[int],
                      **kwargs,
                      ) -> dict:
        """
        Optimize anchors
        """
        boxes_np_full = self._get_scaled_boxes(
            target_spacing_transposed=target_spacing_transposed,
            transpose_forward=transpose_forward,
        )

        boxes_np = self.filter_boxes(boxes_np_full)
        logger.info(f"Filtered {boxes_np_full.shape[0] - boxes_np.shape[0]} "
                    f"boxes, {boxes_np.shape[0]} boxes remaining for anchor "
                    "planning.")
        boxes_torch = torch.from_numpy(boxes_np).float()
        boxes_torch = boxes_torch - expand_to_boxes(box_center(boxes_torch))
        anchor_generator = get_anchor_generator(self.dim, s_param=True)

        rel_strides = self.architecture_kwargs["strides"]
        filt_rel_strides = [[1] * self.dim, *rel_strides]
        filt_rel_strides = [filt_rel_strides[i] for i in self.architecture_kwargs["decoder_levels"]]
        strides = np.cumprod(filt_rel_strides, axis=0) / np.asarray(rel_strides[0])

        params = self.find_anchors(boxes_torch, strides.astype(np.int32), anchor_generator)
        scaled_params = {key: scale_with_abs_strides(item, strides, dim_idx) for dim_idx, (key, item) in enumerate(params.items())}
        logger.info(f"Determined Anchors: {params}; Results in params: {scaled_params}")
        self.anchors = scaled_params
        self.anchors["stride"] = 1
        return self.anchors
    
    def _get_scaled_boxes(self,
                          target_spacing_transposed: Sequence[float],
                          transpose_forward: Sequence[int],
                          cat: bool = True,
                          ) -> np.ndarray:
        """
        training is conducted in preprocessed image space and thus
        we need to scale the extracted boxes to compensate for resampling
        """
        boxes_np_list = []
        for spacing, boxes in zip(self.all_spacings, self.all_boxes):
            if not isinstance(boxes, list) and boxes.size > 0:
                spacing_transposed = np.asarray(spacing)[transpose_forward]
                scaling_transposed = spacing_transposed / np.asarray(target_spacing_transposed)
                boxes_transposed = permute_boxes(np.asarray(boxes), dims=transpose_forward)
                boxes_np_list.append(boxes_transposed * expand_to_boxes(scaling_transposed))
        if cat:
            return np.concatenate(boxes_np_list).astype(np.float32)
        else:
            return boxes_np_list

    @staticmethod
    def _get_initial_patch_size(target_spacing_transposed: np.ndarray,
                                target_median_shape_transposed: Sequence[int],
                                ) -> List[int]:
        """
        Generate initial patch which relies on the spacing of underlying images.
        This is based on the fact that most acquisition protocols are optimized
        to focus on the most importatnt aspects.
        
        Returns:
            List[int]: initial patch size
        """
        voxels_per_mm = 1 / np.array(target_spacing_transposed)

        # normalize voxels per mm
        input_patch_size = voxels_per_mm / voxels_per_mm.mean()

        # create an isotropic patch of size 512x512x512mm
        input_patch_size *= 1 / min(input_patch_size) * 512  # to get a starting value
        input_patch_size = np.round(input_patch_size).astype(np.int32)

        # clip it to the median shape of the dataset because patches larger
        # then that make not much sense and account for recangular patches
        if len(target_spacing_transposed) > 2:
            lowres_axis = np.argmax(target_spacing_transposed)
            isotropic_axes = list(range(len(target_median_shape_transposed)))
            isotropic_axes.pop(lowres_axis)
            min_isotropic_axes_shape = min([target_median_shape_transposed[t] for t in isotropic_axes])
            lowres_shape = target_median_shape_transposed[lowres_axis]
        else:
            lowres_axis = -1
            lowres_shape = None
            min_isotropic_axes_shape = min(target_median_shape_transposed)

        initial_patch_size = []
        for i in range(len(target_median_shape_transposed)):
            if i == lowres_axis:
                assert lowres_shape is not None
                initial_patch_size.append(min(input_patch_size[i], lowres_shape))
            else:
                initial_patch_size.append(min(input_patch_size[i], min_isotropic_axes_shape))
        initial_patch_size = np.round(initial_patch_size).astype(np.int32)
        logger.info(f"Using initial patch size: {initial_patch_size}")
        return initial_patch_size

    def plot_box_distribution(self, 
                              target_spacing_transposed: Sequence[float],
                              transpose_forward: Sequence[int],
                              **kwargs):
        """
        Plot histogram with ground truth bounding box distribution for
        all axis
        """
        super().plot_box_distribution()
        try:
            from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 unused import
            import matplotlib.pyplot as plt
        except ImportError:
            logger.error("Failed to import matplotlib continue anyway.")
            plt = None

        if plt is not None:
            if isinstance(self.all_boxes, list):
                _boxes = np.concatenate(
                    [b for b in self.all_boxes if not isinstance(b, list) and b.size > 0], axis=0)
                dists = box_size_np(_boxes)
            else:
                dists = box_size_np(self.all_boxes)

            if dists.shape[1] == 3:
                fig = plt.figure()
                ax = fig.add_subplot(111, projection='3d')
                ax.scatter(dists[:, 0], dists[:, 1], dists[:, 2])
                ax.set_title(f"Transpose forward {transpose_forward}")
                plt.savefig(self.save_dir / f'bbox_sizes_3d_orig.png')
                plt.close()

                dists = box_size_np(self._get_scaled_boxes(
                    target_spacing_transposed, transpose_forward))
                fig = plt.figure()
                ax = fig.add_subplot(111, projection='3d')
                ax.scatter(dists[:, 0], dists[:, 1], dists[:, 2])
                plt.savefig(self.save_dir / f'bbox_sizes_3d.png')
                plt.close()
            else:
                fig = plt.figure()
                ax = fig.add_subplot(111)
                ax.scatter(dists[:, 0], dists[:, 1])
                ax.grid(True)
                ax.set_title(f"Transpose forward {transpose_forward}")
                plt.savefig(self.save_dir / f'bbox_sizes_2d_orig.png')
                plt.close()

                dists = box_size_np(self._get_scaled_boxes(
                    target_spacing_transposed, transpose_forward))
                fig = plt.figure()
                ax = fig.add_subplot(111)
                ax.scatter(dists[:, 0], dists[:, 1])
                ax.grid(True)
                plt.savefig(self.save_dir / f'bbox_sizes_2d.png')
                plt.close()