Unverified Commit fc1d841b authored by nlpcat's avatar nlpcat Committed by GitHub
Browse files

change shape to support dynamic batch input in tf.function XLA generate for tf serving (#18372)



* change shape to support dynamic batch input in tf.generate

* add tests
Co-authored-by: default avatarnlpcatcode <nlpcodecat@gmail.com>
parent b69a62d5
...@@ -1533,7 +1533,7 @@ class TFGenerationMixin: ...@@ -1533,7 +1533,7 @@ class TFGenerationMixin:
# 2. Define model inputs # 2. Define model inputs
input_ids = self._prepare_model_inputs(input_ids, bos_token_id) input_ids = self._prepare_model_inputs(input_ids, bos_token_id)
# inputs_ids now has to be defined and cannot be None anymore # inputs_ids now has to be defined and cannot be None anymore
batch_size = input_ids.shape[0] batch_size = shape_list(input_ids)[0]
# 3. Prepare other model kwargs # 3. Prepare other model kwargs
if output_attentions is not None: if output_attentions is not None:
...@@ -1702,7 +1702,8 @@ class TFGenerationMixin: ...@@ -1702,7 +1702,8 @@ class TFGenerationMixin:
@staticmethod @staticmethod
def _expand_to_num_beams(tensor: tf.Tensor, num_beams: int) -> tf.Tensor: def _expand_to_num_beams(tensor: tf.Tensor, num_beams: int) -> tf.Tensor:
return tf.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:]) shape = shape_list(tensor)
return tf.broadcast_to(tensor[:, None], (shape[0], num_beams) + tuple(shape[1:]))
def _prepare_attention_mask_for_generation( def _prepare_attention_mask_for_generation(
self, self,
...@@ -2162,7 +2163,7 @@ class TFGenerationMixin: ...@@ -2162,7 +2163,7 @@ class TFGenerationMixin:
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
# 3. init tensors to use for "xla-compileable" generate function # 3. init tensors to use for "xla-compileable" generate function
batch_size, cur_len = input_ids.shape batch_size, cur_len = shape_list(input_ids)
# initialize `generated` (`input_ids` padded with `pad_token_id`), `finished_sequences` # initialize `generated` (`input_ids` padded with `pad_token_id`), `finished_sequences`
input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0) input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0)
...@@ -2432,7 +2433,7 @@ class TFGenerationMixin: ...@@ -2432,7 +2433,7 @@ class TFGenerationMixin:
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
# 3. init tensors to use for "xla-compileable" generate function # 3. init tensors to use for "xla-compileable" generate function
batch_size, cur_len = input_ids.shape batch_size, cur_len = shape_list(input_ids)
# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences` # initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0) input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0)
...@@ -2678,18 +2679,16 @@ class TFGenerationMixin: ...@@ -2678,18 +2679,16 @@ class TFGenerationMixin:
def flatten_beam_dim(tensor, batch_axis=0): def flatten_beam_dim(tensor, batch_axis=0):
"""Flattens the first two dimensions of a non-scalar array.""" """Flattens the first two dimensions of a non-scalar array."""
shape = shape_list(tensor)
return tf.reshape( return tf.reshape(
tensor, tensor,
tensor.shape[:batch_axis] shape[:batch_axis] + [shape[batch_axis] * shape[batch_axis + 1]] + shape[batch_axis + 2 :],
+ [tensor.shape[batch_axis] * tensor.shape[batch_axis + 1]]
+ tensor.shape[batch_axis + 2 :],
) )
def unflatten_beam_dim(tensor, batch_size, num_beams, batch_axis=0): def unflatten_beam_dim(tensor, batch_size, num_beams, batch_axis=0):
"""Unflattens the first, flat batch*beam dimension of a non-scalar array.""" """Unflattens the first, flat batch*beam dimension of a non-scalar array."""
return tf.reshape( shape = shape_list(tensor)
tensor, tensor.shape[:batch_axis] + [batch_size, num_beams] + tensor.shape[batch_axis + 1 :] return tf.reshape(tensor, shape[:batch_axis] + [batch_size, num_beams] + shape[batch_axis + 1 :])
)
def gather_beams(nested, beam_indices, batch_axis=0): def gather_beams(nested, beam_indices, batch_axis=0):
"""Gathers the beam slices indexed by beam_indices into new beam array.""" """Gathers the beam slices indexed by beam_indices into new beam array."""
...@@ -2748,7 +2747,7 @@ class TFGenerationMixin: ...@@ -2748,7 +2747,7 @@ class TFGenerationMixin:
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
# 3. init tensors to use for "xla-compileable" generate function # 3. init tensors to use for "xla-compileable" generate function
batch_size, num_beams, cur_len = input_ids.shape batch_size, num_beams, cur_len = shape_list(input_ids)
# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id` # per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * ( input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * (
...@@ -2894,7 +2893,7 @@ class TFGenerationMixin: ...@@ -2894,7 +2893,7 @@ class TFGenerationMixin:
eos_in_next_token = tf.broadcast_to(eos_in_next_token, topk_sequences[:, :, cur_len].shape) eos_in_next_token = tf.broadcast_to(eos_in_next_token, topk_sequences[:, :, cur_len].shape)
did_topk_just_finished = eos_in_next_token & tf.broadcast_to( did_topk_just_finished = eos_in_next_token & tf.broadcast_to(
tf.concat((tf.ones((num_beams), dtype=tf.bool), tf.zeros((num_beams), dtype=tf.bool)), axis=0), tf.concat((tf.ones((num_beams), dtype=tf.bool), tf.zeros((num_beams), dtype=tf.bool)), axis=0),
eos_in_next_token.shape, shape_list(eos_in_next_token),
) )
# non-top `num_beams` eos tokens can't be used to finish a beam, but the others can't be used in the next # non-top `num_beams` eos tokens can't be used to finish a beam, but the others can't be used in the next
...@@ -2917,7 +2916,7 @@ class TFGenerationMixin: ...@@ -2917,7 +2916,7 @@ class TFGenerationMixin:
topk_log_probs = topk_log_probs / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty) topk_log_probs = topk_log_probs / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty)
beams_in_batch_are_full = ( beams_in_batch_are_full = (
tf.broadcast_to( tf.broadcast_to(
tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), did_topk_just_finished.shape tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), shape_list(did_topk_just_finished)
) )
& early_stopping & early_stopping
) )
......
...@@ -73,6 +73,7 @@ if is_tf_available(): ...@@ -73,6 +73,7 @@ if is_tf_available():
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
BertConfig, BertConfig,
TFAutoModel, TFAutoModel,
TFAutoModelForSeq2SeqLM,
TFAutoModelForSequenceClassification, TFAutoModelForSequenceClassification,
TFBertModel, TFBertModel,
TFSharedEmbeddings, TFSharedEmbeddings,
...@@ -2163,6 +2164,46 @@ class UtilsFunctionsTest(unittest.TestCase): ...@@ -2163,6 +2164,46 @@ class UtilsFunctionsTest(unittest.TestCase):
for p1, p2 in zip(model.weights, new_model.weights): for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
def test_generate_tf_function_export(self):
test_model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
max_length = 8
class DummyModel(tf.Module):
def __init__(self, model):
super(DummyModel, self).__init__()
self.model = model
@tf.function(
input_signature=(
tf.TensorSpec((None, max_length), tf.int32, name="input_ids"),
tf.TensorSpec((None, max_length), tf.int32, name="attention_mask"),
),
jit_compile=True,
)
def serving(self, input_ids, attention_mask):
outputs = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_length,
return_dict_in_generate=True,
)
return {"sequences": outputs["sequences"]}
dummy_input_ids = [[2, 3, 4, 1, 0, 0, 0, 0], [102, 103, 104, 105, 1, 0, 0, 0]]
dummy_attention_masks = [[1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0, 0]]
dummy_model = DummyModel(model=test_model)
with tempfile.TemporaryDirectory() as tmp_dir:
tf.saved_model.save(dummy_model, tmp_dir, signatures={"serving_default": dummy_model.serving})
serving_func = tf.saved_model.load(tmp_dir).signatures["serving_default"]
for batch_size in range(1, len(dummy_input_ids) + 1):
inputs = {
"input_ids": tf.constant(dummy_input_ids[:batch_size]),
"attention_mask": tf.constant(dummy_attention_masks[:batch_size]),
}
tf_func_outputs = serving_func(**inputs)["sequences"]
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_length)
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
@require_tf @require_tf
@is_staging_test @is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment