itk.py 3.84 KB
Newer Older
mibaumgartner's avatar
utils  
mibaumgartner committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from pathlib import Path

import numpy as np
import SimpleITK as sitk
from itertools import product


from typing import Sequence, Union, Tuple


def create_circle_mask_itk(image_itk: sitk.Image,
                           world_centers: Sequence[Sequence[float]],
                           world_rads: Sequence[float],
                           ndim: int = 3,
                           ) -> sitk.Image:
    """
    Creates an itk image with circles defined by center points and radii

    Args:
        image_itk: original image (used for the coordinate frame)
        world_centers: Sequence of center points in world coordiantes (x, y, z)
        world_rads: Sequence of radii to use
        ndim: number of spatial dimensions

    Returns:
        sitk.Image: mask with circles
    """
    image_np = sitk.GetArrayFromImage(image_itk)
    min_spacing = min(image_itk.GetSpacing())

    if image_np.ndim > ndim:
        image_np = image_np[0]
    mask_np = np.zeros_like(image_np).astype(np.uint8)

    for _id, (world_center, world_rad) in enumerate(zip(world_centers, world_rads), start=1):
        check_rad = (world_rad / min_spacing) * 1.5  # add some buffer to it
        bounds = []
        center = image_itk.TransformPhysicalPointToContinuousIndex(world_center)[::-1]
        for ax, c in enumerate(center):
            bounds.append((
                max(0, int(c - check_rad)),
                min(mask_np.shape[ax], int(c + check_rad)),
            ))
        coord_box = product(*[list(range(b[0], b[1])) for b in bounds])

        # loop over every pixel position
        for coord in coord_box:
            world_coord = image_itk.TransformIndexToPhysicalPoint(tuple(reversed(coord)))  # reverse order to x, y, z for sitk
            dist = np.linalg.norm(np.array(world_coord) - np.array(world_center))
            if dist <= world_rad:
                mask_np[tuple(coord)] = _id
        assert mask_np.max() == _id

    mask_itk = sitk.GetImageFromArray(mask_np)
    return copy_meta_data_itk(image_itk, mask_itk)


def copy_meta_data_itk(source: sitk.Image, target: sitk.Image) -> sitk.Image:
    """
    Copy meta data between files

    Args:
        source: source file
        target: target file

    Returns:
        sitk.Image: target file with copied meta data
    """
    # for i in source.GetMetaDataKeys():
    #     target.SetMetaData(i, source.GetMetaData(i))
    target.SetOrigin(source.GetOrigin())
    target.SetDirection(source.GetDirection())
    target.SetSpacing(source.GetSpacing())
    return target


def load_sitk(path: Union[Path, str], **kwargs) -> sitk.Image:
    """
    Functional interface to load image with sitk

    Args:
        path: path to file to load

    Returns:
        sitk.Image: loaded sitk image
    """
    return sitk.ReadImage(str(path), **kwargs)


def load_sitk_as_array(path: Union[Path, str], **kwargs) -> Tuple[np.ndarray, dict]:
    """
    Functional interface to load sitk image and convert it to an array

    Args:
        path: path to file to load

    Returns:
        np.ndarray: loaded image data
        dict: loaded meta data
    """
    img_itk = load_sitk(path, **kwargs)
    meta = {key: img_itk.GetMetaData(key) for key in img_itk.GetMetaDataKeys()}
    return sitk.GetArrayFromImage(img_itk), meta