Unverified Commit dfcc691c authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'master' into panoptic-deeplab

parents 83b87f05 a9d9e633
...@@ -285,8 +285,17 @@ class InputReader: ...@@ -285,8 +285,17 @@ class InputReader:
if self._enable_tf_data_service: if self._enable_tf_data_service:
# Add a random seed as the tf.data service job name suffix, so tf.data # Add a random seed as the tf.data service job name suffix, so tf.data
# service doesn't reuse the previous state if TPU worker gets preempted. # service doesn't reuse the previous state if TPU worker gets preempted.
# It's necessary to add global batch size into the tf data service job
# name because when tuning batch size with vizier and tf data service is
# also enable, the tf data servce job name should be different for
# different vizier trials since once batch size is changed, from the
# tf.data perspective, the dataset is a different instance, and a
# different job name should be used for tf data service. Otherwise, the
# model would read tensors from the incorrect tf data service job, which
# would causes dimension mismatch on the batch size dimension.
self._tf_data_service_job_name = ( self._tf_data_service_job_name = (
params.tf_data_service_job_name + str(self.static_randnum)) f'{params.tf_data_service_job_name}_bs{params.global_batch_size}_'
f'{self.static_randnum}')
self._enable_round_robin_tf_data_service = params.get( self._enable_round_robin_tf_data_service = params.get(
'enable_round_robin_tf_data_service', False) 'enable_round_robin_tf_data_service', False)
...@@ -463,9 +472,8 @@ class InputReader: ...@@ -463,9 +472,8 @@ class InputReader:
dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset: dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset:
"""Generates a tf.data.Dataset object.""" """Generates a tf.data.Dataset object."""
if dataset is None: if dataset is None:
dataset = self._read_data_source( dataset = self._read_data_source(self._matched_files, self._dataset_fn,
self._matched_files, self._dataset_fn, input_context, input_context, self._tfds_builder)
self._tfds_builder)
dataset = self._decode_and_parse_dataset(dataset, self._global_batch_size, dataset = self._decode_and_parse_dataset(dataset, self._global_batch_size,
input_context) input_context)
dataset = _maybe_map_fn(dataset, self._postprocess_fn) dataset = _maybe_map_fn(dataset, self._postprocess_fn)
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
from absl import logging from absl import logging
from official.nlp.xlnet import data_utils from official.legacy.xlnet import data_utils
SEG_ID_A = 0 SEG_ID_A = 0
SEG_ID_B = 1 SEG_ID_B = 1
......
...@@ -26,8 +26,8 @@ import numpy as np ...@@ -26,8 +26,8 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
import sentencepiece as spm import sentencepiece as spm
from official.nlp.xlnet import classifier_utils from official.legacy.xlnet import classifier_utils
from official.nlp.xlnet import preprocess_utils from official.legacy.xlnet import preprocess_utils
flags.DEFINE_bool( flags.DEFINE_bool(
......
...@@ -28,7 +28,7 @@ import numpy as np ...@@ -28,7 +28,7 @@ import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
import sentencepiece as spm import sentencepiece as spm
from official.nlp.xlnet import preprocess_utils from official.legacy.xlnet import preprocess_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -25,7 +25,7 @@ from absl import logging ...@@ -25,7 +25,7 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
import sentencepiece as spm import sentencepiece as spm
from official.nlp.xlnet import squad_utils from official.legacy.xlnet import squad_utils
flags.DEFINE_integer( flags.DEFINE_integer(
"num_proc", default=1, help="Number of preprocessing processes.") "num_proc", default=1, help="Number of preprocessing processes.")
......
...@@ -24,12 +24,12 @@ import numpy as np ...@@ -24,12 +24,12 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
# pylint: disable=unused-import # pylint: disable=unused-import
from official.common import distribute_utils from official.common import distribute_utils
from official.nlp.xlnet import common_flags from official.legacy.xlnet import common_flags
from official.nlp.xlnet import data_utils from official.legacy.xlnet import data_utils
from official.nlp.xlnet import optimization from official.legacy.xlnet import optimization
from official.nlp.xlnet import training_utils from official.legacy.xlnet import training_utils
from official.nlp.xlnet import xlnet_config from official.legacy.xlnet import xlnet_config
from official.nlp.xlnet import xlnet_modeling as modeling from official.legacy.xlnet import xlnet_modeling as modeling
flags.DEFINE_integer("n_class", default=2, help="Number of classes.") flags.DEFINE_integer("n_class", default=2, help="Number of classes.")
flags.DEFINE_string( flags.DEFINE_string(
......
...@@ -24,12 +24,12 @@ from absl import logging ...@@ -24,12 +24,12 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
# pylint: disable=unused-import # pylint: disable=unused-import
from official.common import distribute_utils from official.common import distribute_utils
from official.nlp.xlnet import common_flags from official.legacy.xlnet import common_flags
from official.nlp.xlnet import data_utils from official.legacy.xlnet import data_utils
from official.nlp.xlnet import optimization from official.legacy.xlnet import optimization
from official.nlp.xlnet import training_utils from official.legacy.xlnet import training_utils
from official.nlp.xlnet import xlnet_config from official.legacy.xlnet import xlnet_config
from official.nlp.xlnet import xlnet_modeling as modeling from official.legacy.xlnet import xlnet_modeling as modeling
flags.DEFINE_integer( flags.DEFINE_integer(
"num_predict", "num_predict",
......
...@@ -28,13 +28,13 @@ import tensorflow as tf ...@@ -28,13 +28,13 @@ import tensorflow as tf
# pylint: disable=unused-import # pylint: disable=unused-import
import sentencepiece as spm import sentencepiece as spm
from official.common import distribute_utils from official.common import distribute_utils
from official.nlp.xlnet import common_flags from official.legacy.xlnet import common_flags
from official.nlp.xlnet import data_utils from official.legacy.xlnet import data_utils
from official.nlp.xlnet import optimization from official.legacy.xlnet import optimization
from official.nlp.xlnet import squad_utils from official.legacy.xlnet import squad_utils
from official.nlp.xlnet import training_utils from official.legacy.xlnet import training_utils
from official.nlp.xlnet import xlnet_config from official.legacy.xlnet import xlnet_config
from official.nlp.xlnet import xlnet_modeling as modeling from official.legacy.xlnet import xlnet_modeling as modeling
flags.DEFINE_string( flags.DEFINE_string(
"test_feature_path", default=None, help="Path to feature of test set.") "test_feature_path", default=None, help="Path to feature of test set.")
......
...@@ -32,8 +32,8 @@ import numpy as np ...@@ -32,8 +32,8 @@ import numpy as np
import six import six
import tensorflow as tf import tensorflow as tf
from official.nlp.xlnet import data_utils from official.legacy.xlnet import data_utils
from official.nlp.xlnet import preprocess_utils from official.legacy.xlnet import preprocess_utils
SPIECE_UNDERLINE = u"▁" SPIECE_UNDERLINE = u"▁"
......
...@@ -22,7 +22,7 @@ from absl import logging ...@@ -22,7 +22,7 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
from official.legacy.bert import model_training_utils from official.legacy.bert import model_training_utils
from official.nlp.xlnet import data_utils from official.legacy.xlnet import data_utils
# pytype: disable=attribute-error # pytype: disable=attribute-error
# pylint: disable=g-bare-generic,unused-import # pylint: disable=g-bare-generic,unused-import
......
...@@ -18,9 +18,8 @@ import copy ...@@ -18,9 +18,8 @@ import copy
import warnings import warnings
import tensorflow as tf import tensorflow as tf
from official.legacy.xlnet import data_utils
from official.nlp.modeling import networks from official.nlp.modeling import networks
from official.nlp.xlnet import data_utils
def gelu(x): def gelu(x):
......
...@@ -42,6 +42,7 @@ from official.nlp.modeling.layers.relative_attention import TwoStreamRelativeAtt ...@@ -42,6 +42,7 @@ from official.nlp.modeling.layers.relative_attention import TwoStreamRelativeAtt
from official.nlp.modeling.layers.reuse_attention import ReuseMultiHeadAttention from official.nlp.modeling.layers.reuse_attention import ReuseMultiHeadAttention
from official.nlp.modeling.layers.reuse_transformer import ReuseTransformer from official.nlp.modeling.layers.reuse_transformer import ReuseTransformer
from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer
from official.nlp.modeling.layers.routing import *
from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask
from official.nlp.modeling.layers.spectral_normalization import * from official.nlp.modeling.layers.spectral_normalization import *
from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention
......
...@@ -24,7 +24,7 @@ MultiHeadAttention = tf.keras.layers.MultiHeadAttention ...@@ -24,7 +24,7 @@ MultiHeadAttention = tf.keras.layers.MultiHeadAttention
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
class CachedAttention(tf.keras.layers.MultiHeadAttention): class CachedAttention(tf.keras.layers.MultiHeadAttention):
"""Attention layer with cache used for auto-agressive decoding. """Attention layer with cache used for autoregressive decoding.
Arguments are the same as `tf.keras.layers.MultiHeadAttention` layer. Arguments are the same as `tf.keras.layers.MultiHeadAttention` layer.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment