# Implementation of this model is borrowed and modified # (from torch to paddle) from here: # https://github.com/MIC-DKFZ/nnUNet # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import os import sys import pickle import shutil import numpy as np from copy import deepcopy from typing import Tuple, List, Union parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) sys.path.insert(0, parent_path) import paddle from paddle.inference import create_predictor, PrecisionType from paddle.inference import Config as PredictConfig from nnunet.utils.static_predictor import StaticPredictor from nnunet.transforms import default_2D_augmentation_params, default_3D_augmentation_params from nnunet.predict import predict_from_folder from tools.preprocess_utils import GenericPreprocessor, PreprocessorFor2D def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--image_folder", help="Must contain all modalities for each patient in the correct" " order (same as training). Files must be named " "CASENAME_XXXX.nii.gz where XXXX is the modality " "identifier (0000, 0001, etc)", required=True) parser.add_argument( '--output_folder', required=True, help="folder for saving predictions") parser.add_argument( '--model_type', required=True, type=str, help="Model type, only support '2d', '3d', 'cascade_lowres', 'cascade_fullres'." ) parser.add_argument( '--plan_path', required=True, type=str, help='the path to plan_path') parser.add_argument( '--model_paths', nargs='+', required=True, help="The multi pdmodel paths.") parser.add_argument( '--param_paths', nargs='+', required=True, help="The multi pdiparams paths.") parser.add_argument( '--postprocessing_json_path', required=True, default=None, type=str, help='the path to postprocessing json.') parser.add_argument( '--folds', required=False, type=int, default=5, help='number of folds, default: 5.') parser.add_argument( '--lowres_segmentations', required=False, default=None, help="If model is the highres stage of the cascade then you can use this folder to provide " "predictions from the low resolution 3D U-Net.") parser.add_argument( '--save_npz', required=False, action='store_true', help="use this if you want to ensemble these predictions with those of other models. Softmax " "probabilities will be saved as compressed numpy arrays in output_folder and can be " "merged between output_folders with nnUNet_ensemble_predictions") parser.add_argument( "--num_threads_preprocessing", required=False, default=6, type=int, help="Determines many background processes will be used for data preprocessing. Reduce this if you " "run into out of memory (RAM) problems. Default: 6") parser.add_argument( "--num_threads_nifti_save", required=False, default=2, type=int, help="Determines many background processes will be used for segmentation export. Reduce this if you " "run into out of memory (RAM) problems. Default: 2") parser.add_argument( "--mode", type=str, default="normal", required=False, help="Hands off!") parser.add_argument( "--step_size", type=float, default=0.5, required=False, help="don't touch") parser.add_argument( "--overwrite_existing", required=False, default=False, action="store_true", help="Set this flag if the target folder contains predictions that you would like to overwrite" ) parser.add_argument( "--disable_postprocessing", required=False, default=False, action="store_true", help="Set this flag if no need postprocessing") parser.add_argument( "--disable_tta", required=False, default=False, action="store_true", help="set this flag to disable test time data augmentation via mirroring. Speeds up inference " "by roughly factor 4 (2D) or 8 (3D)") parser.add_argument( '--min_subgraph_size', default=3, type=int, help='The min subgraph size in tensorrt prediction.') return parser.parse_args() class StaticMultiFolderPredictor: def __init__(self, model_paths, param_paths, plan_path, stage, min_subgraph_size=3): self.stage = stage self.plans = self.load_plans(plan_path) self.num_classes = self.plans['num_classes'] + 1 self.patch_size = np.array(self.plans['plans_per_stage'][self.stage][ 'patch_size']).astype(int) if len(self.patch_size) == 2: self.threeD = False self.data_aug_params = default_2D_augmentation_params elif len(self.patch_size) == 3: self.threeD = True self.data_aug_params = default_3D_augmentation_params self.intensity_properties = self.plans['dataset_properties'][ 'intensityproperties'] self.normalization_schemes = self.plans['normalization_schemes'] self.use_mask_for_norm = self.plans['use_mask_for_norm'] if self.plans.get('transpose_forward') is None or self.plans.get( 'transpose_backward') is None: print( "WARNING! You seem to have data that was preprocessed with a previous version of nnU-Net. " "You should rerun preprocessing. We will proceed and assume that both transpose_foward " "and transpose_backward are [0, 1, 2]. If that is not correct then weird things will happen!" ) self.plans['transpose_forward'] = [0, 1, 2] self.plans['transpose_backward'] = [0, 1, 2] self.transpose_forward = self.plans['transpose_forward'] self.transpose_backward = self.plans['transpose_backward'] self.predictors = [] for model_path, param_path in zip(model_paths, param_paths): self.predictors.append( StaticPredictor(model_path, param_path, self.plans, stage, min_subgraph_size)) def load_plans(self, plan_path): with open(plan_path, 'rb') as f: plans = pickle.load(f) return plans def preprocess_patient(self, input_files): if self.threeD: preprocessor_class = GenericPreprocessor else: preprocessor_class = PreprocessorFor2D preprocessor = preprocessor_class( self.normalization_schemes, self.use_mask_for_norm, self.transpose_forward, self.intensity_properties) d, s, properties = preprocessor.preprocess_test_case( input_files, self.plans['plans_per_stage'][self.stage]['current_spacing']) return d, s, properties def multi_folds_predict_preprocessed_data_return_seg_and_softmax( self, data: np.ndarray, do_mirroring: bool=True, mirror_axes: Tuple[int]=None, use_sliding_window: bool=True, step_size: float=0.5, use_gaussian: bool=True, pad_border_mode: str='constant', pad_kwargs: dict=None, verbose: bool=True, mixed_precision=True): softmax_res = None for predictor in self.predictors: x = predictor.predict_preprocessed_data_return_seg_and_softmax( data=data, do_mirroring=do_mirroring, mirror_axes=mirror_axes, use_sliding_window=use_sliding_window, step_size=step_size, use_gaussian=use_gaussian, pad_border_mode=pad_border_mode, pad_kwargs=pad_kwargs, verbose=verbose, mixed_precision=mixed_precision)[1] if softmax_res is None: softmax_res = x else: softmax_res += x return softmax_res / len(self.predictors) def main(args): assert args.model_type in [ '2d', '3d', 'cascade_lowres', 'cascade_fullres' ], "model only support ['2d', '3d', 'cascade_lowres', 'cascade_fullres'], but got {}.".format( args.model_type) assert len(args.model_paths) == len( args.param_paths ), "The number of pdmodel is not the same with pdiparams. {} != {}.".format( len(args.model_paths), len(args.params_paths)) print("model type: ", args.model_type) print("The plan path: ", args.plan_path) print("The model paths: ", args.model_paths) print("The postprocessing json path: ", args.postprocessing_json_path) if args.model_type in ['3d', 'cascade_fullres']: stage = 1 else: stage = 0 predictor = StaticMultiFolderPredictor(args.model_paths, args.param_paths, args.plan_path, stage, args.min_subgraph_size) if args.lowres_segmentations is not None: assert args.model_type == 'cascade_fullres', "You supply lowres_segmentations dir but the model is not 'cascade_fullres'. Please check model_type." print("Cascade lowres segmentation result dir: ", args.lowres_segmentations) predict_from_folder( predictor=predictor, input_folder=args.image_folder, output_folder=args.output_folder, save_npz=args.save_npz, num_threads_preprocessing=args.num_threads_preprocessing, num_threads_nifti_save=args.num_threads_nifti_save, lowres_segmentations=args.lowres_segmentations, tta=not args.disable_tta, mixed_precision=False, overwrite_existing=args.overwrite_existing, mode='normal', step_size=args.step_size, plan_path=args.plan_path, disable_postprocessing=args.disable_postprocessing, postprocessing_json_path=args.postprocessing_json_path) if __name__ == '__main__': args = parse_args() main(args)