# Copyright 2024 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Preprocessing utilities for text/image models.""" import dataclasses import numpy as np import tensorflow as tf import tensorflow_text def get_tokenizer(tokenizer_name): """Returns a tokenizer specified by name ("bert" or "sentencpiece").""" return { 'bert': BertTokenizer, 'sentencepiece': SentencepieceTokenizer, }[tokenizer_name] @dataclasses.dataclass(frozen=True) class BertTokenizer: """BERT tokenizer with prepended CLS token and fixed sequence length. This class can be used to tokenize batches of text tokens to numpy arrays (by calling `__call__()`), or as part of a TensorFlow preprocessing graph (via the method `preprocess_tf()`). Attributes: vocab_path: Path pointing to the vocabulary file. Can be any path string that is understood by `tf.io.gfile`. max_len: Length of tokenized sequences. If the provided texts result in fewer tokens, then the sequence is zero-padded. If the provided texts result in more tokens, then the tokens are clipped. cls_token: Will be set during class construction. """ vocab_path: str max_len: int cls_token: int = dataclasses.field(init=False) _tokenizer: tensorflow_text.BertTokenizer = dataclasses.field(init=False) def __post_init__(self): tokenizer = tensorflow_text.BertTokenizer( self.vocab_path, token_out_type=tf.int32, lower_case=True) with tf.io.gfile.GFile(self.vocab_path) as f: vocab = f.read().split('\n') cls_token = vocab.index('[CLS]') # Work-around for frozen dataclasses: # https://stackoverflow.com/questions/53756788 object.__setattr__(self, 'cls_token', cls_token) object.__setattr__(self, '_tokenizer', tokenizer) def preprocess_tf(self, text): """Tokenizes a single text as part of a TensorFlow graph.""" return self._preprocess(text[None])[0] def _preprocess(self, texts): token_ids = self._tokenizer.tokenize(texts) tokens, mask = tensorflow_text.pad_model_inputs(token_ids, self.max_len - 1) del mask # Recovered from zero padding in model. count = tf.shape(tokens)[0] return tf.concat([tf.fill([count, 1], self.cls_token), tokens], axis=1) def __call__(self, texts): """Tokenizes a batch of texts to a numpy array.""" return self._preprocess(tf.constant(texts)).numpy() @dataclasses.dataclass(frozen=True) class SentencepieceTokenizer: """SentencePiece tokenizer with sticky eos. Models that use this tokanizer usually use the *last* token, which is guaranteed to be the "" token (even if tokens are capped to `max_len`). The same token is used for padding (and exposed as `eos_token`). This class can be used to tokenize batches of text tokens to numpy arrays (by calling `__call__()`), or as part of a TensorFlow preprocessing graph (via the method `preprocess_tf()`). Attributes: vocab_path: Path pointing to the vocabulary file. Can be any path string that is understood by `tf.io.gfile`. max_len: Length of tokenized sequences. If the provided texts result in fewer tokens, then the sequence is zero-padded. If the provided texts result in more tokens, then the tokens are clipped. eos_token: Token used for padding. Last token is guaranteed to be padded. """ vocab_path: str max_len: int eos_token: int = dataclasses.field(init=False) _tokenizer: tensorflow_text.BertTokenizer = dataclasses.field(init=False) def __post_init__(self): tokenizer = tensorflow_text.SentencepieceTokenizer( model=tf.io.gfile.GFile(self.vocab_path, 'rb').read(), add_eos=True) eos_token = tokenizer.string_to_id('') # Work-around for frozen dataclasses: # https://stackoverflow.com/questions/53756788 object.__setattr__(self, 'eos_token', eos_token) object.__setattr__(self, '_tokenizer', tokenizer) def preprocess_tf(self, text): """Tokenizes a single text as part of a TensorFlow graph.""" tokens = self._tokenizer.tokenize(text) tokens = tokens[:self.max_len - 1] # to guarantee eos at end return tf.pad( tokens, [(0, self.max_len - tf.shape(tokens)[0])], constant_values=self.eos_token) def __call__(self, texts): """Tokenizes a batch of texts to a numpy array.""" return tf.stack([self.preprocess_tf(text) for text in texts]).numpy() @dataclasses.dataclass(frozen=True) class PreprocessImages: """Resizes images and sets value range to [-1, 1]. This class can be used to tokenize batches of text tokens to numpy arrays (by calling `__call__()`), or as part of a TensorFlow preprocessing graph (via the method `preprocess_tf()`). Attributes: size: Target size of images. crop: If set to true, then the image will first be resized maintaining the original aspect ratio, and then a central crop of that resized image will be returned. """ size: int crop: bool = False def _resize_small(self, image): # pylint: disable=missing-docstring h, w = tf.shape(image)[0], tf.shape(image)[1] # Figure out the necessary h/w. ratio = ( tf.cast(self.size, tf.float32) / tf.cast(tf.minimum(h, w), tf.float32)) h = tf.cast(tf.round(tf.cast(h, tf.float32) * ratio), tf.int32) w = tf.cast(tf.round(tf.cast(w, tf.float32) * ratio), tf.int32) return tf.image.resize(image, (h, w), method='bilinear') def _crop(self, image): h, w = self.size, self.size dy = (tf.shape(image)[0] - h) // 2 dx = (tf.shape(image)[1] - w) // 2 return tf.image.crop_to_bounding_box(image, dy, dx, h, w) def _resize(self, image): return tf.image.resize( image, size=[self.size, self.size], method='bilinear') def _value_range(self, image): image = tf.cast(image, tf.float32) / 255 return -1 + image * 2 def preprocess_tf(self, image): """Resizes a single image as part of a TensorFlowg graph.""" assert image.dtype == tf.uint8 if self.crop: image = self._resize_small(image) image = self._crop(image) else: image = self._resize(image) image = tf.cast(image, tf.uint8) return self._value_range(image) def __call__(self, images): """Resizes a sequence of images, returns a numpy array.""" return np.stack([ self.preprocess_tf(tf.constant(image)) for image in images ]) def get_pp(*, tokenizer_name, vocab_path, max_len, size, crop=False): """Returns preprocessing function for "image" and "text" features. The returned function can directly be used with `tf.data.Dataset.map()`. If either the text feature (feature key "text") or the image feature (feature key "image") are not found, then they will be left untouched. Note that the "image" feature is overwritten with the resized image, but the "text" feature is tokenized into a new feature "tokens". Args: tokenizer_name: Name of tokenizer (either "bert", or "sentencepiece"). vocab_path: Argument passed to tokenizer. max_len: Argument passed to tokenizer. size: Argument passed to `PreprocessImages`. crop: Argument passed to `PreprocessImages`. """ tokenizer_class = get_tokenizer(tokenizer_name) tokenizer = tokenizer_class(vocab_path=vocab_path, max_len=max_len) preprocess_images = PreprocessImages(size=size, crop=crop) def pp(features): features = {**features} if 'image' in features: features['image'] = preprocess_images.preprocess_tf(features['image']) if 'text' in features: features['tokens'] = tokenizer.preprocess_tf(features['text']) return features return pp