Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
afab77f6
Commit
afab77f6
authored
Jan 05, 2021
by
Abdullah Rashwan
Committed by
A. Unique TensorFlower
Jan 05, 2021
Browse files
Internal change
PiperOrigin-RevId: 350255525
parent
2832bca8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
37 additions
and
2 deletions
+37
-2
official/vision/beta/configs/semantic_segmentation.py
official/vision/beta/configs/semantic_segmentation.py
+6
-0
official/vision/beta/dataloaders/dataset_fn.py
official/vision/beta/dataloaders/dataset_fn.py
+28
-0
official/vision/beta/losses/segmentation_losses.py
official/vision/beta/losses/segmentation_losses.py
+1
-1
official/vision/beta/tasks/semantic_segmentation.py
official/vision/beta/tasks/semantic_segmentation.py
+2
-1
No files found.
official/vision/beta/configs/semantic_segmentation.py
View file @
afab77f6
...
...
@@ -33,6 +33,8 @@ from official.vision.beta.configs import decoders
class
DataConfig
(
cfg
.
DataConfig
):
"""Input config for training."""
output_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
# If train_on_crops is set to True, a patch of size output_size is cropped
# from the input image.
train_on_crops
:
bool
=
False
input_path
:
str
=
''
global_batch_size
:
int
=
0
...
...
@@ -40,12 +42,16 @@ class DataConfig(cfg.DataConfig):
dtype
:
str
=
'float32'
shuffle_buffer_size
:
int
=
1000
cycle_length
:
int
=
10
# If resize_eval_groundtruth is set to False, original image sizes are used
# for eval. In that case, groundtruth_padded_size has to be specified too to
# allow for batching the variable input sizes of images.
resize_eval_groundtruth
:
bool
=
True
groundtruth_padded_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
aug_scale_min
:
float
=
1.0
aug_scale_max
:
float
=
1.0
aug_rand_hflip
:
bool
=
True
drop_remainder
:
bool
=
True
file_type
:
str
=
'tfrecod'
# tfrecord, or sstable
@
dataclasses
.
dataclass
...
...
official/vision/beta/dataloaders/dataset_fn.py
0 → 100644
View file @
afab77f6
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility library for picking an appropriate dataset function."""
from
typing
import
Any
,
Callable
,
Union
,
Type
import
tensorflow
as
tf
PossibleDatasetType
=
Union
[
Type
[
tf
.
data
.
Dataset
],
Callable
[[
tf
.
Tensor
],
Any
]]
def
pick_dataset_fn
(
file_type
:
str
)
->
PossibleDatasetType
:
if
file_type
==
'tf_record'
:
return
tf
.
data
.
TFRecordDataset
raise
ValueError
(
'Unrecognized file_type: {}'
.
format
(
file_type
))
official/vision/beta/losses/segmentation_losses.py
View file @
afab77f6
...
...
@@ -83,7 +83,7 @@ class SegmentationLoss:
top_k_losses
,
_
=
tf
.
math
.
top_k
(
cross_entropy_loss
,
k
=
top_k_pixels
,
sorted
=
True
)
normalizer
=
tf
.
reduce_sum
(
tf
.
cast
(
tf
.
not_equal
(
top_k_losses
,
0.0
),
tf
.
float32
)
+
EPSILON
)
tf
.
cast
(
tf
.
not_equal
(
top_k_losses
,
0.0
),
tf
.
float32
)
)
+
EPSILON
loss
=
tf
.
reduce_sum
(
top_k_losses
)
/
normalizer
return
loss
official/vision/beta/tasks/semantic_segmentation.py
View file @
afab77f6
...
...
@@ -23,6 +23,7 @@ from official.core import input_reader
from
official.core
import
task_factory
from
official.vision.beta.configs
import
semantic_segmentation
as
exp_cfg
from
official.vision.beta.dataloaders
import
segmentation_input
from
official.vision.beta.dataloaders
import
dataset_fn
from
official.vision.beta.evaluation
import
segmentation_metrics
from
official.vision.beta.losses
import
segmentation_losses
from
official.vision.beta.modeling
import
factory
...
...
@@ -97,7 +98,7 @@ class SemanticSegmentationTask(base_task.Task):
reader
=
input_reader
.
InputReader
(
params
,
dataset_fn
=
tf
.
data
.
TFRecordDataset
,
dataset_fn
=
dataset_fn
.
pick_dataset_fn
(
params
.
file_type
)
,
decoder_fn
=
decoder
.
decode
,
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment