Unverified Commit 0225b135 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'tensorflow:master' into panoptic-deeplab-modeling

parents 7479dbb8 4c571a3c
...@@ -68,8 +68,17 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -68,8 +68,17 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
if inference_step is not None: if inference_step is not None:
self.inference_step = functools.partial(inference_step, model=self.model) self.inference_step = functools.partial(inference_step, model=self.model)
else: else:
self.inference_step = functools.partial( if issubclass(type(model), tf.keras.Model):
self.model.__call__, training=False) # Default to self.model.call instead of self.model.__call__ to avoid
# keras tracing logic designed for training.
# Since most of Model Garden's call doesn't not have training kwargs
# or the default is False, we don't pass anything here.
# Please pass custom inference step if your model has training=True as
# default.
self.inference_step = self.model.call
else:
self.inference_step = functools.partial(
self.model.__call__, training=False)
self.preprocessor = preprocessor self.preprocessor = preprocessor
self.postprocessor = postprocessor self.postprocessor = postprocessor
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Registry utility.""" """Registry utility."""
from absl import logging
def register(registered_collection, reg_key): def register(registered_collection, reg_key):
...@@ -54,8 +55,16 @@ def register(registered_collection, reg_key): ...@@ -54,8 +55,16 @@ def register(registered_collection, reg_key):
leaf_reg_key = reg_key leaf_reg_key = reg_key
if leaf_reg_key in collection: if leaf_reg_key in collection:
raise KeyError("Function or class {} registered multiple times.".format( if "beta" in fn_or_cls.__module__:
leaf_reg_key)) # TODO(yeqing): Clean this temporary branch for beta.
logging.warn(
"Duplicate registeration of beta module "
"name %r new %r old %r", reg_key, collection[leaf_reg_key],
fn_or_cls.__module__)
return fn_or_cls
else:
raise KeyError("Function or class {} registered multiple times.".format(
leaf_reg_key))
collection[leaf_reg_key] = fn_or_cls collection[leaf_reg_key] = fn_or_cls
return fn_or_cls return fn_or_cls
......
...@@ -48,6 +48,8 @@ class OptimizerConfig(oneof.OneOfConfig): ...@@ -48,6 +48,8 @@ class OptimizerConfig(oneof.OneOfConfig):
sgd_experimental: opt_cfg.SGDExperimentalConfig = ( sgd_experimental: opt_cfg.SGDExperimentalConfig = (
opt_cfg.SGDExperimentalConfig()) opt_cfg.SGDExperimentalConfig())
adam: opt_cfg.AdamConfig = opt_cfg.AdamConfig() adam: opt_cfg.AdamConfig = opt_cfg.AdamConfig()
adam_experimental: opt_cfg.AdamExperimentalConfig = (
opt_cfg.AdamExperimentalConfig())
adamw: opt_cfg.AdamWeightDecayConfig = opt_cfg.AdamWeightDecayConfig() adamw: opt_cfg.AdamWeightDecayConfig = opt_cfg.AdamWeightDecayConfig()
lamb: opt_cfg.LAMBConfig = opt_cfg.LAMBConfig() lamb: opt_cfg.LAMBConfig = opt_cfg.LAMBConfig()
rmsprop: opt_cfg.RMSPropConfig = opt_cfg.RMSPropConfig() rmsprop: opt_cfg.RMSPropConfig = opt_cfg.RMSPropConfig()
......
...@@ -67,6 +67,7 @@ class SGDExperimentalConfig(BaseOptimizerConfig): ...@@ -67,6 +67,7 @@ class SGDExperimentalConfig(BaseOptimizerConfig):
name: name of the optimizer. name: name of the optimizer.
nesterov: nesterov for SGD optimizer. nesterov: nesterov for SGD optimizer.
momentum: momentum for SGD optimizer. momentum: momentum for SGD optimizer.
jit_compile: if True, jit compile will be used.
""" """
name: str = "SGD" name: str = "SGD"
nesterov: bool = False nesterov: bool = False
...@@ -135,6 +136,30 @@ class AdamConfig(BaseOptimizerConfig): ...@@ -135,6 +136,30 @@ class AdamConfig(BaseOptimizerConfig):
amsgrad: bool = False amsgrad: bool = False
@dataclasses.dataclass
class AdamExperimentalConfig(BaseOptimizerConfig):
"""Configuration for experimental Adam optimizer.
The attributes for this class matches the arguments of
`tf.keras.optimizer.experimental.Adam`.
Attributes:
name: name of the optimizer.
beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in Adam optimizer.
amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
the paper "On the Convergence of Adam and beyond".
jit_compile: if True, jit compile will be used.
"""
name: str = "Adam"
beta_1: float = 0.9
beta_2: float = 0.999
epsilon: float = 1e-07
amsgrad: bool = False
jit_compile: bool = False
@dataclasses.dataclass @dataclasses.dataclass
class AdamWeightDecayConfig(BaseOptimizerConfig): class AdamWeightDecayConfig(BaseOptimizerConfig):
"""Configuration for Adam optimizer with weight decay. """Configuration for Adam optimizer with weight decay.
......
...@@ -30,6 +30,7 @@ OPTIMIZERS_CLS = { ...@@ -30,6 +30,7 @@ OPTIMIZERS_CLS = {
'sgd': tf.keras.optimizers.SGD, 'sgd': tf.keras.optimizers.SGD,
'sgd_experimental': tf.keras.optimizers.experimental.SGD, 'sgd_experimental': tf.keras.optimizers.experimental.SGD,
'adam': tf.keras.optimizers.Adam, 'adam': tf.keras.optimizers.Adam,
'adam_experimental': tf.keras.optimizers.experimental.Adam,
'adamw': nlp_optimization.AdamWeightDecay, 'adamw': nlp_optimization.AdamWeightDecay,
'lamb': tfa_optimizers.LAMB, 'lamb': tfa_optimizers.LAMB,
'rmsprop': tf.keras.optimizers.RMSprop, 'rmsprop': tf.keras.optimizers.RMSprop,
......
...@@ -115,7 +115,8 @@ class MaskedLM(tf.keras.layers.Layer): ...@@ -115,7 +115,8 @@ class MaskedLM(tf.keras.layers.Layer):
flat_offsets = tf.reshape( flat_offsets = tf.reshape(
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
flat_positions = tf.reshape(positions + flat_offsets, [-1]) flat_positions = tf.reshape(
positions + tf.cast(flat_offsets, positions.dtype), [-1])
flat_sequence_tensor = tf.reshape(sequence_tensor, flat_sequence_tensor = tf.reshape(sequence_tensor,
[batch_size * seq_length, width]) [batch_size * seq_length, width])
output_tensor = tf.gather(flat_sequence_tensor, flat_positions) output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
......
...@@ -20,8 +20,10 @@ from absl import logging ...@@ -20,8 +20,10 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
try: try:
# pytype: disable=import-error
import tensorflow_text as text import tensorflow_text as text
from tensorflow_text.python.ops import bert_tokenizer from tensorflow_text.python.ops import bert_tokenizer
# pytype: enable=import-error
except ImportError: except ImportError:
text = None text = None
bert_tokenizer = None bert_tokenizer = None
......
...@@ -226,6 +226,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -226,6 +226,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
funnel encoder relies on. funnel encoder relies on.
share_rezero: bool. Whether to share ReZero alpha between the attention share_rezero: bool. Whether to share ReZero alpha between the attention
layer and the ffn layer. This option is specific to ReZero. layer and the ffn layer. This option is specific to ReZero.
with_dense_inputs: Whether to accept dense embeddings as the input.
""" """
def __init__( def __init__(
...@@ -402,12 +403,22 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -402,12 +403,22 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
_transformer_cls2str.get(transformer_cls, str(transformer_cls)) _transformer_cls2str.get(transformer_cls, str(transformer_cls))
} }
self.inputs = dict(
input_word_ids=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_mask=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32))
def call(self, inputs): def call(self, inputs):
# inputs are [word_ids, mask, type_ids] # inputs are [word_ids, mask, type_ids]
if isinstance(inputs, (list, tuple)): if isinstance(inputs, (list, tuple)):
logging.warning('List inputs to %s are discouraged.', self.__class__) logging.warning('List inputs to %s are discouraged.', self.__class__)
if len(inputs) == 3: if len(inputs) == 3:
word_ids, mask, type_ids = inputs word_ids, mask, type_ids = inputs
dense_inputs = None
dense_mask = None
dense_type_ids = None
elif len(inputs) == 6:
word_ids, mask, type_ids, dense_inputs, dense_mask, dense_type_ids = inputs
else: else:
raise ValueError('Unexpected inputs to %s with length at %d.' % raise ValueError('Unexpected inputs to %s with length at %d.' %
(self.__class__, len(inputs))) (self.__class__, len(inputs)))
...@@ -415,10 +426,21 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -415,10 +426,21 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
word_ids = inputs.get('input_word_ids') word_ids = inputs.get('input_word_ids')
mask = inputs.get('input_mask') mask = inputs.get('input_mask')
type_ids = inputs.get('input_type_ids') type_ids = inputs.get('input_type_ids')
dense_inputs = inputs.get('dense_inputs', None)
dense_mask = inputs.get('dense_mask', None)
dense_type_ids = inputs.get('dense_type_ids', None)
else: else:
raise ValueError('Unexpected inputs type to %s.' % self.__class__) raise ValueError('Unexpected inputs type to %s.' % self.__class__)
word_embeddings = self._embedding_layer(word_ids) word_embeddings = self._embedding_layer(word_ids)
if dense_inputs is not None:
# Concat the dense embeddings at sequence begin so unpool_len can control
# embedding not being pooled.
word_embeddings = tf.concat([dense_inputs, word_embeddings], axis=1)
type_ids = tf.concat([dense_type_ids, type_ids], axis=1)
mask = tf.concat([dense_mask, mask], axis=1)
# absolute position embeddings # absolute position embeddings
position_embeddings = self._position_embedding_layer(word_embeddings) position_embeddings = self._position_embedding_layer(word_embeddings)
type_embeddings = self._type_embedding_layer(type_ids) type_embeddings = self._type_embedding_layer(type_ids)
......
...@@ -101,6 +101,55 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -101,6 +101,55 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual(tf.float32, data.dtype) self.assertAllEqual(tf.float32, data.dtype)
self.assertAllEqual(pooled_dtype, pooled.dtype) self.assertAllEqual(pooled_dtype, pooled.dtype)
def test_network_creation_dense(self):
tf.keras.mixed_precision.set_global_policy("mixed_float16")
pool_type = "avg"
hidden_size = 32
sequence_length = 21
dense_sequence_length = 3
pool_stride = 2
num_layers = 3
# Create a small FunnelTransformerEncoder for testing.
test_network = funnel_transformer.FunnelTransformerEncoder(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=num_layers,
pool_stride=pool_stride,
pool_type=pool_type,
max_sequence_length=sequence_length + dense_sequence_length,
unpool_length=0,
transformer_cls="TransformerEncoderBlock")
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dense_inputs = tf.keras.Input(
shape=(dense_sequence_length, hidden_size), dtype=tf.float32)
dense_mask = tf.keras.Input(shape=(dense_sequence_length,), dtype=tf.int32)
dense_type_ids = tf.keras.Input(
shape=(dense_sequence_length,), dtype=tf.int32)
dict_outputs = test_network(
[word_ids, mask, type_ids, dense_inputs, dense_mask, dense_type_ids])
data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"]
self.assertIsInstance(test_network.transformer_layers, list)
self.assertLen(test_network.transformer_layers, num_layers)
self.assertIsInstance(test_network.pooler_layer, tf.keras.layers.Dense)
# Stride=2 compresses sequence length to half the size at each layer.
# For pool_type = max or avg,
# this configuration gives each layer of seq length: 24->12->6->3.
expected_data_shape = [None, 3, hidden_size]
expected_pooled_shape = [None, hidden_size]
self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
def test_invalid_stride_and_num_layers(self): def test_invalid_stride_and_num_layers(self):
hidden_size = 32 hidden_size = 32
num_layers = 3 num_layers = 3
......
...@@ -417,6 +417,8 @@ class Translation(export_base.ExportModule): ...@@ -417,6 +417,8 @@ class Translation(export_base.ExportModule):
@dataclasses.dataclass @dataclasses.dataclass
class Params(base_config.Config): class Params(base_config.Config):
sentencepiece_model_path: str = "" sentencepiece_model_path: str = ""
# Needs to be specified if padded_decode is True/on TPUs.
batch_size: Optional[int] = None
def __init__(self, params, model: tf.keras.Model, inference_step=None): def __init__(self, params, model: tf.keras.Model, inference_step=None):
super().__init__(params, model, inference_step) super().__init__(params, model, inference_step)
...@@ -431,6 +433,7 @@ class Translation(export_base.ExportModule): ...@@ -431,6 +433,7 @@ class Translation(export_base.ExportModule):
"Please make sure the tokenizer generates a single token for an " "Please make sure the tokenizer generates a single token for an "
"empty string.") "empty string.")
self._eos_id = empty_str_tokenized.item() self._eos_id = empty_str_tokenized.item()
self._batch_size = params.batch_size
@tf.function @tf.function
def serve(self, inputs) -> Dict[str, tf.Tensor]: def serve(self, inputs) -> Dict[str, tf.Tensor]:
...@@ -452,5 +455,6 @@ class Translation(export_base.ExportModule): ...@@ -452,5 +455,6 @@ class Translation(export_base.ExportModule):
(self.__class__, func_key, valid_keys)) (self.__class__, func_key, valid_keys))
if func_key == "serve_text": if func_key == "serve_text":
signatures[signature_key] = self.serve_text.get_concrete_function( signatures[signature_key] = self.serve_text.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="text")) tf.TensorSpec(shape=[self._batch_size],
dtype=tf.string, name="text"))
return signatures return signatures
...@@ -20,6 +20,7 @@ from absl.testing import parameterized ...@@ -20,6 +20,7 @@ from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from sentencepiece import SentencePieceTrainer from sentencepiece import SentencePieceTrainer
from official.core import export_base
from official.nlp.configs import bert from official.nlp.configs import bert
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.serving import serving_modules from official.nlp.serving import serving_modules
...@@ -343,7 +344,10 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase): ...@@ -343,7 +344,10 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = export_module.get_inference_signatures({"foo": None}) _ = export_module.get_inference_signatures({"foo": None})
def test_translation(self): @parameterized.parameters(
(False, None),
(True, 2))
def test_translation(self, padded_decode, batch_size):
sp_path = _make_sentencepeice(self.get_temp_dir()) sp_path = _make_sentencepeice(self.get_temp_dir())
encdecoder = translation.EncDecoder( encdecoder = translation.EncDecoder(
num_attention_heads=4, intermediate_size=256) num_attention_heads=4, intermediate_size=256)
...@@ -352,7 +356,7 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase): ...@@ -352,7 +356,7 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
encoder=encdecoder, encoder=encdecoder,
decoder=encdecoder, decoder=encdecoder,
embedding_width=256, embedding_width=256,
padded_decode=False, padded_decode=padded_decode,
decode_max_length=100), decode_max_length=100),
sentencepiece_model_path=sp_path, sentencepiece_model_path=sp_path,
) )
...@@ -360,7 +364,7 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase): ...@@ -360,7 +364,7 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
model = task.build_model() model = task.build_model()
params = serving_modules.Translation.Params( params = serving_modules.Translation.Params(
sentencepiece_model_path=sp_path) sentencepiece_model_path=sp_path, batch_size=batch_size)
export_module = serving_modules.Translation(params=params, model=model) export_module = serving_modules.Translation(params=params, model=model)
functions = export_module.get_inference_signatures({ functions = export_module.get_inference_signatures({
"serve_text": "serving_default" "serve_text": "serving_default"
...@@ -369,5 +373,19 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase): ...@@ -369,5 +373,19 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(outputs.shape, (2,)) self.assertEqual(outputs.shape, (2,))
self.assertEqual(outputs.dtype, tf.string) self.assertEqual(outputs.dtype, tf.string)
tmp_dir = self.get_temp_dir()
tmp_dir = os.path.join(tmp_dir, "padded_decode", str(padded_decode))
export_base_dir = os.path.join(tmp_dir, "export")
ckpt_dir = os.path.join(tmp_dir, "ckpt")
ckpt_path = tf.train.Checkpoint(model=model).save(ckpt_dir)
export_dir = export_base.export(export_module,
{"serve_text": "serving_default"},
export_base_dir, ckpt_path)
loaded = tf.saved_model.load(export_dir)
infer = loaded.signatures["serving_default"]
out = infer(text=tf.constant(["abcd", "ef gh"]))
self.assertLen(out["output_0"], 2)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
...@@ -84,13 +84,13 @@ def _create_model( ...@@ -84,13 +84,13 @@ def _create_model(
"""Creates the model to export and the model to restore the checkpoint. """Creates the model to export and the model to restore the checkpoint.
Args: Args:
bert_config: A legacy `BertConfig` to create a `BertEncoder` object. bert_config: A legacy `BertConfig` to create a `BertEncoder` object. Exactly
Exactly one of encoder_config and bert_config must be set. one of encoder_config and bert_config must be set.
encoder_config: An `EncoderConfig` to create an encoder of the configured encoder_config: An `EncoderConfig` to create an encoder of the configured
type (`BertEncoder` or other). type (`BertEncoder` or other).
with_mlm: A bool to control the second component of the result. with_mlm: A bool to control the second component of the result. If True,
If True, will create a `BertPretrainerV2` object; otherwise, will will create a `BertPretrainerV2` object; otherwise, will create a
create a `BertEncoder` object. `BertEncoder` object.
Returns: Returns:
A Tuple of (1) a Keras model that will be exported, (2) a `BertPretrainerV2` A Tuple of (1) a Keras model that will be exported, (2) a `BertPretrainerV2`
...@@ -110,7 +110,11 @@ def _create_model( ...@@ -110,7 +110,11 @@ def _create_model(
# Convert from list of named inputs to dict of inputs keyed by name. # Convert from list of named inputs to dict of inputs keyed by name.
# Only the latter accepts a dict of inputs after restoring from SavedModel. # Only the latter accepts a dict of inputs after restoring from SavedModel.
encoder_inputs_dict = {x.name: x for x in encoder.inputs} if isinstance(encoder.inputs, list) or isinstance(encoder.inputs, tuple):
encoder_inputs_dict = {x.name: x for x in encoder.inputs}
else:
# encoder.inputs by default is dict for BertEncoderV2.
encoder_inputs_dict = encoder.inputs
encoder_output_dict = encoder(encoder_inputs_dict) encoder_output_dict = encoder(encoder_inputs_dict)
# For interchangeability with other text representations, # For interchangeability with other text representations,
# add "default" as an alias for BERT's whole-input reptesentations. # add "default" as an alias for BERT's whole-input reptesentations.
...@@ -206,26 +210,28 @@ def export_model(export_path: Text, ...@@ -206,26 +210,28 @@ def export_model(export_path: Text,
encoder_config: An optional `encoders.EncoderConfig` object. encoder_config: An optional `encoders.EncoderConfig` object.
model_checkpoint_path: The path to the checkpoint. model_checkpoint_path: The path to the checkpoint.
with_mlm: Whether to export the additional mlm sub-object. with_mlm: Whether to export the additional mlm sub-object.
copy_pooler_dense_to_encoder: Whether to copy the pooler's dense layer copy_pooler_dense_to_encoder: Whether to copy the pooler's dense layer used
used in the next sentence prediction task to the encoder. in the next sentence prediction task to the encoder.
vocab_file: The path to the wordpiece vocab file, or None. vocab_file: The path to the wordpiece vocab file, or None.
sp_model_file: The path to the sentencepiece model file, or None. sp_model_file: The path to the sentencepiece model file, or None. Exactly
Exactly one of vocab_file and sp_model_file must be set. one of vocab_file and sp_model_file must be set.
do_lower_case: Whether to lower-case text before tokenization. do_lower_case: Whether to lower-case text before tokenization.
""" """
if with_mlm: if with_mlm:
core_model, pretrainer = _create_model(bert_config=bert_config, core_model, pretrainer = _create_model(
encoder_config=encoder_config, bert_config=bert_config,
with_mlm=with_mlm) encoder_config=encoder_config,
with_mlm=with_mlm)
encoder = pretrainer.encoder_network encoder = pretrainer.encoder_network
# It supports both the new pretrainer checkpoint produced by TF-NLP and # It supports both the new pretrainer checkpoint produced by TF-NLP and
# the checkpoint converted from TF1 (original BERT, SmallBERTs). # the checkpoint converted from TF1 (original BERT, SmallBERTs).
checkpoint_items = pretrainer.checkpoint_items checkpoint_items = pretrainer.checkpoint_items
checkpoint = tf.train.Checkpoint(**checkpoint_items) checkpoint = tf.train.Checkpoint(**checkpoint_items)
else: else:
core_model, encoder = _create_model(bert_config=bert_config, core_model, encoder = _create_model(
encoder_config=encoder_config, bert_config=bert_config,
with_mlm=with_mlm) encoder_config=encoder_config,
with_mlm=with_mlm)
checkpoint = tf.train.Checkpoint( checkpoint = tf.train.Checkpoint(
model=encoder, # Legacy checkpoints. model=encoder, # Legacy checkpoints.
encoder=encoder) encoder=encoder)
...@@ -279,21 +285,26 @@ class BertPackInputsSavedModelWrapper(tf.train.Checkpoint): ...@@ -279,21 +285,26 @@ class BertPackInputsSavedModelWrapper(tf.train.Checkpoint):
# overridable. Having this dynamically determined default argument # overridable. Having this dynamically determined default argument
# requires self.__call__ to be defined in this indirect way. # requires self.__call__ to be defined in this indirect way.
default_seq_length = bert_pack_inputs.seq_length default_seq_length = bert_pack_inputs.seq_length
@tf.function(autograph=False) @tf.function(autograph=False)
def call(inputs, seq_length=default_seq_length): def call(inputs, seq_length=default_seq_length):
return layers.BertPackInputs.bert_pack_inputs( return layers.BertPackInputs.bert_pack_inputs(
inputs, seq_length=seq_length, inputs,
seq_length=seq_length,
start_of_sequence_id=bert_pack_inputs.start_of_sequence_id, start_of_sequence_id=bert_pack_inputs.start_of_sequence_id,
end_of_segment_id=bert_pack_inputs.end_of_segment_id, end_of_segment_id=bert_pack_inputs.end_of_segment_id,
padding_id=bert_pack_inputs.padding_id) padding_id=bert_pack_inputs.padding_id)
self.__call__ = call self.__call__ = call
for ragged_rank in range(1, 3): for ragged_rank in range(1, 3):
for num_segments in range(1, 3): for num_segments in range(1, 3):
_ = self.__call__.get_concrete_function( _ = self.__call__.get_concrete_function([
[tf.RaggedTensorSpec([None] * (ragged_rank + 1), dtype=tf.int32) tf.RaggedTensorSpec([None] * (ragged_rank + 1), dtype=tf.int32)
for _ in range(num_segments)], for _ in range(num_segments)
seq_length=tf.TensorSpec([], tf.int32)) ],
seq_length=tf.TensorSpec(
[], tf.int32))
def create_preprocessing(*, def create_preprocessing(*,
...@@ -311,14 +322,14 @@ def create_preprocessing(*, ...@@ -311,14 +322,14 @@ def create_preprocessing(*,
Args: Args:
vocab_file: The path to the wordpiece vocab file, or None. vocab_file: The path to the wordpiece vocab file, or None.
sp_model_file: The path to the sentencepiece model file, or None. sp_model_file: The path to the sentencepiece model file, or None. Exactly
Exactly one of vocab_file and sp_model_file must be set. one of vocab_file and sp_model_file must be set. This determines the type
This determines the type of tokenzer that is used. of tokenzer that is used.
do_lower_case: Whether to do lower case. do_lower_case: Whether to do lower case.
tokenize_with_offsets: Whether to include the .tokenize_with_offsets tokenize_with_offsets: Whether to include the .tokenize_with_offsets
subobject. subobject.
default_seq_length: The sequence length of preprocessing results from default_seq_length: The sequence length of preprocessing results from root
root callable. This is also the default sequence length for the callable. This is also the default sequence length for the
bert_pack_inputs subobject. bert_pack_inputs subobject.
Returns: Returns:
...@@ -378,7 +389,8 @@ def create_preprocessing(*, ...@@ -378,7 +389,8 @@ def create_preprocessing(*,
def _move_to_tmpdir(file_path: Optional[Text], tmpdir: Text) -> Optional[Text]: def _move_to_tmpdir(file_path: Optional[Text], tmpdir: Text) -> Optional[Text]:
"""Returns new path with same basename and hash of original path.""" """Returns new path with same basename and hash of original path."""
if file_path is None: return None if file_path is None:
return None
olddir, filename = os.path.split(file_path) olddir, filename = os.path.split(file_path)
hasher = hashlib.sha1() hasher = hashlib.sha1()
hasher.update(olddir.encode("utf-8")) hasher.update(olddir.encode("utf-8"))
...@@ -460,12 +472,17 @@ def _check_no_assert(saved_model_path): ...@@ -460,12 +472,17 @@ def _check_no_assert(saved_model_path):
assert_nodes = [] assert_nodes = []
graph_def = saved_model.meta_graphs[0].graph_def graph_def = saved_model.meta_graphs[0].graph_def
assert_nodes += ["node '{}' in global graph".format(n.name) assert_nodes += [
for n in graph_def.node if n.op == "Assert"] "node '{}' in global graph".format(n.name)
for n in graph_def.node
if n.op == "Assert"
]
for fdef in graph_def.library.function: for fdef in graph_def.library.function:
assert_nodes += [ assert_nodes += [
"node '{}' in function '{}'".format(n.name, fdef.signature.name) "node '{}' in function '{}'".format(n.name, fdef.signature.name)
for n in fdef.node_def if n.op == "Assert"] for n in fdef.node_def
if n.op == "Assert"
]
if assert_nodes: if assert_nodes:
raise AssertionError( raise AssertionError(
"Internal tool error: " "Internal tool error: "
......
This diff is collapsed.
...@@ -40,9 +40,9 @@ from typing import List, Optional, Tuple ...@@ -40,9 +40,9 @@ from typing import List, Optional, Tuple
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.vision.beta.configs import backbones_3d from official.vision.configs import backbones_3d
from official.vision.beta.configs import common from official.vision.configs import common
from official.vision.beta.configs import video_classification from official.vision.configs import video_classification
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -18,7 +18,7 @@ import tensorflow as tf ...@@ -18,7 +18,7 @@ import tensorflow as tf
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.projects.assemblenet.configs import assemblenet from official.projects.assemblenet.configs import assemblenet
from official.vision.beta.configs import video_classification as exp_cfg from official.vision.configs import video_classification as exp_cfg
class AssemblenetTest(tf.test.TestCase, parameterized.TestCase): class AssemblenetTest(tf.test.TestCase, parameterized.TestCase):
......
...@@ -57,8 +57,8 @@ import tensorflow as tf ...@@ -57,8 +57,8 @@ import tensorflow as tf
from official.modeling import hyperparams from official.modeling import hyperparams
from official.projects.assemblenet.configs import assemblenet as cfg from official.projects.assemblenet.configs import assemblenet as cfg
from official.projects.assemblenet.modeling import rep_flow_2d_layer as rf from official.projects.assemblenet.modeling import rep_flow_2d_layer as rf
from official.vision.beta.modeling import factory_3d as model_factory from official.vision.modeling import factory_3d as model_factory
from official.vision.beta.modeling.backbones import factory as backbone_factory from official.vision.modeling.backbones import factory as backbone_factory
layers = tf.keras.layers layers = tf.keras.layers
intermediate_channel_size = [64, 128, 256, 512] intermediate_channel_size = [64, 128, 256, 512]
......
...@@ -64,8 +64,8 @@ from official.modeling import hyperparams ...@@ -64,8 +64,8 @@ from official.modeling import hyperparams
from official.projects.assemblenet.configs import assemblenet as cfg from official.projects.assemblenet.configs import assemblenet as cfg
from official.projects.assemblenet.modeling import assemblenet as asn from official.projects.assemblenet.modeling import assemblenet as asn
from official.projects.assemblenet.modeling import rep_flow_2d_layer as rf from official.projects.assemblenet.modeling import rep_flow_2d_layer as rf
from official.vision.beta.modeling import factory_3d as model_factory from official.vision.modeling import factory_3d as model_factory
from official.vision.beta.modeling.backbones import factory as backbone_factory from official.vision.modeling.backbones import factory as backbone_factory
layers = tf.keras.layers layers = tf.keras.layers
......
...@@ -29,9 +29,6 @@ from absl import flags ...@@ -29,9 +29,6 @@ from absl import flags
from absl import logging from absl import logging
import gin import gin
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import distribute_utils from official.common import distribute_utils
from official.common import flags as tfm_flags from official.common import flags as tfm_flags
from official.core import task_factory from official.core import task_factory
...@@ -42,6 +39,7 @@ from official.modeling import performance ...@@ -42,6 +39,7 @@ from official.modeling import performance
from official.projects.assemblenet.configs import assemblenet as asn_configs from official.projects.assemblenet.configs import assemblenet as asn_configs
from official.projects.assemblenet.modeling import assemblenet as asn from official.projects.assemblenet.modeling import assemblenet as asn
from official.projects.assemblenet.modeling import assemblenet_plus as asnp from official.projects.assemblenet.modeling import assemblenet_plus as asnp
from official.vision import registry_imports
# pylint: enable=unused-import # pylint: enable=unused-import
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -22,7 +22,7 @@ from absl import logging ...@@ -22,7 +22,7 @@ from absl import logging
from absl.testing import flagsaver from absl.testing import flagsaver
import tensorflow as tf import tensorflow as tf
from official.projects.assemblenet import train as train_lib from official.projects.assemblenet import train as train_lib
from official.vision.beta.dataloaders import tfexample_utils from official.vision.dataloaders import tfexample_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -20,8 +20,8 @@ import tensorflow as tf ...@@ -20,8 +20,8 @@ import tensorflow as tf
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import input_reader from official.core import input_reader
from official.vision.beta.ops import box_ops from official.vision.ops import box_ops
from official.vision.beta.ops import preprocess_ops from official.vision.ops import preprocess_ops
@dataclasses.dataclass @dataclasses.dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment