Commit 07dfdc7f authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Remove set_shape() when using preprocessing_hub_module.bert_pack_inputs()

PiperOrigin-RevId: 360364244
parent f191e76b
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Loads dataset for the sentence prediction (classification) task.""" """Loads dataset for the sentence prediction (classification) task."""
import functools
from typing import List, Mapping, Optional from typing import List, Mapping, Optional
import dataclasses import dataclasses
...@@ -137,20 +138,9 @@ class TextProcessor(tf.Module): ...@@ -137,20 +138,9 @@ class TextProcessor(tf.Module):
if preprocessing_hub_module_url: if preprocessing_hub_module_url:
self._preprocessing_hub_module = hub.load(preprocessing_hub_module_url) self._preprocessing_hub_module = hub.load(preprocessing_hub_module_url)
self._tokenizer = self._preprocessing_hub_module.tokenize self._tokenizer = self._preprocessing_hub_module.tokenize
def set_shape(t): self._pack_inputs = functools.partial(
# Before TF2.4, the sequence length dimension loaded from the self._preprocessing_hub_module.bert_pack_inputs,
# preprocessing hub module is None, so we recover the shape here. seq_length=seq_length)
# TODO(b/157636658): Remove once TF2.4 is released and being used.
t.set_shape([None, seq_length])
return t
def pack_inputs_fn(inputs):
result = self._preprocessing_hub_module.bert_pack_inputs(
inputs, seq_length=seq_length)
result = tf.nest.map_structure(set_shape, result)
return result
self._pack_inputs = pack_inputs_fn
return return
if tokenization == 'WordPiece': if tokenization == 'WordPiece':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment