"torch_scatter/src/generic/cpu.c" did not exist on "fe98b763dd35b1d8ad6c2d8100b60807532ca12b"
Commit 07dfdc7f authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Remove set_shape() when using preprocessing_hub_module.bert_pack_inputs()

PiperOrigin-RevId: 360364244
parent f191e76b
......@@ -14,6 +14,7 @@
# limitations under the License.
# ==============================================================================
"""Loads dataset for the sentence prediction (classification) task."""
import functools
from typing import List, Mapping, Optional
import dataclasses
......@@ -137,20 +138,9 @@ class TextProcessor(tf.Module):
if preprocessing_hub_module_url:
self._preprocessing_hub_module = hub.load(preprocessing_hub_module_url)
self._tokenizer = self._preprocessing_hub_module.tokenize
def set_shape(t):
# Before TF2.4, the sequence length dimension loaded from the
# preprocessing hub module is None, so we recover the shape here.
# TODO(b/157636658): Remove once TF2.4 is released and being used.
t.set_shape([None, seq_length])
return t
def pack_inputs_fn(inputs):
result = self._preprocessing_hub_module.bert_pack_inputs(
inputs, seq_length=seq_length)
result = tf.nest.map_structure(set_shape, result)
return result
self._pack_inputs = pack_inputs_fn
self._pack_inputs = functools.partial(
self._preprocessing_hub_module.bert_pack_inputs,
seq_length=seq_length)
return
if tokenization == 'WordPiece':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment