nnunet.py 22.5 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
# Implementation of this model is borrowed and modified
# (from torch to paddle) from here:
# https://github.com/MIC-DKFZ/nnUNet

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import pickle
import numpy as np
from copy import deepcopy

sys.path.append(
    os.path.join(os.path.dirname(os.path.realpath(__file__)), "../.."))

import paddle
import paddle.nn as nn
import paddle.nn.functional as F

from medicalseg.cvlibs import manager
from medicalseg.utils import utils
from medicalseg.cvlibs import param_init
from medicalseg.models import layers

from tools.preprocess_utils import experiment_planner


@manager.MODELS.add_component
class NNUNet(nn.Layer):
    """
    Args:
        plan_path (int): The plan path of nnunet.
        num_classes (int): Only for comparative with other models, this param has no function. Default: 0.
        pretrained (str | optional): The path or url of pretrained model. Default: None.
        stage (int | optional): The stage of nnunet, 0 for nnunet_2d and nnunet_3d, 1 for nnunet_cascade stage 2. Default: None.
        cascade (bool | optional): Whether is cascade model. Default: False.
        deep_supervision (bool | optional): Whether return multi-scale feats when training mode. Default: True.
        feat_map_mul_on_downscale (int | optional): The expansion ratio of stage channels. Defatult: 2.
        use_dropout (bool | optional): Whether use dropout layer in model. Default: False.
        upscale_logits (bool | optional): Whether upscale output feats with different resolutions to the same resolution. Defatult: False.
        convolutional_pooling (bool | optional): Whether add pool layer after conv layer. If convolutional_pooling is True, only conv layer is used and reduce resolution by conv stride. Default: False.
        convolutional_upsampling (bool | optional): Use transpose conv layer or interpolate to upsample feature maps. If True, using transpose conv. Default: False.
    """

    def __init__(
            self,
            plan_path,
            num_classes=0,
            pretrained=None,
            stage=None,
            cascade=False,
            deep_supervision=True,
            feat_map_mul_on_downscale=2,
            use_dropout=False,
            upscale_logits=False,
            convolutional_pooling=True,
            convolutional_upsampling=True, ):
        super().__init__()
        self.plan_path = plan_path
        self.stage = stage
        self.cascade = cascade
        self.load_and_process_plan_file(plan_path)

        if self.threeD:
            conv_op = nn.Conv3D
            dropout_op = nn.Dropout3D
            norm_op = nn.InstanceNorm3D
            max_num_features = experiment_planner.MAX_NUM_FILTERS_3D
        else:
            conv_op = nn.Conv2D
            dropout_op = nn.Dropout2D
            norm_op = nn.InstanceNorm2D
            max_num_features = experiment_planner.MAX_FILTERS_2D

        norm_op_kwargs = {'epsilon': 1e-5}
        dropout_op_kwargs = {'p': 0}
        net_nonlin = nn.LeakyReLU
        net_nonlin_kwargs = {'negative_slope': 1e-2}
        self.network = Generic_UNet(
            input_channels=self.num_input_channels,
            base_num_features=self.base_num_features,
            num_classes=self.num_classes,
            num_pool=len(self.net_num_pool_op_kernel_sizes),
            num_conv_per_stage=self.conv_per_stage,
            feat_map_mul_on_downscale=feat_map_mul_on_downscale,
            conv_op=conv_op,
            norm_op=norm_op,
            norm_op_kwargs=norm_op_kwargs,
            dropout_op=dropout_op,
            dropout_op_kwargs=dropout_op_kwargs,
            nonlin=net_nonlin,
            nonlin_kwargs=net_nonlin_kwargs,
            deep_supervision=deep_supervision,
            use_dropout=use_dropout,
            pool_op_kernel_sizes=self.net_num_pool_op_kernel_sizes,
            conv_kernel_sizes=self.net_conv_kernel_sizes,
            upscale_logits=upscale_logits,
            convolutional_pooling=convolutional_pooling,
            convolutional_upsampling=convolutional_upsampling,
            max_num_features=max_num_features)

        self.pretrained = pretrained
        self.init_weight()

    def init_weight(self):
        if self.pretrained is not None:
            utils.load_entire_model(self, self.pretrained)
        else:
            for sublayer in self.sublayers():
                if isinstance(sublayer,
                              (nn.Conv2D, nn.Conv3D, nn.Conv2DTranspose,
                               nn.Conv3DTranspose)):
                    param_init.kaiming_normal_init(sublayer.weight)
                    if sublayer.bias is not None:
                        param_init.constant_init(sublayer.bias, value=0)

    def load_and_process_plan_file(self, plan_path):
        with open(plan_path, 'rb') as f:
            plans = pickle.load(f)

        if self.stage is None:
            assert len(
                list(plans['plans_per_stage'].keys())
            ) == 1, "If self.stage is None then there can be only one stage in the plans file but got {}. Please specify which stage of the cascade must be trained.".format(
                len(list(plans['plans_per_stage'].keys())))
            self.stage = list(plans['plans_per_stage'].keys())[0]
        self.plans = plans

        stage_plans = self.plans['plans_per_stage'][self.stage]
        self.net_pool_per_axis = stage_plans['num_pool_per_axis']
        self.patch_size = np.array(stage_plans['patch_size']).astype(int)
        self.net_num_pool_op_kernel_sizes = stage_plans['pool_op_kernel_sizes']
        self.net_conv_kernel_sizes = stage_plans['conv_kernel_sizes']
        self.base_num_features = plans['base_num_features']
        self.num_input_channels = plans['num_modalities']
        self.classes = plans['all_classes']
        self.num_classes = plans['num_classes'] + 1
        if self.stage == 1 and self.cascade:
            self.num_input_channels += (self.num_classes - 1)

        if len(self.patch_size) == 2:
            self.threeD = False
        elif len(self.patch_size) == 3:
            self.threeD = True
        else:
            raise RuntimeError("Invalid patch size in plans file: {}".format(
                self.patch_size))

        if "conv_per_stage" in plans.keys():
            self.conv_per_stage = plans['conv_per_stage']
        else:
            self.conv_per_stage = 2

    def forward(self, x):
        x = self.network(x)
        return x


class StackedConvLayers(nn.Layer):
    def __init__(self,
                 input_feature_channels,
                 output_feature_channels,
                 num_convs,
                 conv_op=nn.Conv2D,
                 conv_kwargs=None,
                 norm_op=nn.BatchNorm2D,
                 norm_op_kwargs=None,
                 dropout_op=nn.Dropout2D,
                 dropout_op_kwargs=None,
                 nonlin=nn.LeakyReLU,
                 nonlin_kwargs=None,
                 first_stride=None,
                 basic_block=layers.ConvDropoutNormNonlin):
        super(StackedConvLayers, self).__init__()
        self.input_channels = input_feature_channels
        self.output_channels = output_feature_channels

        if nonlin_kwargs is None:
            nonlin_kwargs = {'negative_slope': 1e-2}
        if dropout_op_kwargs is None:
            dropout_op_kwargs = {'p': 0.5}
        if norm_op_kwargs is None:
            norm_op_kwargs = {'eps': 1e-5}
        if conv_kwargs is None:
            conv_kwargs = {
                'kernel_size': 3,
                'stride': 1,
                'padding': 1,
                'dilation': 1,
                'bias_attr': True
            }

        self.nonlin_kwargs = nonlin_kwargs
        self.nonlin = nonlin
        self.dropout_op = dropout_op
        self.dropout_op_kwargs = dropout_op_kwargs
        self.norm_op_kwargs = norm_op_kwargs
        self.conv_kwargs = conv_kwargs
        self.conv_op = conv_op
        self.norm_op = norm_op

        if first_stride is not None:
            self.conv_kwargs_first_conv = deepcopy(conv_kwargs)
            self.conv_kwargs_first_conv['stride'] = first_stride
        else:
            self.conv_kwargs_first_conv = conv_kwargs

        self.blocks = nn.Sequential(
            basic_block(input_feature_channels, output_feature_channels,
                        self.conv_op, self.conv_kwargs_first_conv, self.norm_op,
                        self.norm_op_kwargs, self.dropout_op,
                        self.dropout_op_kwargs, self.nonlin,
                        self.nonlin_kwargs),
            *[
                basic_block(output_feature_channels, output_feature_channels,
                            self.conv_op, self.conv_kwargs, self.norm_op,
                            self.norm_op_kwargs, self.dropout_op,
                            self.dropout_op_kwargs, self.nonlin,
                            self.nonlin_kwargs) for _ in range(num_convs - 1)
            ])

    def forward(self, x):
        return self.blocks(x)


class Generic_UNet(nn.Layer):
    """
    Args:
        input_channels (int): The input channels of nnUNet.
        base_num_features (int): Basic number of nnUNet channels.
        num_pool (int): The number of MaxPooling.
        num_conv_per_stage (int | optional): The number of conv-bn-nonlin blocks in every stage. Default: 2.
        feat_map_mul_on_downscale (int | optional): The expansion ratio of stage channels. Defatult: 2.
        conv_op (paddle.nn.Layer | optional): The type of conv layer, only support nn.Conv2D and nn.Conv3D. Default: nn.Conv2D.
        norm_op (paddle.nn.Layer | optional): The type of batchnorm layer. Default: nn.BatchNorm2D.
        norm_op_kwwargs (dict | optional): The params for norm_op.
        dropout_op (paddle.nn.Layer | optional): The type of dropout layer. Default: nn.Dropout2D.
        dropout_op_kwargs (dict | optional): The params for dropout_op.
        nonlin (paddle.nn.Layer | optional): The type of activation layer. Default: nn.LeakyReLU.
        nonlin_kwargs (dict | optional): The params for nonlin.
        deep_supervision (bool | optional): Whether return multi-scale feats when training mode. Default: True.
        use_dropout (bool | optional): Whether use dropout layer in model. Default: False.
        pool_op_kernel_sizes (list | optional): The kernel_sizes of pool layers. If None, this param will be computed from num_pool automatically. Default: None.
        conv_kernel_sizes (list | optional): The kernel_sizes of conv layers. If None, this param will be computed from num_pool automatically. Default: None.
        upscale_logits (bool | optional): Whether upscale output feats with different resolutions to the same resolution. Defatult: False.
        convolutional_pooling (bool | optional): Whether add pool layer after conv layer. If convolutional_pooling is True, only conv layer is used and reduce resolution by conv stride. Default: False.
        convolutional_upsampling (bool | optional): Use transpose conv layer or interpolate to upsample feature maps. If True, using transpose conv. Default: False.
        max_num_features (int | optional): The maximum channels of feature maps. Default: None.
        basic_block (paddle.nn.Layer): Only use conv-drop-norm-nonlin module. Default: layers.ConvDropoutNormNonlin.
        seg_output_use_bias (bool | optional): Whether use bias in segmentation head. Default: False.
    """

    def __init__(self,
                 input_channels,
                 base_num_features,
                 num_classes,
                 num_pool,
                 num_conv_per_stage=2,
                 feat_map_mul_on_downscale=2,
                 conv_op=nn.Conv2D,
                 norm_op=nn.BatchNorm2D,
                 norm_op_kwargs=None,
                 dropout_op=nn.Dropout2D,
                 dropout_op_kwargs=None,
                 nonlin=nn.LeakyReLU,
                 nonlin_kwargs=None,
                 deep_supervision=True,
                 use_dropout=False,
                 pool_op_kernel_sizes=None,
                 conv_kernel_sizes=None,
                 upscale_logits=False,
                 convolutional_pooling=False,
                 convolutional_upsampling=False,
                 max_num_features=None,
                 basic_block=layers.ConvDropoutNormNonlin,
                 seg_output_use_bias=False):
        super().__init__()
        self.convolutional_upsampling = convolutional_upsampling
        self.convolutional_pooling = convolutional_pooling
        self.upscale_logits = upscale_logits

        if nonlin_kwargs is None:
            nonlin_kwargs = {'negative_slope': 1e-2}
        if dropout_op_kwargs is None:
            dropout_op_kwargs = {'p': 0.5}
        if norm_op_kwargs is None:
            norm_op_kwargs = {'epsilon': 1e-5}
        self.conv_kwargs = {'stride': 1, 'dilation': 1, 'bias_attr': True}

        self.nonlin = nonlin
        self.nonlin_kwargs = nonlin_kwargs
        self.dropout_op_kwargs = dropout_op_kwargs
        self.norm_op_kwargs = norm_op_kwargs
        self.conv_op = conv_op
        self.norm_op = norm_op
        self.dropout_op = dropout_op
        self.num_classes = num_classes
        self._deep_supervision = deep_supervision

        if conv_op == nn.Conv2D:
            upsample_mode = 'bilinear'
            pool_op = nn.MaxPool2D
            transpconv = nn.Conv2DTranspose
            if pool_op_kernel_sizes is None:
                pool_op_kernel_sizes = [(2, 2)] * num_pool
            if conv_kernel_sizes is None:
                conv_kernel_sizes = [(3, 3)] * (num_pool + 1)
        elif conv_op == nn.Conv3D:
            upsample_mode = 'trilinear'
            pool_op = nn.MaxPool3D
            transpconv = nn.Conv3DTranspose
            if pool_op_kernel_sizes is None:
                pool_op_kernel_sizes = [(2, 2, 2)] * num_pool
            if conv_kernel_sizes is None:
                conv_kernel_sizes = [(3, 3, 3)] * (num_pool + 1)
        else:
            raise ValueError("Unknown convolution dimensionality, conv op: {}.".
                             format(str(conv_op)))

        self.input_shape_must_be_divisible_by = np.prod(
            pool_op_kernel_sizes, 0, dtype=np.int64)
        self.pool_op_kernel_sizes = pool_op_kernel_sizes
        self.conv_kernel_sizes = conv_kernel_sizes
        self.max_num_features = max_num_features

        self.conv_pad_sizes = []
        for krnl in self.conv_kernel_sizes:
            self.conv_pad_sizes.append([1 if i == 3 else 0 for i in krnl])

        self.conv_blocks_context = []
        self.conv_blocks_localization = []
        self.pool_layers = []
        self.upsample_ops = []
        self.seg_heads = []

        output_features = base_num_features
        input_features = input_channels

        for d in range(num_pool):
            if d != 0 and self.convolutional_pooling:
                first_stride = pool_op_kernel_sizes[d - 1]
            else:
                first_stride = None

            self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[d]
            self.conv_kwargs['padding'] = self.conv_pad_sizes[d]

            self.conv_blocks_context.append(
                StackedConvLayers(
                    input_features,
                    output_features,
                    num_conv_per_stage,
                    self.conv_op,
                    self.conv_kwargs,
                    self.norm_op,
                    self.norm_op_kwargs,
                    self.dropout_op,
                    self.dropout_op_kwargs,
                    self.nonlin,
                    self.nonlin_kwargs,
                    first_stride,
                    basic_block=basic_block))
            if not self.convolutional_pooling:
                self.pool_layers.append(pool_op(pool_op_kernel_sizes[d]))
            input_features = output_features
            output_features = int(
                np.round(output_features * feat_map_mul_on_downscale))
            output_features = min(output_features, self.max_num_features)

        if self.convolutional_pooling:
            first_stride = pool_op_kernel_sizes[-1]
        else:
            first_stride = None

        if self.convolutional_upsampling:
            final_num_features = output_features
        else:
            final_num_features = self.conv_blocks_context[-1].output_channels

        self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[num_pool]
        self.conv_kwargs['padding'] = self.conv_pad_sizes[num_pool]
        self.conv_blocks_context.append(
            nn.Sequential(
                StackedConvLayers(
                    input_features,
                    output_features,
                    num_conv_per_stage - 1,
                    self.conv_op,
                    self.conv_kwargs,
                    self.norm_op,
                    self.norm_op_kwargs,
                    self.dropout_op,
                    self.dropout_op_kwargs,
                    self.nonlin,
                    self.nonlin_kwargs,
                    first_stride,
                    basic_block=basic_block),
                StackedConvLayers(
                    output_features,
                    final_num_features,
                    1,
                    self.conv_op,
                    self.conv_kwargs,
                    self.norm_op,
                    self.norm_op_kwargs,
                    self.dropout_op,
                    self.dropout_op_kwargs,
                    self.nonlin,
                    self.nonlin_kwargs,
                    basic_block=basic_block)))

        if not use_dropout:
            self.dropout_op_kwargs['p'] = 0.0

        for u in range(num_pool):
            nfeatures_from_down = final_num_features
            nfeatures_from_skip = self.conv_blocks_context[-(2 + u
                                                             )].output_channels
            n_features_after_tu_and_concat = nfeatures_from_skip * 2

            if u != num_pool - 1 and not self.convolutional_upsampling:
                final_num_features = self.conv_blocks_context[-(
                    3 + u)].output_channels
            else:
                final_num_features = nfeatures_from_skip

            if not self.convolutional_upsampling:
                self.upsample_ops.append(
                    nn.Upsample(
                        scale_factor=pool_op_kernel_sizes[-(u + 1)],
                        mode=upsample_mode))
            else:
                self.upsample_ops.append(
                    transpconv(
                        nfeatures_from_down,
                        nfeatures_from_skip,
                        pool_op_kernel_sizes[-(u + 1)],
                        pool_op_kernel_sizes[-(u + 1)],
                        bias_attr=False))

            self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[-(u + 1)]
            self.conv_kwargs['padding'] = self.conv_pad_sizes[-(u + 1)]
            self.conv_blocks_localization.append(
                nn.Sequential(
                    StackedConvLayers(
                        n_features_after_tu_and_concat,
                        nfeatures_from_skip,
                        num_conv_per_stage - 1,
                        self.conv_op,
                        self.conv_kwargs,
                        self.norm_op,
                        self.norm_op_kwargs,
                        self.dropout_op,
                        self.dropout_op_kwargs,
                        self.nonlin,
                        self.nonlin_kwargs,
                        basic_block=basic_block),
                    StackedConvLayers(
                        nfeatures_from_skip,
                        final_num_features,
                        1,
                        self.conv_op,
                        self.conv_kwargs,
                        self.norm_op,
                        self.norm_op_kwargs,
                        self.dropout_op,
                        self.dropout_op_kwargs,
                        self.nonlin,
                        self.nonlin_kwargs,
                        basic_block=basic_block)))

        for ds in range(len(self.conv_blocks_localization)):
            self.seg_heads.append(
                conv_op(
                    self.conv_blocks_localization[ds][-1].output_channels,
                    num_classes,
                    1,
                    1,
                    0,
                    1,
                    1,
                    bias_attr=seg_output_use_bias))

        self.upscale_logits_ops = []
        cum_upsample = np.cumprod(np.vstack(pool_op_kernel_sizes), axis=0)[::-1]
        for usl in range(num_pool - 1):
            if self.upscale_logits:
                self.upscale_logits_ops.append(
                    nn.Upsample(
                        scale_factor=tuple(
                            [int(i) for i in cum_upsample[usl + 1]]),
                        mode=upsample_mode))
            else:
                self.upscale_logits_ops.append(layers.Identity())

        self.conv_blocks_localization = nn.LayerList(
            self.conv_blocks_localization)
        self.conv_blocks_context = nn.LayerList(self.conv_blocks_context)
        self.pool_layers = nn.LayerList(self.pool_layers)
        self.upsample_ops = nn.LayerList(self.upsample_ops)
        self.seg_heads = nn.LayerList(self.seg_heads)
        if self.upscale_logits:
            self.upscale_logits_ops = nn.LayerList(self.upscale_logits_ops)

    def forward(self, x):
        skips = []
        outputs = []
        for d in range(len(self.conv_blocks_context) - 1):
            x = self.conv_blocks_context[d](x)
            skips.append(x)
            if not self.convolutional_pooling:
                x = self.pool_layers[d](x)

        x = self.conv_blocks_context[-1](x)

        for u in range(len(self.upsample_ops)):
            x = self.upsample_ops[u](x)
            x = paddle.concat([x, skips[-(u + 1)]], axis=1)
            x = self.conv_blocks_localization[u](x)
            outputs.append(self.seg_heads[u](x))

        if self._deep_supervision and self.training:
            return [[outputs[-1]] + [
                up_op(feat)
                for up_op, feat in zip(
                    list(self.upscale_logits_ops)[::-1], outputs[:-1][::-1])
            ]]
        else:
            return [outputs[-1]]