Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
14782f9a
Commit
14782f9a
authored
Jun 24, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Jun 24, 2020
Browse files
Opensource 3D Unet model to Tensorflow Official Model Garden.
PiperOrigin-RevId: 318151912
parent
d495e481
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
0 additions
and
1881 deletions
+0
-1881
official/vision/segmentation/README.md
official/vision/segmentation/README.md
+0
-188
official/vision/segmentation/__init__.py
official/vision/segmentation/__init__.py
+0
-0
official/vision/segmentation/convert_lits.py
official/vision/segmentation/convert_lits.py
+0
-269
official/vision/segmentation/convert_lits_nii_to_npy.py
official/vision/segmentation/convert_lits_nii_to_npy.py
+0
-58
official/vision/segmentation/unet_config.py
official/vision/segmentation/unet_config.py
+0
-76
official/vision/segmentation/unet_data.py
official/vision/segmentation/unet_data.py
+0
-175
official/vision/segmentation/unet_main.py
official/vision/segmentation/unet_main.py
+0
-350
official/vision/segmentation/unet_main_test.py
official/vision/segmentation/unet_main_test.py
+0
-248
official/vision/segmentation/unet_metrics.py
official/vision/segmentation/unet_metrics.py
+0
-279
official/vision/segmentation/unet_model.py
official/vision/segmentation/unet_model.py
+0
-238
No files found.
official/vision/segmentation/README.md
deleted
100644 → 0
View file @
d495e481
# UNet 3D Model
This repository contains TensorFlow 2.x implementation for 3D Unet model
[
[1]
](
#1
)
as well as instructions for producing the data for training and
evaluation.
Furthermore, this implementation also includes use of spatial partitioning
[
[2]
](
#2
)
for TPU's to leverage high resolution images for training.
## Contents
*
[
Contents
](
#contents
)
*
[
Prerequsites
](
#prerequsites
)
*
[
Setup
](
#setup
)
*
[
Data Preparation
](
#data-preparation
)
*
[
Training
](
#data-preparation
)
*
[
Train with Spatial Partition
](
#train-with-spatial-partition
)
*
[
Evaluation
](
#evaluation
)
*
[
References
](
#references
)
## Prerequsites
To use high resolution image data, spatial partition should be used to avoid
prevent out of memory issues. This is currently only supported with TPU's. To
use TPU's for training, in Google Cloud console, please run the following
command to create cloud TPU VM.
```
shell
ctpu up
-name
=[
tpu_name]
-tf-version
=
nightly
-tpu-size
=
v3-8
-zone
=
us-central1-b
```
## Setup
Before running any binary, please install necessary packages on cloud VM.
```
shell
pip
install
-r
requirements.tx
```
## Data Preparation
This software uses TFRecords as input. We provide example scripts to convert
Numpy (.npy) files or NIfTI-1 (.nii) files to TFRecords, using the Liver Tumor
Segmentation (LiTS) dataset (Christ et al.
https://competitions.codalab.org/competitions/17094). You can download the
dataset by registering on the competition website.
**Example**
:
```
shell
cd
data_preprocess
# Change input_path and output_path in convert_lits_nii_to_npy.py
# Then run the script to convert nii to npy.
python convert_lits_nii_to_npy.py
# Convert npy files to TFRecords.
python convert_lits.py
\
--image_file_pattern
=
Downloads/.../volume-
{}
.npy
\
--label_file_pattern
=
Downloads/.../segmentation-
{}
.npy
\
--output_path
=
Downloads/...
```
## Training
Working configs on TPU V3-8:
+
TF 2.2, train_batch_size=16, use_batch_norm=true, dtype='bfloat16' or
'float16', spatial partition not used.
+
tf-nightly, train_batch_size=32, use_batch_norm=true, dtype='bfloat16',
spatial partition used.
The following example shows how to train volumic UNet on TPU v3-8. The loss is
*adaptive_dice32*
. The training batch size is 32. For detail config, refer to
`unet_config.py`
and example config file shown below.
**Example**
:
```
shell
DATA_BUCKET
=
<GS bucket
for
data>
TRAIN_FILES
=
"
${
DATA_BUCKET
}
/tfrecords/trainbox*.tfrecord"
VAL_FILES
=
"
${
DATA_BUCKET
}
/tfrecords/validationbox*.tfrecord"
MODEL_BUCKET
=
<GS bucket
for
model checkpoints>
EXP_NAME
=
unet_20190610_dice_t1
python unet_main.py
\
--distribution_strategy
=
<
"mirrored"
or
"tpu"
>
--num_gpus
=
<
'number of GPUs to use if using mirrored strategy'
>
--tpu
=
<TPU name>
\
--model_dir
=
"gs://
${
MODEL_BUCKET
}
/models/
${
EXP_NAME
}
"
\
--training_file_pattern
=
"
${
TRAIN_FILES
}
"
\
--eval_file_pattern
=
"
${
VAL_FILES
}
"
\
--steps_per_loop
=
10
\
--mode
=
train
\
--config_file
=
"./configs/cloud/v3-8_128x128x128_ce.yaml"
\
```
The following script example is for running evaluation on TPU v3-8.
Configurations such as
`train_batch_size`
,
`train_steps`
,
`eval_batch_size`
and
`eval_item_count`
are defined in the configuration file passed as
`config_file`
flag. It is only one line change from previous script: changes the
`mode`
flag to "eval".
### Train with Spatial Partition
The following example specifies spatial partition with the
"--input_partition_dims" in the config file. For example, setting
`input_partition_dims: [1, 16, 1, 1, 1]`
in the config_file will split
the image into 16 ways in first (width) dimension. The first dimension
(set to 1) is the batch dimension.
**Example: Train with 16-way spatial partition**
:
```
shell
DATA_BUCKET
=
<GS bucket
for
data>
TRAIN_FILES
=
"
${
DATA_BUCKET
}
/tfrecords/trainbox*.tfrecord"
VAL_FILES
=
"
${
DATA_BUCKET
}
/tfrecords/validationbox*.tfrecord"
MODEL_BUCKET
=
<GS bucket
for
model checkpoints>
EXP_NAME
=
unet_20190610_dice_t1
python unet_main.py
\
--distribution_strategy
=
<
"mirrored"
or
"tpu"
>
--num_gpus
=
<
'number of GPUs to use if using mirrored strategy'
>
--tpu
=
<TPU name>
\
--model_dir
=
"gs://
${
MODEL_BUCKET
}
/models/
${
EXP_NAME
}
"
\
--training_file_pattern
=
"
${
TRAIN_FILES
}
"
\
--eval_file_pattern
=
"
${
VAL_FILES
}
"
\
--steps_per_loop
=
10
\
--mode
=
train
\
--config_file
=
"./configs/cloud/v3-8_128x128x128_ce.yaml"
```
**Example: Example config file with 16-way spatial partition**
:
```
train_steps: 3000
loss: 'adaptive_dice32'
train_batch_size: 8
eval_batch_size: 8
use_index_label_in_train: false
input_partition_dims: [1,16,1,1,1]
input_image_size: [256,256,256]
dtype: 'bfloat16'
label_dtype: 'float32'
train_item_count: 5400
eval_item_count: 1674
```
## Evaluation
```
shell
DATA_BUCKET
=
<GS bucket
for
data>
TRAIN_FILES
=
"
${
DATA_BUCKET
}
/tfrecords/trainbox*.tfrecord"
VAL_FILES
=
"
${
DATA_BUCKET
}
/tfrecords/validationbox*.tfrecord"
MODEL_BUCKET
=
<GS bucket
for
model checkpoints>
EXP_NAME
=
unet_20190610_dice_t1
python unet_main.py
\
--distribution_strategy
=
<
"mirrored"
or
"tpu"
>
--num_gpus
=
<
'number of GPUs to use if using mirrored strategy'
>
--tpu
=
<TPU name>
\
--model_dir
=
"gs://
${
MODEL_BUCKET
}
/models/
${
EXP_NAME
}
"
\
--training_file_pattern
=
"
${
TRAIN_FILES
}
"
\
--eval_file_pattern
=
"
${
VAL_FILES
}
"
\
--steps_per_loop
=
10
\
--mode
=
"eval"
\
--config_file
=
"./configs/cloud/v3-8_128x128x128_ce.yaml"
```
## License
[

](https://opensource.org/licenses/Apache-2.0)
This project is licensed under the terms of the
**Apache License 2.0**
.
## References
<a
id=
"1"
>
[1]
</a>
Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp,
Thomas Brox, Olaf Ronneberger "3D U-Net: Learning Dense Volumetric Segmentation
from Sparse Annotation": https://arxiv.org/abs/1606.06650. (MICCAI 2016).
<a
id=
"2"
>
[2]
</a>
Le Hou, Youlong Cheng, Noam Shazeer, Niki Parmar, Yeqing Li,
Panagiotis Korfiatis, Travis M. Drucker, Daniel J. Blezek, Xiaodan Song "High
Resolution Medical Image Analysis with Spatial Partitioning":
https://arxiv.org/abs/1810.04805.
official/vision/segmentation/__init__.py
deleted
100644 → 0
View file @
d495e481
official/vision/segmentation/convert_lits.py
deleted
100644 → 0
View file @
d495e481
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r
"""Converts raw LiTS numpy data to TFRecord.
The file is forked from:
https://github.com/tensorflow/tpu/blob/master/models/official/unet3d/data_preprocess/convert_lits.py
"""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
os
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
numpy
as
np
from
PIL
import
Image
from
scipy
import
ndimage
import
tensorflow.google.compat.v1
as
tf
flags
.
DEFINE_string
(
"image_file_pattern"
,
None
,
"path pattern to an input image npy file."
)
flags
.
DEFINE_string
(
"label_file_pattern"
,
None
,
"path pattern to an input label npy file."
)
flags
.
DEFINE_string
(
"output_path"
,
None
,
"path to output TFRecords."
)
flags
.
DEFINE_boolean
(
"crop_liver_region"
,
True
,
"whether to crop liver region out."
)
flags
.
DEFINE_boolean
(
"apply_data_aug"
,
False
,
"whether to apply data augmentation."
)
flags
.
DEFINE_integer
(
"shard_start"
,
0
,
"start with volume-${shard_start}.npy."
)
flags
.
DEFINE_integer
(
"shard_stride"
,
1
,
"this process will convert "
"volume-${shard_start + n * shard_stride}.npy for all n."
)
flags
.
DEFINE_integer
(
"output_size"
,
128
,
"output, cropped size along x, y, and z."
)
flags
.
DEFINE_integer
(
"resize_size"
,
192
,
"size along x, y, and z before cropping."
)
FLAGS
=
flags
.
FLAGS
def
to_1hot
(
label
):
per_class
=
[]
for
classes
in
range
(
3
):
per_class
.
append
((
label
==
classes
)[...,
np
.
newaxis
])
label
=
np
.
concatenate
(
per_class
,
axis
=-
1
).
astype
(
label
.
dtype
)
return
label
def
save_to_tfrecord
(
image
,
label
,
idx
,
im_id
,
output_path
,
convert_label_to_1hot
):
"""Save to TFRecord."""
if
convert_label_to_1hot
:
label
=
to_1hot
(
label
)
d_feature
=
{}
d_feature
[
"image/ct_image"
]
=
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
image
.
reshape
([
-
1
]).
tobytes
()]))
d_feature
[
"image/label"
]
=
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
label
.
reshape
([
-
1
]).
tobytes
()]))
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
d_feature
))
serialized
=
example
.
SerializeToString
()
result_file
=
os
.
path
.
join
(
output_path
,
"instance-{}-{}.tfrecords"
.
format
(
im_id
,
idx
))
options
=
tf
.
python_io
.
TFRecordOptions
(
tf
.
python_io
.
TFRecordCompressionType
.
GZIP
)
with
tf
.
python_io
.
TFRecordWriter
(
result_file
,
options
=
options
)
as
w
:
w
.
write
(
serialized
)
def
intensity_change
(
im
):
"""Color augmentation."""
if
np
.
random
.
rand
()
<
0.1
:
return
im
# Randomly scale color.
sigma
=
0.05
truncate_rad
=
0.1
im
*=
np
.
clip
(
np
.
random
.
normal
(
1.0
,
sigma
),
1.0
-
truncate_rad
,
1.0
+
truncate_rad
)
return
im
def
rand_crop_liver
(
image
,
label
,
res_s
,
out_s
,
apply_data_aug
,
augment_times
=
54
):
"""Crop image and label; Randomly change image intensity.
Randomly crop image and label around liver.
Args:
image: 3D numpy array.
label: 3D numpy array.
res_s: resized size of image and label.
out_s: output size of random crops.
apply_data_aug: whether to apply data augmentation.
augment_times: the number of times to randomly crop and augment data.
Yields:
croped and augmented image and label.
"""
if
image
.
shape
!=
(
res_s
,
res_s
,
res_s
)
or
\
label
.
shape
!=
(
res_s
,
res_s
,
res_s
):
logging
.
info
(
"Unexpected shapes. "
"image.shape: %s, label.shape: %s"
,
image
.
shape
,
label
.
shape
)
return
rough_liver_label
=
1
x
,
y
,
z
=
np
.
where
(
label
==
rough_liver_label
)
bbox_center
=
[(
x
.
min
()
+
x
.
max
())
//
2
,
(
y
.
min
()
+
y
.
max
())
//
2
,
(
z
.
min
()
+
z
.
max
())
//
2
]
def
in_range_check
(
c
):
c
=
max
(
c
,
out_s
//
2
)
c
=
min
(
c
,
res_s
-
out_s
//
2
)
return
c
for
_
in
range
(
augment_times
):
rand_c
=
[]
for
c
in
bbox_center
:
sigma
=
out_s
//
6
truncate_rad
=
out_s
//
4
c
+=
np
.
clip
(
np
.
random
.
randn
()
*
sigma
,
-
truncate_rad
,
truncate_rad
)
rand_c
.
append
(
int
(
in_range_check
(
c
)))
image_aug
=
image
[
rand_c
[
0
]
-
out_s
//
2
:
rand_c
[
0
]
+
out_s
//
2
,
rand_c
[
1
]
-
out_s
//
2
:
rand_c
[
1
]
+
out_s
//
2
,
rand_c
[
2
]
-
out_s
//
2
:
rand_c
[
2
]
+
out_s
//
2
].
copy
()
label_aug
=
label
[
rand_c
[
0
]
-
out_s
//
2
:
rand_c
[
0
]
+
out_s
//
2
,
rand_c
[
1
]
-
out_s
//
2
:
rand_c
[
1
]
+
out_s
//
2
,
rand_c
[
2
]
-
out_s
//
2
:
rand_c
[
2
]
+
out_s
//
2
].
copy
()
if
apply_data_aug
:
image_aug
=
intensity_change
(
image_aug
)
yield
image_aug
,
label_aug
def
rand_crop_whole_ct
(
image
,
label
,
res_s
,
out_s
,
apply_data_aug
,
augment_times
=
2
):
"""Crop image and label; Randomly change image intensity.
Randomly crop image and label.
Args:
image: 3D numpy array.
label: 3D numpy array.
res_s: resized size of image and label.
out_s: output size of random crops.
apply_data_aug: whether to apply data augmentation.
augment_times: the number of times to randomly crop and augment data.
Yields:
croped and augmented image and label.
"""
if
image
.
shape
!=
(
res_s
,
res_s
,
res_s
)
or
\
label
.
shape
!=
(
res_s
,
res_s
,
res_s
):
logging
.
info
(
"Unexpected shapes. "
"image.shape: %s, label.shape: %s"
,
image
.
shape
,
label
.
shape
)
return
if
not
apply_data_aug
:
# Do not augment data.
idx
=
(
res_s
-
out_s
)
//
2
image
=
image
[
idx
:
idx
+
out_s
,
idx
:
idx
+
out_s
,
idx
:
idx
+
out_s
]
label
=
label
[
idx
:
idx
+
out_s
,
idx
:
idx
+
out_s
,
idx
:
idx
+
out_s
]
yield
image
,
label
else
:
cut
=
res_s
-
out_s
for
_
in
range
(
augment_times
):
for
i
in
[
0
,
cut
//
2
,
cut
]:
for
j
in
[
0
,
cut
//
2
,
cut
]:
for
k
in
[
0
,
cut
//
2
,
cut
]:
image_aug
=
image
[
i
:
i
+
out_s
,
j
:
j
+
out_s
,
k
:
k
+
out_s
].
copy
()
label_aug
=
label
[
i
:
i
+
out_s
,
j
:
j
+
out_s
,
k
:
k
+
out_s
].
copy
()
image_aug
=
intensity_change
(
image_aug
)
yield
image_aug
,
label_aug
def
resize_3d_image_nearest_interpolation
(
im
,
res_s
):
"""Resize 3D image, but with nearest interpolation."""
new_shape
=
[
res_s
,
im
.
shape
[
1
],
im
.
shape
[
2
]]
ret0
=
np
.
zeros
(
new_shape
,
dtype
=
im
.
dtype
)
for
i
in
range
(
im
.
shape
[
2
]):
im_slice
=
np
.
array
(
Image
.
fromarray
(
im
[...,
i
]).
resize
(
(
im
.
shape
[
1
],
res_s
),
resample
=
Image
.
NEAREST
))
ret0
[...,
i
]
=
im_slice
new_shape
=
[
res_s
,
res_s
,
res_s
]
ret
=
np
.
zeros
(
new_shape
,
dtype
=
im
.
dtype
)
for
i
in
range
(
res_s
):
im_slice
=
np
.
array
(
Image
.
fromarray
(
ret0
[
i
,
...]).
resize
(
(
res_s
,
res_s
),
resample
=
Image
.
NEAREST
))
ret
[
i
,
...]
=
im_slice
return
ret
def
process_one_file
(
image_path
,
label_path
,
im_id
,
output_path
,
res_s
,
out_s
,
crop_liver_region
,
apply_data_aug
):
"""Convert one npy file."""
with
tf
.
gfile
.
Open
(
image_path
,
"rb"
)
as
f
:
image
=
np
.
load
(
f
)
with
tf
.
gfile
.
Open
(
label_path
,
"rb"
)
as
f
:
label
=
np
.
load
(
f
)
image
=
ndimage
.
zoom
(
image
,
[
float
(
res_s
)
/
image
.
shape
[
0
],
float
(
res_s
)
/
image
.
shape
[
1
],
float
(
res_s
)
/
image
.
shape
[
2
]])
label
=
resize_3d_image_nearest_interpolation
(
label
.
astype
(
np
.
uint8
),
res_s
).
astype
(
np
.
float32
)
if
crop_liver_region
:
for
idx
,
(
image_aug
,
label_aug
)
in
enumerate
(
rand_crop_liver
(
image
,
label
,
res_s
,
out_s
,
apply_data_aug
)):
save_to_tfrecord
(
image_aug
,
label_aug
,
idx
,
im_id
,
output_path
,
convert_label_to_1hot
=
True
)
else
:
# not crop_liver_region
# If we output the entire CT scan (crop_liver_region=False),
# do not convert_label_to_1hot to save storage.
for
idx
,
(
image_aug
,
label_aug
)
in
enumerate
(
rand_crop_whole_ct
(
image
,
label
,
res_s
,
out_s
,
apply_data_aug
)):
save_to_tfrecord
(
image_aug
,
label_aug
,
idx
,
im_id
,
output_path
,
convert_label_to_1hot
=
False
)
def
main
(
argv
):
del
argv
output_path
=
FLAGS
.
output_path
res_s
=
FLAGS
.
resize_size
out_s
=
FLAGS
.
output_size
crop_liver_region
=
FLAGS
.
crop_liver_region
apply_data_aug
=
FLAGS
.
apply_data_aug
for
im_id
in
range
(
FLAGS
.
shard_start
,
1000000
,
FLAGS
.
shard_stride
):
image_path
=
FLAGS
.
image_file_pattern
.
format
(
im_id
)
label_path
=
FLAGS
.
label_file_pattern
.
format
(
im_id
)
if
not
tf
.
gfile
.
Exists
(
image_path
):
logging
.
info
(
"Reached the end. Image does not exist: %s. "
"Process finish."
,
image_path
)
break
process_one_file
(
image_path
,
label_path
,
im_id
,
output_path
,
res_s
,
out_s
,
crop_liver_region
,
apply_data_aug
)
if
__name__
==
"__main__"
:
app
.
run
(
main
)
official/vision/segmentation/convert_lits_nii_to_npy.py
deleted
100644 → 0
View file @
d495e481
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r
"""Converts .nii files in LiTS dataset to .npy files.
This script should be run just once before running convert_lits.py.
The file is forked from:
https://github.com/tensorflow/tpu/blob/master/models/official/unet3d/data_preprocess/convert_lits_nii_to_npy.py
"""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
glob
import
multiprocessing
import
os
import
nibabel
as
nib
import
numpy
as
np
num_processes
=
2
input_path
=
"Downloads/LiTS/Train/"
# where the .nii files are.
output_path
=
"Downloads/LiTS/Train_np/"
# where you want to put the npy files.
def
process_one_file
(
image_path
):
"""Convert one nii file to npy."""
im_id
=
os
.
path
.
basename
(
image_path
).
split
(
"volume-"
)[
1
].
split
(
".nii"
)[
0
]
label_path
=
image_path
.
replace
(
"volume-"
,
"segmentation-"
)
image
=
nib
.
load
(
image_path
).
get_data
().
astype
(
np
.
float32
)
label
=
nib
.
load
(
label_path
).
get_data
().
astype
(
np
.
float32
)
print
(
"image shape: {}, dtype: {}"
.
format
(
image
.
shape
,
image
.
dtype
))
print
(
"label shape: {}, dtype: {}"
.
format
(
label
.
shape
,
label
.
dtype
))
np
.
save
(
os
.
path
.
join
(
output_path
,
"volume-{}.npy"
.
format
(
im_id
)),
image
)
np
.
save
(
os
.
path
.
join
(
output_path
,
"segmentation-{}.npy"
.
format
(
im_id
)),
label
)
nii_dir
=
os
.
path
.
join
(
input_path
,
"volume-*"
)
p
=
multiprocessing
.
Pool
(
num_processes
)
p
.
map
(
process_one_file
,
glob
.
glob
(
nii_dir
))
official/vision/segmentation/unet_config.py
deleted
100644 → 0
View file @
d495e481
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Config to train UNet."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
UNET_CONFIG
=
{
# Place holder for tpu configs.
'tpu_config'
:
{},
'model_dir'
:
''
,
'training_file_pattern'
:
None
,
'eval_file_pattern'
:
None
,
# The input files are GZip compressed and need decompression.
'compressed_input'
:
True
,
'dtype'
:
'bfloat16'
,
'label_dtype'
:
'float32'
,
'train_batch_size'
:
8
,
'eval_batch_size'
:
8
,
'predict_batch_size'
:
8
,
'train_epochs'
:
20
,
'train_steps'
:
1000
,
'eval_steps'
:
10
,
'num_steps_per_eval'
:
100
,
'min_eval_interval'
:
180
,
'eval_timeout'
:
None
,
'optimizer'
:
'adam'
,
'momentum'
:
0.9
,
# Spatial dimension of input image.
'input_image_size'
:
[
128
,
128
,
128
],
# Number of channels of the input image.
'num_channels'
:
1
,
# Spatial partition dimensions.
'input_partition_dims'
:
None
,
# Use deconvolution to upsample, otherwise upsampling.
'deconvolution'
:
True
,
# Number of areas i need to segment
'num_classes'
:
3
,
# Number of filters used by the architecture
'num_base_filters'
:
32
,
# Depth of the network
'depth'
:
4
,
# Dropout values to use across the network
'dropout_rate'
:
0.5
,
# Number of levels that contribute to the output.
'num_segmentation_levels'
:
2
,
# Use batch norm.
'use_batch_norm'
:
True
,
'init_learning_rate'
:
0.1
,
# learning rate decay steps.
'lr_decay_steps'
:
100
,
# learning rate decay rate.
'lr_decay_rate'
:
0.5
,
# Data format, 'channels_last' and 'channels_first'
'data_format'
:
'channels_last'
,
# Use class index for training. Otherwise, use one-hot encoding.
'use_index_label_in_train'
:
False
,
# e.g. softmax cross entropy, adaptive_dice32
'loss'
:
'adaptive_dice32'
,
}
UNET_RESTRICTIONS
=
[]
official/vision/segmentation/unet_data.py
deleted
100644 → 0
View file @
d495e481
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r
"""Defines input_fn of TF2 UNet-3D model."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
functools
import
tensorflow
as
tf
class
BaseInput
(
object
):
"""Input function for 3D Unet model."""
def
__init__
(
self
,
file_pattern
,
params
,
is_training
):
self
.
_params
=
params
self
.
_file_pattern
=
file_pattern
self
.
_is_training
=
is_training
self
.
_parser_fn
=
self
.
create_parser_fn
(
params
)
if
params
.
compressed_input
:
self
.
_dataset_fn
=
functools
.
partial
(
tf
.
data
.
TFRecordDataset
,
compression_type
=
'GZIP'
)
else
:
self
.
_dataset_fn
=
tf
.
data
.
TFRecordDataset
def
create_parser_fn
(
self
,
params
):
"""Create parse fn to extract tensors from tf.Example."""
def
_parser
(
serialized_example
):
"""Parses a single tf.Example into image and label tensors."""
features
=
tf
.
io
.
parse_example
(
serialized
=
[
serialized_example
],
features
=
{
'image/encoded'
:
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
float32
),
'image/segmentation/mask'
:
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
float32
),
})
image
=
features
[
'image/encoded'
]
if
isinstance
(
image
,
tf
.
SparseTensor
):
image
=
tf
.
sparse
.
to_dense
(
image
)
gt_mask
=
features
[
'image/segmentation/mask'
]
if
isinstance
(
gt_mask
,
tf
.
SparseTensor
):
gt_mask
=
tf
.
sparse
.
to_dense
(
gt_mask
)
image_size
,
label_size
=
self
.
get_input_shapes
(
params
)
image
=
tf
.
reshape
(
image
,
image_size
)
gt_mask
=
tf
.
reshape
(
gt_mask
,
label_size
)
image
=
tf
.
cast
(
image
,
dtype
=
params
.
dtype
)
gt_mask
=
tf
.
cast
(
gt_mask
,
dtype
=
params
.
dtype
)
return
image
,
gt_mask
return
_parser
def
get_input_shapes
(
self
,
params
):
image_size
=
params
.
input_image_size
+
[
params
.
num_channels
]
label_size
=
params
.
input_image_size
+
[
params
.
num_classes
]
return
image_size
,
label_size
def
__call__
(
self
,
input_pipeline_context
=
None
):
"""Generates features and labels for training or evaluation.
This uses the input pipeline based approach using file name queue
to read data so that entire data is not loaded in memory.
Args:
input_pipeline_context: Context used by distribution strategy to
shard dataset across workers.
Returns:
tf.data.Dataset
"""
params
=
self
.
_params
batch_size
=
(
params
.
train_batch_size
if
self
.
_is_training
else
params
.
eval_batch_size
)
dataset
=
tf
.
data
.
Dataset
.
list_files
(
self
.
_file_pattern
,
shuffle
=
self
.
_is_training
)
# Shard dataset when there are more than 1 workers in training.
if
input_pipeline_context
:
batch_size
=
input_pipeline_context
.
get_per_replica_batch_size
(
batch_size
)
if
input_pipeline_context
.
num_input_pipelines
>
1
:
dataset
=
dataset
.
shard
(
input_pipeline_context
.
num_input_pipelines
,
input_pipeline_context
.
input_pipeline_id
)
if
self
.
_is_training
:
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
apply
(
tf
.
data
.
experimental
.
parallel_interleave
(
lambda
file_name
:
self
.
_dataset_fn
(
file_name
).
prefetch
(
1
),
cycle_length
=
32
,
sloppy
=
self
.
_is_training
))
if
self
.
_is_training
:
dataset
=
dataset
.
shuffle
(
64
)
# Parses the fetched records to input tensors for model function.
dataset
=
dataset
.
map
(
self
.
_parser_fn
,
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
True
)
dataset
=
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
class
LiverInput
(
BaseInput
):
"""Input function of Liver Segmentation data set."""
def
create_parser_fn
(
self
,
params
):
"""Create parse fn to extract tensors from tf.Example."""
def
_decode_liver_example
(
serialized_example
):
"""Parses a single tf.Example into image and label tensors."""
features
=
{}
features
[
'image/ct_image'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
string
)
features
[
'image/label'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
string
)
parsed
=
tf
.
io
.
parse_single_example
(
serialized
=
serialized_example
,
features
=
features
)
# Here, assumes the `image` is normalized to [0, 1] of type float32 and
# the `label` is a binary matrix, whose last dimension is one_hot encoded
# labels.
# The dtype of `label` can be either float32 or int64.
image
=
tf
.
io
.
decode_raw
(
parsed
[
'image/ct_image'
],
tf
.
as_dtype
(
tf
.
float32
))
label
=
tf
.
io
.
decode_raw
(
parsed
[
'image/label'
],
tf
.
as_dtype
(
params
.
label_dtype
))
image_size
=
params
.
input_image_size
+
[
params
.
num_channels
]
image
=
tf
.
reshape
(
image
,
image_size
)
label_size
=
params
.
input_image_size
+
[
params
.
num_classes
]
label
=
tf
.
reshape
(
label
,
label_size
)
if
self
.
_is_training
and
params
.
use_index_label_in_train
:
# Use class index for labels and remove the channel dim (#channels=1).
channel_dim
=
-
1
label
=
tf
.
argmax
(
input
=
label
,
axis
=
channel_dim
,
output_type
=
tf
.
int32
)
image
=
tf
.
cast
(
image
,
dtype
=
params
.
dtype
)
label
=
tf
.
cast
(
label
,
dtype
=
params
.
dtype
)
# TPU doesn't support tf.int64 well, use tf.int32 directly.
if
label
.
dtype
==
tf
.
int64
:
label
=
tf
.
cast
(
label
,
dtype
=
tf
.
int32
)
return
image
,
label
return
_decode_liver_example
def
get_input_shapes
(
self
,
params
):
image_size
=
params
.
input_image_size
+
[
params
.
num_channels
]
if
self
.
_is_training
and
params
.
use_index_label_in_train
:
label_size
=
params
.
input_image_size
else
:
label_size
=
params
.
input_image_size
+
[
params
.
num_classes
]
return
image_size
,
label_size
official/vision/segmentation/unet_main.py
deleted
100644 → 0
View file @
d495e481
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r
"""Training script for UNet-3D."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
functools
import
os
from
absl
import
app
from
absl
import
flags
import
numpy
as
np
import
tensorflow
as
tf
from
official.modeling.hyperparams
import
params_dict
from
official.utils
import
hyperparams_flags
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
from
official.vision.segmentation
import
unet_config
from
official.vision.segmentation
import
unet_data
from
official.vision.segmentation
import
unet_metrics
from
official.vision.segmentation
import
unet_model
as
unet_model_lib
def
define_unet3d_flags
():
"""Defines flags for training 3D Unet."""
hyperparams_flags
.
initialize_common_flags
()
flags
.
DEFINE_enum
(
'distribution_strategy'
,
'tpu'
,
[
'tpu'
,
'mirrored'
],
'Distribution Strategy type to use for training. `tpu` uses TPUStrategy '
'for running on TPUs, `mirrored` uses GPUs with single host.'
)
flags
.
DEFINE_integer
(
'steps_per_loop'
,
50
,
'Number of steps to execute in a loop for performance optimization.'
)
flags
.
DEFINE_integer
(
'checkpoint_interval'
,
100
,
'Minimum step interval between two checkpoints.'
)
flags
.
DEFINE_integer
(
'epochs'
,
10
,
'Number of epochs to run training.'
)
flags
.
DEFINE_string
(
'gcp_project'
,
default
=
None
,
help
=
'Project name for the Cloud TPU-enabled project. If not specified, we '
'will attempt to automatically detect the GCE project from metadata.'
)
flags
.
DEFINE_string
(
'eval_checkpoint_dir'
,
default
=
None
,
help
=
'Directory for reading checkpoint file when `mode` == `eval`.'
)
flags
.
DEFINE_multi_integer
(
'input_partition_dims'
,
[
1
],
'A list that describes the partition dims for all the tensors.'
)
flags
.
DEFINE_string
(
'mode'
,
'train'
,
'Mode to run: train or eval or train_and_eval '
'(default: train)'
)
flags
.
DEFINE_string
(
'training_file_pattern'
,
None
,
'Location of the train data.'
)
flags
.
DEFINE_string
(
'eval_file_pattern'
,
None
,
'Location of ther eval data'
)
flags
.
DEFINE_float
(
'lr_init_value'
,
0.0001
,
'Initial learning rate.'
)
flags
.
DEFINE_float
(
'lr_decay_rate'
,
0.9
,
'Learning rate decay rate.'
)
flags
.
DEFINE_integer
(
'lr_decay_steps'
,
100
,
'Learning rate decay steps.'
)
def
save_params
(
params
):
"""Save parameters to config files if model_dir is defined."""
model_dir
=
params
.
model_dir
assert
model_dir
is
not
None
if
not
tf
.
io
.
gfile
.
exists
(
model_dir
):
tf
.
io
.
gfile
.
makedirs
(
model_dir
)
file_name
=
os
.
path
.
join
(
model_dir
,
'params.yaml'
)
params_dict
.
save_params_dict_to_yaml
(
params
,
file_name
)
def
extract_params
(
flags_obj
):
"""Extract configuration parameters for training and evaluation."""
params
=
params_dict
.
ParamsDict
(
unet_config
.
UNET_CONFIG
,
unet_config
.
UNET_RESTRICTIONS
)
params
=
params_dict
.
override_params_dict
(
params
,
flags_obj
.
config_file
,
is_strict
=
False
)
if
flags_obj
.
training_file_pattern
:
params
.
override
({
'training_file_pattern'
:
flags_obj
.
training_file_pattern
},
is_strict
=
True
)
if
flags_obj
.
eval_file_pattern
:
params
.
override
({
'eval_file_pattern'
:
flags_obj
.
eval_file_pattern
},
is_strict
=
True
)
train_epoch_steps
=
params
.
train_item_count
//
params
.
train_batch_size
eval_epoch_steps
=
params
.
eval_item_count
//
params
.
eval_batch_size
params
.
override
(
{
'model_dir'
:
flags_obj
.
model_dir
,
'eval_checkpoint_dir'
:
flags_obj
.
eval_checkpoint_dir
,
'mode'
:
flags_obj
.
mode
,
'distribution_strategy'
:
flags_obj
.
distribution_strategy
,
'tpu'
:
flags_obj
.
tpu
,
'num_gpus'
:
flags_obj
.
num_gpus
,
'init_learning_rate'
:
flags_obj
.
lr_init_value
,
'lr_decay_rate'
:
flags_obj
.
lr_decay_rate
,
'lr_decay_steps'
:
train_epoch_steps
,
'train_epoch_steps'
:
train_epoch_steps
,
'eval_epoch_steps'
:
eval_epoch_steps
,
'steps_per_loop'
:
flags_obj
.
steps_per_loop
,
'epochs'
:
flags_obj
.
epochs
,
'checkpoint_interval'
:
flags_obj
.
checkpoint_interval
,
},
is_strict
=
False
)
params
.
validate
()
params
.
lock
()
return
params
def
unet3d_callbacks
(
params
,
checkpoint_manager
=
None
):
"""Custom callbacks during training."""
tensorboard_callback
=
tf
.
keras
.
callbacks
.
TensorBoard
(
log_dir
=
params
.
model_dir
)
if
checkpoint_manager
:
checkpoint_callback
=
keras_utils
.
SimpleCheckpoint
(
checkpoint_manager
)
return
[
tensorboard_callback
,
checkpoint_callback
]
else
:
return
[
tensorboard_callback
]
def
get_computation_shape_for_model_parallelism
(
input_partition_dims
):
"""Return computation shape to be used for TPUStrategy spatial partition."""
num_logical_devices
=
np
.
prod
(
input_partition_dims
)
if
num_logical_devices
==
1
:
return
[
1
,
1
,
1
,
1
]
if
num_logical_devices
==
2
:
return
[
1
,
1
,
1
,
2
]
if
num_logical_devices
==
4
:
return
[
1
,
2
,
1
,
2
]
if
num_logical_devices
==
8
:
return
[
2
,
2
,
1
,
2
]
if
num_logical_devices
==
16
:
return
[
4
,
2
,
1
,
2
]
raise
ValueError
(
'Unsupported number of spatial partition configuration.'
)
def
create_distribution_strategy
(
params
):
"""Creates distribution strategy to use for computation."""
if
params
.
input_partition_dims
is
not
None
:
if
params
.
distribution_strategy
!=
'tpu'
:
raise
ValueError
(
'Spatial partitioning is only supported '
'for TPUStrategy.'
)
# When `input_partition_dims` is specified create custom TPUStrategy
# instance with computation shape for model parallelism.
resolver
=
tf
.
distribute
.
cluster_resolver
.
TPUClusterResolver
(
tpu
=
params
.
tpu
)
if
params
.
tpu
not
in
(
''
,
'local'
):
tf
.
config
.
experimental_connect_to_cluster
(
resolver
)
topology
=
tf
.
tpu
.
experimental
.
initialize_tpu_system
(
resolver
)
num_replicas
=
resolver
.
get_tpu_system_metadata
().
num_cores
//
np
.
prod
(
params
.
input_partition_dims
)
device_assignment
=
tf
.
tpu
.
experimental
.
DeviceAssignment
.
build
(
topology
,
num_replicas
=
num_replicas
,
computation_shape
=
get_computation_shape_for_model_parallelism
(
params
.
input_partition_dims
))
return
tf
.
distribute
.
experimental
.
TPUStrategy
(
resolver
,
device_assignment
=
device_assignment
)
return
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
distribution_strategy
,
tpu_address
=
params
.
tpu
,
num_gpus
=
params
.
num_gpus
)
def
get_train_dataset
(
params
,
ctx
=
None
):
"""Returns training dataset."""
return
unet_data
.
LiverInput
(
params
.
training_file_pattern
,
params
,
is_training
=
True
)(
ctx
)
def
get_eval_dataset
(
params
,
ctx
=
None
):
"""Returns evaluation dataset."""
return
unet_data
.
LiverInput
(
params
.
training_file_pattern
,
params
,
is_training
=
False
)(
ctx
)
def
expand_1d
(
data
):
"""Expands 1-dimensional `Tensor`s into 2-dimensional `Tensor`s."""
def
_expand_single_1d_tensor
(
t
):
if
(
isinstance
(
t
,
tf
.
Tensor
)
and
isinstance
(
t
.
shape
,
tf
.
TensorShape
)
and
t
.
shape
.
rank
==
1
):
return
tf
.
expand_dims
(
t
,
axis
=-
1
)
return
t
return
tf
.
nest
.
map_structure
(
_expand_single_1d_tensor
,
data
)
def
train_step
(
train_fn
,
input_partition_dims
,
data
):
"""The logic for one training step with spatial partitioning."""
# Keras expects rank 2 inputs. As so, expand single rank inputs.
data
=
expand_1d
(
data
)
x
,
y
,
sample_weight
=
tf
.
keras
.
utils
.
unpack_x_y_sample_weight
(
data
)
if
input_partition_dims
:
strategy
=
tf
.
distribute
.
get_strategy
()
x
=
strategy
.
experimental_split_to_logical_devices
(
x
,
input_partition_dims
)
y
=
strategy
.
experimental_split_to_logical_devices
(
y
,
input_partition_dims
)
partitioned_data
=
tf
.
keras
.
utils
.
pack_x_y_sample_weight
(
x
,
y
,
sample_weight
)
return
train_fn
(
partitioned_data
)
def
test_step
(
test_fn
,
input_partition_dims
,
data
):
"""The logic for one testing step with spatial partitioning."""
# Keras expects rank 2 inputs. As so, expand single rank inputs.
data
=
expand_1d
(
data
)
x
,
y
,
sample_weight
=
tf
.
keras
.
utils
.
unpack_x_y_sample_weight
(
data
)
if
input_partition_dims
:
strategy
=
tf
.
distribute
.
get_strategy
()
x
=
strategy
.
experimental_split_to_logical_devices
(
x
,
input_partition_dims
)
y
=
strategy
.
experimental_split_to_logical_devices
(
y
,
input_partition_dims
)
partitioned_data
=
tf
.
keras
.
utils
.
pack_x_y_sample_weight
(
x
,
y
,
sample_weight
)
return
test_fn
(
partitioned_data
)
def
train
(
params
,
strategy
,
unet_model
,
train_input_fn
,
eval_input_fn
):
"""Trains 3D Unet model."""
assert
tf
.
distribute
.
has_strategy
()
# Override Keras Model's train_step() and test_step() function so
# that inputs are spatially partitioned.
# Note that is `predict()` API is used, then `predict_step()` should also
# be overriden.
unet_model
.
train_step
=
functools
.
partial
(
train_step
,
unet_model
.
train_step
,
params
.
input_partition_dims
)
unet_model
.
test_step
=
functools
.
partial
(
test_step
,
unet_model
.
test_step
,
params
.
input_partition_dims
)
optimizer
=
unet_model_lib
.
create_optimizer
(
params
.
init_learning_rate
,
params
)
loss_fn
=
unet_metrics
.
get_loss_fn
(
params
.
mode
,
params
)
unet_model
.
compile
(
loss
=
loss_fn
,
optimizer
=
optimizer
,
metrics
=
[
unet_metrics
.
metric_accuracy
],
experimental_steps_per_execution
=
params
.
steps_per_loop
)
train_ds
=
strategy
.
experimental_distribute_datasets_from_function
(
train_input_fn
)
eval_ds
=
strategy
.
experimental_distribute_datasets_from_function
(
eval_input_fn
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
unet_model
)
train_epoch_steps
=
params
.
train_item_count
//
params
.
train_batch_size
eval_epoch_steps
=
params
.
eval_item_count
//
params
.
eval_batch_size
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
directory
=
params
.
model_dir
,
max_to_keep
=
10
,
step_counter
=
unet_model
.
optimizer
.
iterations
,
checkpoint_interval
=
params
.
checkpoint_interval
)
checkpoint_manager
.
restore_or_initialize
()
train_result
=
unet_model
.
fit
(
x
=
train_ds
,
epochs
=
params
.
epochs
,
steps_per_epoch
=
train_epoch_steps
,
validation_data
=
eval_ds
,
validation_steps
=
eval_epoch_steps
,
callbacks
=
unet3d_callbacks
(
params
,
checkpoint_manager
))
return
train_result
def
evaluate
(
params
,
strategy
,
unet_model
,
input_fn
):
"""Reads from checkpoint and evaluate 3D Unet model."""
assert
tf
.
distribute
.
has_strategy
()
unet_model
.
compile
(
metrics
=
[
unet_metrics
.
metric_accuracy
],
experimental_steps_per_execution
=
params
.
steps_per_loop
)
# Override test_step() function so that inputs are spatially partitioned.
unet_model
.
test_step
=
functools
.
partial
(
test_step
,
unet_model
.
test_step
,
params
.
input_partition_dims
)
# Load checkpoint for evaluation.
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
unet_model
)
checkpoint_path
=
tf
.
train
.
latest_checkpoint
(
params
.
eval_checkpoint_dir
)
status
=
checkpoint
.
restore
(
checkpoint_path
)
status
.
assert_existing_objects_matched
()
eval_ds
=
strategy
.
experimental_distribute_datasets_from_function
(
input_fn
)
eval_epoch_steps
=
params
.
eval_item_count
//
params
.
eval_batch_size
eval_result
=
unet_model
.
evaluate
(
x
=
eval_ds
,
steps
=
eval_epoch_steps
,
callbacks
=
unet3d_callbacks
(
params
))
return
eval_result
def
main
(
_
):
params
=
extract_params
(
flags
.
FLAGS
)
assert
params
.
mode
in
{
'train'
,
'eval'
},
'only support train and eval'
save_params
(
params
)
input_dtype
=
params
.
dtype
if
input_dtype
==
'float16'
or
input_dtype
==
'bfloat16'
:
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
'mixed_bfloat16'
if
input_dtype
==
'bfloat16'
else
'mixed_float16'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
strategy
=
create_distribution_strategy
(
params
)
with
strategy
.
scope
():
unet_model
=
unet_model_lib
.
build_unet_model
(
params
)
if
params
.
mode
==
'train'
:
train
(
params
,
strategy
,
unet_model
,
functools
.
partial
(
get_train_dataset
,
params
),
functools
.
partial
(
get_eval_dataset
,
params
))
elif
params
.
mode
==
'eval'
:
evaluate
(
params
,
strategy
,
unet_model
,
functools
.
partial
(
get_eval_dataset
,
params
))
else
:
raise
Exception
(
'Only `train` mode and `eval` mode are supported.'
)
if
__name__
==
'__main__'
:
define_unet3d_flags
()
app
.
run
(
main
)
official/vision/segmentation/unet_main_test.py
deleted
100644 → 0
View file @
d495e481
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
tempfile
from
absl
import
flags
from
absl.testing
import
flagsaver
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.contrib
import
cluster_resolver
as
contrib_cluster_resolver
from
tensorflow.contrib.tpu.python.tpu
import
device_assignment
as
device_lib
from
tensorflow.python.distribute
import
tpu_strategy
as
tpu_strategy_lib
from
tensorflow.python.tpu
import
tpu_strategy_util
from
official.modeling.hyperparams
import
params_dict
from
official.vision.segmentation
import
unet_config
from
official.vision.segmentation
import
unet_main
as
unet_main_lib
from
official.vision.segmentation
import
unet_metrics
from
official.vision.segmentation
import
unet_model
as
unet_model_lib
FLAGS
=
flags
.
FLAGS
def
create_fake_input_fn
(
params
,
features_size
,
labels_size
,
use_bfloat16
=
False
):
"""Returns fake input function for testing."""
def
fake_data_input_fn
(
unused_ctx
=
None
):
"""An input function for generating fake data."""
batch_size
=
params
.
train_batch_size
features
=
np
.
random
.
rand
(
64
,
*
features_size
)
labels
=
np
.
random
.
randint
(
2
,
size
=
[
64
]
+
labels_size
)
# Convert the inputs to a Dataset.
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
((
features
,
labels
))
def
_assign_dtype
(
features
,
labels
):
if
use_bfloat16
:
features
=
tf
.
cast
(
features
,
tf
.
bfloat16
)
labels
=
tf
.
cast
(
labels
,
tf
.
bfloat16
)
else
:
features
=
tf
.
cast
(
features
,
tf
.
float32
)
labels
=
tf
.
cast
(
labels
,
tf
.
float32
)
return
features
,
labels
# Shuffle, repeat, and batch the examples.
dataset
=
dataset
.
map
(
_assign_dtype
)
dataset
=
dataset
.
shuffle
(
64
).
repeat
()
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
True
)
dataset
=
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
# Return the dataset.
return
dataset
return
fake_data_input_fn
class
UnetMainTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
UnetMainTest
,
self
).
setUp
()
self
.
_model_dir
=
os
.
path
.
join
(
tempfile
.
mkdtemp
(),
'model_dir'
)
tf
.
io
.
gfile
.
makedirs
(
self
.
_model_dir
)
def
tearDown
(
self
):
tf
.
io
.
gfile
.
rmtree
(
self
.
_model_dir
)
super
(
UnetMainTest
,
self
).
tearDown
()
@
flagsaver
.
flagsaver
def
testUnet3DModel
(
self
):
FLAGS
.
tpu
=
''
FLAGS
.
mode
=
'train'
params
=
params_dict
.
ParamsDict
(
unet_config
.
UNET_CONFIG
,
unet_config
.
UNET_RESTRICTIONS
)
params
.
override
(
{
'input_image_size'
:
[
64
,
64
,
64
],
'train_item_count'
:
4
,
'eval_item_count'
:
4
,
'train_batch_size'
:
2
,
'eval_batch_size'
:
2
,
'batch_size'
:
2
,
'num_base_filters'
:
16
,
'dtype'
:
'bfloat16'
,
'depth'
:
1
,
'train_steps'
:
2
,
'eval_steps'
:
2
,
'mode'
:
FLAGS
.
mode
,
'tpu'
:
FLAGS
.
tpu
,
'num_gpus'
:
0
,
'checkpoint_interval'
:
1
,
'use_tpu'
:
True
,
'input_partition_dims'
:
None
,
},
is_strict
=
False
)
params
.
validate
()
params
.
lock
()
image_size
=
params
.
input_image_size
+
[
params
.
num_channels
]
label_size
=
params
.
input_image_size
+
[
params
.
num_classes
]
input_fn
=
create_fake_input_fn
(
params
,
features_size
=
image_size
,
labels_size
=
label_size
)
resolver
=
contrib_cluster_resolver
.
TPUClusterResolver
(
tpu
=
params
.
tpu
)
topology
=
tpu_strategy_util
.
initialize_tpu_system
(
resolver
)
device_assignment
=
None
if
params
.
input_partition_dims
is
not
None
:
assert
np
.
prod
(
params
.
input_partition_dims
)
==
2
,
'invalid unit test configuration'
computation_shape
=
[
1
,
1
,
1
,
2
]
partition_dimension
=
params
.
input_partition_dims
num_replicas
=
resolver
.
get_tpu_system_metadata
().
num_cores
//
np
.
prod
(
partition_dimension
)
device_assignment
=
device_lib
.
device_assignment
(
topology
,
computation_shape
=
computation_shape
,
num_replicas
=
num_replicas
)
strategy
=
tpu_strategy_lib
.
TPUStrategy
(
resolver
,
device_assignment
=
device_assignment
)
with
strategy
.
scope
():
model
=
unet_model_lib
.
build_unet_model
(
params
)
optimizer
=
unet_model_lib
.
create_optimizer
(
params
.
init_learning_rate
,
params
)
loss_fn
=
unet_metrics
.
get_loss_fn
(
params
.
mode
,
params
)
model
.
compile
(
loss
=
loss_fn
,
optimizer
=
optimizer
,
metrics
=
[
loss_fn
])
eval_ds
=
input_fn
()
iterator
=
iter
(
eval_ds
)
image
,
_
=
next
(
iterator
)
logits
=
model
(
image
,
training
=
False
)
self
.
assertEqual
(
logits
.
shape
[
1
:],
params
.
input_image_size
+
[
3
])
@
parameterized
.
parameters
(
{
'use_mlir'
:
True
,
'dtype'
:
'bfloat16'
,
'input_partition_dims'
:
None
,
},
{
'use_mlir'
:
False
,
'dtype'
:
'bfloat16'
,
'input_partition_dims'
:
None
,
},
{
'use_mlir'
:
True
,
'dtype'
:
'bfloat16'
,
'input_partition_dims'
:
None
,
},
{
'use_mlir'
:
False
,
'dtype'
:
'bfloat16'
,
'input_partition_dims'
:
None
,
},
{
'use_mlir'
:
True
,
'dtype'
:
'bfloat16'
,
'input_partition_dims'
:
[
1
,
2
,
1
,
1
,
1
],
},
{
'use_mlir'
:
False
,
'dtype'
:
'bfloat16'
,
'input_partition_dims'
:
[
1
,
2
,
1
,
1
,
1
],
},
{
'use_mlir'
:
True
,
'dtype'
:
'bfloat16'
,
'input_partition_dims'
:
[
1
,
2
,
1
,
1
,
1
],
},
{
'use_mlir'
:
False
,
'dtype'
:
'bfloat16'
,
'input_partition_dims'
:
[
1
,
2
,
1
,
1
,
1
]
})
@
flagsaver
.
flagsaver
def
testUnetTrain
(
self
,
use_mlir
,
dtype
,
input_partition_dims
):
FLAGS
.
tpu
=
''
FLAGS
.
mode
=
'train'
if
use_mlir
:
tf
.
config
.
experimental
.
enable_mlir_bridge
()
params
=
params_dict
.
ParamsDict
(
unet_config
.
UNET_CONFIG
,
unet_config
.
UNET_RESTRICTIONS
)
params
.
override
(
{
'model_dir'
:
self
.
_model_dir
,
'input_image_size'
:
[
8
,
8
,
8
],
'train_item_count'
:
2
,
'eval_item_count'
:
2
,
'train_batch_size'
:
2
,
'eval_batch_size'
:
2
,
'batch_size'
:
2
,
'num_base_filters'
:
1
,
'dtype'
:
'bfloat16'
,
'depth'
:
1
,
'epochs'
:
1
,
'checkpoint_interval'
:
1
,
'train_steps'
:
1
,
'eval_steps'
:
1
,
'mode'
:
FLAGS
.
mode
,
'tpu'
:
FLAGS
.
tpu
,
'use_tpu'
:
True
,
'num_gpus'
:
0
,
'distribution_strategy'
:
'tpu'
,
'steps_per_loop'
:
1
,
'input_partition_dims'
:
input_partition_dims
,
},
is_strict
=
False
)
params
.
validate
()
params
.
lock
()
image_size
=
params
.
input_image_size
+
[
params
.
num_channels
]
label_size
=
params
.
input_image_size
+
[
params
.
num_classes
]
input_fn
=
create_fake_input_fn
(
params
,
features_size
=
image_size
,
labels_size
=
label_size
)
input_dtype
=
params
.
dtype
if
input_dtype
==
'float16'
or
input_dtype
==
'bfloat16'
:
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
'mixed_bfloat16'
if
input_dtype
==
'bfloat16'
else
'mixed_float16'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
strategy
=
unet_main_lib
.
create_distribution_strategy
(
params
)
with
strategy
.
scope
():
unet_model
=
unet_model_lib
.
build_unet_model
(
params
)
unet_main_lib
.
train
(
params
,
strategy
,
unet_model
,
input_fn
,
input_fn
)
if
__name__
==
'__main__'
:
unet_main_lib
.
define_unet3d_flags
()
tf
.
test
.
main
()
official/vision/segmentation/unet_metrics.py
deleted
100644 → 0
View file @
d495e481
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define metrics for the UNet 3D Model."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
tensorflow
as
tf
def
dice
(
y_true
,
y_pred
,
axis
=
(
1
,
2
,
3
,
4
)):
"""DICE coefficient.
Taha AA, Hanbury A. Metrics for evaluating 3D medical image segmentation:
analysis, selection, and tool. BMC Med Imaging. 2015;15:29. Published
2015
Aug 12. doi:10.1186/s12880-015-0068-x
Implemented according to
https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4533825/#Equ6
Args:
y_true: the ground truth matrix. Shape [batch_size, x, y, z, num_classes].
y_pred: the prediction matrix. Shape [batch_size, x, y, z, num_classes].
axis: axises of features.
Returns:
DICE coefficient.
"""
y_true
=
tf
.
cast
(
y_true
,
y_pred
.
dtype
)
eps
=
tf
.
keras
.
backend
.
epsilon
()
intersection
=
tf
.
reduce_sum
(
input_tensor
=
y_true
*
y_pred
,
axis
=
axis
)
summation
=
tf
.
reduce_sum
(
input_tensor
=
y_true
,
axis
=
axis
)
+
tf
.
reduce_sum
(
input_tensor
=
y_pred
,
axis
=
axis
)
return
(
2
*
intersection
+
eps
)
/
(
summation
+
eps
)
def
generalized_dice
(
y_true
,
y_pred
,
axis
=
(
1
,
2
,
3
)):
"""Generalized Dice coefficient, for multi-class predictions.
For output of a multi-class model, where the shape of the output is
(batch, x, y, z, n_classes), the axis argument should be (1, 2, 3).
Args:
y_true: the ground truth matrix. Shape [batch_size, x, y, z, num_classes].
y_pred: the prediction matrix. Shape [batch_size, x, y, z, num_classes].
axis: axises of features.
Returns:
DICE coefficient.
"""
y_true
=
tf
.
cast
(
y_true
,
y_pred
.
dtype
)
if
y_true
.
get_shape
().
ndims
<
2
or
y_pred
.
get_shape
().
ndims
<
2
:
raise
ValueError
(
'y_true and y_pred must be at least rank 2.'
)
epsilon
=
tf
.
keras
.
backend
.
epsilon
()
w
=
tf
.
math
.
reciprocal
(
tf
.
square
(
tf
.
reduce_sum
(
y_true
,
axis
=
axis
))
+
epsilon
)
num
=
2
*
tf
.
reduce_sum
(
w
*
tf
.
reduce_sum
(
y_true
*
y_pred
,
axis
=
axis
),
axis
=-
1
)
den
=
tf
.
reduce_sum
(
w
*
tf
.
reduce_sum
(
y_true
+
y_pred
,
axis
=
axis
),
axis
=-
1
)
return
(
num
+
epsilon
)
/
(
den
+
epsilon
)
def
hamming
(
y_true
,
y_pred
,
axis
=
(
1
,
2
,
3
)):
"""Hamming distance.
Args:
y_true: the ground truth matrix. Shape [batch_size, x, y, z].
y_pred: the prediction matrix. Shape [batch_size, x, y, z].
axis: a list, axises of the feature dimensions.
Returns:
Hamming distance value.
"""
y_true
=
tf
.
cast
(
y_true
,
y_pred
.
dtype
)
return
tf
.
reduce_mean
(
input_tensor
=
tf
.
not_equal
(
y_pred
,
y_true
),
axis
=
axis
)
def
jaccard
(
y_true
,
y_pred
,
axis
=
(
1
,
2
,
3
,
4
)):
"""Jaccard Similarity.
Taha AA, Hanbury A. Metrics for evaluating 3D medical image segmentation:
analysis, selection, and tool. BMC Med Imaging. 2015;15:29. Published
2015
Aug 12. doi:10.1186/s12880-015-0068-x
Implemented according to
https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4533825/#Equ7
Args:
y_true: the ground truth matrix. Shape [batch_size, x, y, z, num_classes].
y_pred: the prediction matrix. Shape [batch_size, x, y, z, num_classes].
axis: axises of features.
Returns:
Jaccard similarity.
"""
y_true
=
tf
.
cast
(
y_true
,
y_pred
.
dtype
)
eps
=
tf
.
keras
.
backend
.
epsilon
()
intersection
=
tf
.
reduce_sum
(
input_tensor
=
y_true
*
y_pred
,
axis
=
axis
)
union
=
tf
.
reduce_sum
(
y_true
,
axis
=
axis
)
+
tf
.
reduce_sum
(
y_pred
,
axis
=
axis
)
return
(
intersection
+
eps
)
/
(
union
-
intersection
+
eps
)
def
tversky
(
y_true
,
y_pred
,
axis
=
(
1
,
2
,
3
),
alpha
=
0.3
,
beta
=
0.7
):
"""Tversky similarity.
Args:
y_true: the ground truth matrix. Shape [batch_size, x, y, z, num_classes].
y_pred: the prediction matrix. Shape [batch_size, x, y, z, num_classes].
axis: axises of spatial dimensions.
alpha: weight of the prediction.
beta: weight of the groundtruth.
Returns:
Tversky similarity coefficient.
"""
y_true
=
tf
.
cast
(
y_true
,
y_pred
.
dtype
)
if
y_true
.
get_shape
().
ndims
<
2
or
y_pred
.
get_shape
().
ndims
<
2
:
raise
ValueError
(
'y_true and y_pred must be at least rank 2.'
)
eps
=
tf
.
keras
.
backend
.
epsilon
()
num
=
tf
.
reduce_sum
(
input_tensor
=
y_pred
*
y_true
,
axis
=
axis
)
den
=
(
num
+
alpha
*
tf
.
reduce_sum
(
y_pred
*
(
1
-
y_true
),
axis
=
axis
)
+
beta
*
tf
.
reduce_sum
((
1
-
y_pred
)
*
y_true
,
axis
=
axis
))
# Sum over classes.
return
tf
.
reduce_sum
(
input_tensor
=
(
num
+
eps
)
/
(
den
+
eps
),
axis
=-
1
)
def
adaptive_dice32
(
y_true
,
y_pred
,
data_format
=
'channels_last'
):
"""Adaptive dice metric.
Args:
y_true: the ground truth matrix. Shape [batch_size, x, y, z, num_classes].
y_pred: the prediction matrix. Shape [batch_size, x, y, z, num_classes].
data_format: channel last of channel first.
Returns:
Adaptive dice value.
"""
epsilon
=
10
**-
7
y_true
=
tf
.
cast
(
y_true
,
dtype
=
y_pred
.
dtype
)
# Determine axes to pass to tf.reduce_sum
if
data_format
==
'channels_last'
:
ndim
=
len
(
y_pred
.
shape
)
reduction_axes
=
list
(
range
(
ndim
-
1
))
else
:
reduction_axes
=
1
# Calculate intersections and unions per class
intersections
=
tf
.
reduce_sum
(
y_true
*
y_pred
,
axis
=
reduction_axes
)
unions
=
tf
.
reduce_sum
(
y_true
+
y_pred
,
axis
=
reduction_axes
)
# Calculate Dice scores per class
dice_scores
=
2.0
*
(
intersections
+
epsilon
)
/
(
unions
+
epsilon
)
# Calculate weights based on Dice scores
weights
=
tf
.
exp
(
-
1.0
*
dice_scores
)
# Multiply weights by corresponding scores and get sum
weighted_dice
=
tf
.
reduce_sum
(
weights
*
dice_scores
)
# Calculate normalization factor
norm_factor
=
tf
.
size
(
input
=
dice_scores
,
out_type
=
tf
.
float32
)
*
tf
.
exp
(
-
1.0
)
weighted_dice
=
tf
.
cast
(
weighted_dice
,
dtype
=
tf
.
float32
)
# Return 1 - adaptive Dice score
return
1
-
(
weighted_dice
/
norm_factor
)
def
assert_shape_equal
(
pred_shape
,
label_shape
):
"""Asserts that `pred_shape` and `label_shape` is equal."""
assert
(
label_shape
==
pred_shape
),
'pred. shape {} is not equal to label shape {}'
.
format
(
label_shape
,
pred_shape
)
def
get_loss_fn
(
mode
,
params
):
"""Return loss_fn for unet training.
Args:
mode: training or eval. This is a legacy parameter from TF1.
params: unet configuration parameter.
Returns:
loss_fn.
"""
def
loss_fn
(
y_true
,
y_pred
):
"""Returns scalar loss from labels and netowrk outputs."""
loss
=
None
label_shape
=
y_true
.
get_shape
().
as_list
()
pred_shape
=
y_pred
.
get_shape
().
as_list
()
assert_shape_equal
(
label_shape
,
pred_shape
)
if
params
.
loss
==
'adaptive_dice32'
:
loss
=
adaptive_dice32
(
y_true
,
y_pred
)
elif
params
.
loss
==
'cross_entropy'
:
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
and
params
.
use_index_label_in_train
:
labels_idx
=
tf
.
cast
(
y_true
,
dtype
=
tf
.
int32
)
else
:
# Use one-hot label representation, convert to label index.
labels_idx
=
tf
.
argmax
(
input
=
y_true
,
axis
=-
1
,
output_type
=
tf
.
int32
)
y_pred
=
tf
.
cast
(
y_pred
,
dtype
=
tf
.
float32
)
loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels_idx
,
y_pred
,
from_logits
=
False
)
else
:
raise
Exception
(
'Unexpected loss type'
)
return
loss
return
loss_fn
def
metric_accuracy
(
labels
,
predictions
):
"""Returns accuracy metric of model outputs.
Args:
labels: ground truth tensor (labels).
predictions: network output (logits)
Returns:
metric_fn.
"""
if
labels
.
dtype
==
tf
.
bfloat16
:
labels
=
tf
.
cast
(
labels
,
tf
.
float32
)
if
predictions
.
dtype
==
tf
.
bfloat16
:
predictions
=
tf
.
cast
(
predictions
,
tf
.
float32
)
return
tf
.
keras
.
backend
.
mean
(
tf
.
keras
.
backend
.
equal
(
tf
.
argmax
(
input
=
labels
,
axis
=-
1
),
tf
.
argmax
(
input
=
predictions
,
axis
=-
1
)))
def
metric_ce
(
labels
,
predictions
):
"""Returns categorical crossentropy given outputs and labels.
Args:
labels: ground truth tensor (labels).
predictions: network output (logits)
Returns:
metric_fn.
"""
if
labels
.
dtype
==
tf
.
bfloat16
:
labels
=
tf
.
cast
(
labels
,
tf
.
float32
)
if
predictions
.
dtype
==
tf
.
bfloat16
:
predictions
=
tf
.
cast
(
predictions
,
tf
.
float32
)
return
tf
.
keras
.
losses
.
categorical_crossentropy
(
labels
,
predictions
,
from_logits
=
False
)
def
metric_dice
(
labels
,
predictions
):
"""Returns adaptive dice coefficient."""
if
labels
.
dtype
==
tf
.
bfloat16
:
labels
=
tf
.
cast
(
labels
,
tf
.
float32
)
if
predictions
.
dtype
==
tf
.
bfloat16
:
predictions
=
tf
.
cast
(
predictions
,
tf
.
float32
)
return
adaptive_dice32
(
labels
,
predictions
)
official/vision/segmentation/unet_model.py
deleted
100644 → 0
View file @
d495e481
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model definition for the TF2 Keras UNet 3D Model."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
tensorflow
as
tf
def
create_optimizer
(
init_learning_rate
,
params
):
"""Creates optimizer for training."""
learning_rate
=
tf
.
keras
.
optimizers
.
schedules
.
ExponentialDecay
(
initial_learning_rate
=
init_learning_rate
,
decay_steps
=
params
.
lr_decay_steps
,
decay_rate
=
params
.
lr_decay_rate
)
# TODO(hongjunchoi): Provide alternative optimizer options depending on model
# config parameters.
optimizer
=
tf
.
keras
.
optimizers
.
Adam
(
learning_rate
)
return
optimizer
def
create_convolution_block
(
input_layer
,
n_filters
,
batch_normalization
=
False
,
kernel
=
(
3
,
3
,
3
),
activation
=
tf
.
nn
.
relu
,
padding
=
'SAME'
,
strides
=
(
1
,
1
,
1
),
data_format
=
'channels_last'
,
instance_normalization
=
False
):
"""UNet convolution block.
Args:
input_layer: tf.Tensor, the input tensor.
n_filters: integer, the number of the output channels of the convolution.
batch_normalization: boolean, use batch normalization after the convolution.
kernel: kernel size of the convolution.
activation: Tensorflow activation layer to use. (default is 'relu')
padding: padding type of the convolution.
strides: strides of the convolution.
data_format: data format of the convolution. One of 'channels_first' or
'channels_last'.
instance_normalization: use Instance normalization. Exclusive with batch
normalization.
Returns:
The Tensor after apply the convolution block to the input.
"""
assert
instance_normalization
==
0
,
'TF 2.0 does not support inst. norm.'
layer
=
tf
.
keras
.
layers
.
Conv3D
(
filters
=
n_filters
,
kernel_size
=
kernel
,
strides
=
strides
,
padding
=
padding
,
data_format
=
data_format
,
activation
=
None
,
)(
inputs
=
input_layer
)
if
batch_normalization
:
layer
=
tf
.
keras
.
layers
.
BatchNormalization
(
axis
=
1
)(
inputs
=
layer
)
return
activation
(
layer
)
def
apply_up_convolution
(
inputs
,
num_filters
,
pool_size
,
kernel_size
=
(
2
,
2
,
2
),
strides
=
(
2
,
2
,
2
),
deconvolution
=
False
):
"""Apply up convolution on inputs.
Args:
inputs: input feature tensor.
num_filters: number of deconvolution output feature channels.
pool_size: pool size of the up-scaling.
kernel_size: kernel size of the deconvolution.
strides: strides of the deconvolution.
deconvolution: Use deconvolution or upsampling.
Returns:
The tensor of the up-scaled features.
"""
if
deconvolution
:
return
tf
.
keras
.
layers
.
Conv3DTranspose
(
filters
=
num_filters
,
kernel_size
=
kernel_size
,
strides
=
strides
)(
inputs
=
inputs
)
else
:
return
tf
.
keras
.
layers
.
UpSampling3D
(
size
=
pool_size
)(
inputs
)
def
unet3d_base
(
input_layer
,
pool_size
=
(
2
,
2
,
2
),
n_labels
=
1
,
deconvolution
=
False
,
depth
=
4
,
n_base_filters
=
32
,
batch_normalization
=
False
,
data_format
=
'channels_last'
):
"""Builds the 3D UNet Tensorflow model and return the last layer logits.
Args:
input_layer: the input Tensor.
pool_size: Pool size for the max pooling operations.
n_labels: Number of binary labels that the model is learning.
deconvolution: If set to True, will use transpose convolution(deconvolution)
instead of up-sampling. This increases the amount memory required during
training.
depth: indicates the depth of the U-shape for the model. The greater the
depth, the more max pooling layers will be added to the model. Lowering
the depth may reduce the amount of memory required for training.
n_base_filters: The number of filters that the first layer in the
convolution network will have. Following layers will contain a multiple of
this number. Lowering this number will likely reduce the amount of memory
required to train the model.
batch_normalization: boolean. True for use batch normalization after
convolution and before activation.
data_format: string, channel_last (default) or channel_first
Returns:
The last layer logits of 3D UNet.
"""
levels
=
[]
current_layer
=
input_layer
if
data_format
==
'channels_last'
:
channel_dim
=
-
1
else
:
channel_dim
=
1
# add levels with max pooling
for
layer_depth
in
range
(
depth
):
layer1
=
create_convolution_block
(
input_layer
=
current_layer
,
n_filters
=
n_base_filters
*
(
2
**
layer_depth
),
batch_normalization
=
batch_normalization
,
kernel
=
(
3
,
3
,
3
),
activation
=
tf
.
nn
.
relu
,
padding
=
'SAME'
,
strides
=
(
1
,
1
,
1
),
data_format
=
data_format
,
instance_normalization
=
False
)
layer2
=
create_convolution_block
(
input_layer
=
layer1
,
n_filters
=
n_base_filters
*
(
2
**
layer_depth
)
*
2
,
batch_normalization
=
batch_normalization
,
kernel
=
(
3
,
3
,
3
),
activation
=
tf
.
nn
.
relu
,
padding
=
'SAME'
,
strides
=
(
1
,
1
,
1
),
data_format
=
data_format
,
instance_normalization
=
False
)
if
layer_depth
<
depth
-
1
:
current_layer
=
tf
.
keras
.
layers
.
MaxPool3D
(
pool_size
=
pool_size
,
strides
=
(
2
,
2
,
2
),
padding
=
'VALID'
,
data_format
=
data_format
)(
inputs
=
layer2
)
levels
.
append
([
layer1
,
layer2
,
current_layer
])
else
:
current_layer
=
layer2
levels
.
append
([
layer1
,
layer2
])
# add levels with up-convolution or up-sampling
for
layer_depth
in
range
(
depth
-
2
,
-
1
,
-
1
):
up_convolution
=
apply_up_convolution
(
current_layer
,
pool_size
=
pool_size
,
deconvolution
=
deconvolution
,
num_filters
=
current_layer
.
get_shape
().
as_list
()[
channel_dim
])
concat
=
tf
.
concat
([
up_convolution
,
levels
[
layer_depth
][
1
]],
axis
=
channel_dim
)
current_layer
=
create_convolution_block
(
n_filters
=
levels
[
layer_depth
][
1
].
get_shape
().
as_list
()[
channel_dim
],
input_layer
=
concat
,
batch_normalization
=
batch_normalization
,
kernel
=
(
3
,
3
,
3
),
activation
=
tf
.
nn
.
relu
,
padding
=
'SAME'
,
strides
=
(
1
,
1
,
1
),
data_format
=
data_format
,
instance_normalization
=
False
)
current_layer
=
create_convolution_block
(
n_filters
=
levels
[
layer_depth
][
1
].
get_shape
().
as_list
()[
channel_dim
],
input_layer
=
current_layer
,
batch_normalization
=
batch_normalization
,
kernel
=
(
3
,
3
,
3
),
activation
=
tf
.
nn
.
relu
,
padding
=
'SAME'
,
strides
=
(
1
,
1
,
1
),
data_format
=
data_format
,
instance_normalization
=
False
)
final_convolution
=
tf
.
keras
.
layers
.
Conv3D
(
filters
=
n_labels
,
kernel_size
=
(
1
,
1
,
1
),
padding
=
'VALID'
,
data_format
=
data_format
,
activation
=
None
)(
current_layer
)
return
final_convolution
def
build_unet_model
(
params
):
"""Builds the unet model, optimizer included."""
input_shape
=
params
.
input_image_size
+
[
1
]
input_layer
=
tf
.
keras
.
layers
.
Input
(
shape
=
input_shape
)
logits
=
unet3d_base
(
input_layer
,
pool_size
=
(
2
,
2
,
2
),
n_labels
=
params
.
num_classes
,
deconvolution
=
params
.
deconvolution
,
depth
=
params
.
depth
,
n_base_filters
=
params
.
num_base_filters
,
batch_normalization
=
params
.
use_batch_norm
,
data_format
=
params
.
data_format
)
# Set output of softmax to float32 to avoid potential numerical overflow.
predictions
=
tf
.
keras
.
layers
.
Softmax
(
dtype
=
'float32'
)(
logits
)
model
=
tf
.
keras
.
models
.
Model
(
inputs
=
input_layer
,
outputs
=
predictions
)
model
.
optimizer
=
create_optimizer
(
params
.
init_learning_rate
,
params
)
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment