"dialogctrl/ner/src/config.py" did not exist on "7d044e4e58b5b479c908c94442a413af36b1a97e"
Commit a684cbb8 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 366496850
parent 11afc46a
......@@ -112,8 +112,8 @@ class Parser(parser.Parser):
self._use_autoaugment = use_autoaugment
self._autoaugment_policy_name = autoaugment_policy_name
# Device.
self._use_bfloat16 = True if dtype == 'bfloat16' else False
# Data type.
self._dtype = dtype
def _parse_train_data(self, data):
"""Parses data for training and evaluation."""
......@@ -180,9 +180,8 @@ class Parser(parser.Parser):
box_weights) = anchor_labeler.label_anchors(
anchor_boxes, boxes, tf.expand_dims(classes, axis=1))
# If bfloat16 is used, casts input image to tf.bfloat16.
if self._use_bfloat16:
image = tf.cast(image, dtype=tf.bfloat16)
# Casts input image to desired data type.
image = tf.cast(image, dtype=self._dtype)
# Packs labels for model_fn outputs.
labels = {
......@@ -245,9 +244,8 @@ class Parser(parser.Parser):
box_weights) = anchor_labeler.label_anchors(
anchor_boxes, boxes, tf.expand_dims(classes, axis=1))
# If bfloat16 is used, casts input image to tf.bfloat16.
if self._use_bfloat16:
image = tf.cast(image, dtype=tf.bfloat16)
# Casts input image to desired data type.
image = tf.cast(image, dtype=self._dtype)
# Sets up groundtruth data for evaluation.
groundtruths = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment