# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys import numpy as np sys.path.append( os.path.join(os.path.dirname(os.path.realpath(__file__)), "../..")) from medicalseg.cvlibs import manager from medicalseg.transforms import Compose from medicalseg.datasets import MedicalDataset URL = ' ' # todo: add coronavirus url @manager.DATASETS.add_component class MRISpineSeg(MedicalDataset): """ The MRISpineSeg dataset is come from the MRI Spine Seg competition Args: dataset_root (str): The dataset directory. Default: None result_root(str): The directory to save the result file. Default: None transforms (list): Transforms for image. mode (str, optional): Which part of dataset to use. it is one of ('train', 'val'). Default: 'train'. Examples: transforms=[] dataset_root = "data/lung_coronavirus/lung_coronavirus_phase0/" dataset = LungCoronavirus(dataset_root=dataset_root, transforms=[], num_classes=3, mode="train") for data in dataset: img, label = data print(img.shape, label.shape) # (1, 128, 128, 128) (128, 128, 128) print(np.unique(label)) """ def __init__(self, dataset_root=None, result_dir=None, transforms=None, num_classes=None, mode='train', ignore_index=255, dataset_json_path=""): super(MRISpineSeg, self).__init__( dataset_root, result_dir, transforms, num_classes, mode, ignore_index, data_URL=URL, dataset_json_path=dataset_json_path) if __name__ == "__main__": dataset = MRISpineSeg( dataset_root="data/MRSpineSeg/MRI_spine_seg_phase0_class3", result_dir="data/MRSpineSeg/MRI_spine_seg_phase1", transforms=[], mode="train", num_classes=3) for item in dataset: img, label = item if np.any(np.isnan(img)): print(img.dtype, label.dtype) # (1, 128, 128, 12) float32, int64