Unverified Commit fc1d841b authored by nlpcat's avatar nlpcat Committed by GitHub
Browse files

change shape to support dynamic batch input in tf.function XLA generate for tf serving (#18372)



* change shape to support dynamic batch input in tf.generate

* add tests
Co-authored-by: default avatarnlpcatcode <nlpcodecat@gmail.com>
parent b69a62d5
......@@ -1533,7 +1533,7 @@ class TFGenerationMixin:
# 2. Define model inputs
input_ids = self._prepare_model_inputs(input_ids, bos_token_id)
# inputs_ids now has to be defined and cannot be None anymore
batch_size = input_ids.shape[0]
batch_size = shape_list(input_ids)[0]
# 3. Prepare other model kwargs
if output_attentions is not None:
......@@ -1702,7 +1702,8 @@ class TFGenerationMixin:
@staticmethod
def _expand_to_num_beams(tensor: tf.Tensor, num_beams: int) -> tf.Tensor:
return tf.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:])
shape = shape_list(tensor)
return tf.broadcast_to(tensor[:, None], (shape[0], num_beams) + tuple(shape[1:]))
def _prepare_attention_mask_for_generation(
self,
......@@ -2162,7 +2163,7 @@ class TFGenerationMixin:
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
# 3. init tensors to use for "xla-compileable" generate function
batch_size, cur_len = input_ids.shape
batch_size, cur_len = shape_list(input_ids)
# initialize `generated` (`input_ids` padded with `pad_token_id`), `finished_sequences`
input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0)
......@@ -2432,7 +2433,7 @@ class TFGenerationMixin:
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
# 3. init tensors to use for "xla-compileable" generate function
batch_size, cur_len = input_ids.shape
batch_size, cur_len = shape_list(input_ids)
# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0)
......@@ -2678,18 +2679,16 @@ class TFGenerationMixin:
def flatten_beam_dim(tensor, batch_axis=0):
"""Flattens the first two dimensions of a non-scalar array."""
shape = shape_list(tensor)
return tf.reshape(
tensor,
tensor.shape[:batch_axis]
+ [tensor.shape[batch_axis] * tensor.shape[batch_axis + 1]]
+ tensor.shape[batch_axis + 2 :],
shape[:batch_axis] + [shape[batch_axis] * shape[batch_axis + 1]] + shape[batch_axis + 2 :],
)
def unflatten_beam_dim(tensor, batch_size, num_beams, batch_axis=0):
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
return tf.reshape(
tensor, tensor.shape[:batch_axis] + [batch_size, num_beams] + tensor.shape[batch_axis + 1 :]
)
shape = shape_list(tensor)
return tf.reshape(tensor, shape[:batch_axis] + [batch_size, num_beams] + shape[batch_axis + 1 :])
def gather_beams(nested, beam_indices, batch_axis=0):
"""Gathers the beam slices indexed by beam_indices into new beam array."""
......@@ -2748,7 +2747,7 @@ class TFGenerationMixin:
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
# 3. init tensors to use for "xla-compileable" generate function
batch_size, num_beams, cur_len = input_ids.shape
batch_size, num_beams, cur_len = shape_list(input_ids)
# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * (
......@@ -2894,7 +2893,7 @@ class TFGenerationMixin:
eos_in_next_token = tf.broadcast_to(eos_in_next_token, topk_sequences[:, :, cur_len].shape)
did_topk_just_finished = eos_in_next_token & tf.broadcast_to(
tf.concat((tf.ones((num_beams), dtype=tf.bool), tf.zeros((num_beams), dtype=tf.bool)), axis=0),
eos_in_next_token.shape,
shape_list(eos_in_next_token),
)
# non-top `num_beams` eos tokens can't be used to finish a beam, but the others can't be used in the next
......@@ -2917,7 +2916,7 @@ class TFGenerationMixin:
topk_log_probs = topk_log_probs / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty)
beams_in_batch_are_full = (
tf.broadcast_to(
tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), did_topk_just_finished.shape
tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), shape_list(did_topk_just_finished)
)
& early_stopping
)
......
......@@ -73,6 +73,7 @@ if is_tf_available():
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
BertConfig,
TFAutoModel,
TFAutoModelForSeq2SeqLM,
TFAutoModelForSequenceClassification,
TFBertModel,
TFSharedEmbeddings,
......@@ -2163,6 +2164,46 @@ class UtilsFunctionsTest(unittest.TestCase):
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
def test_generate_tf_function_export(self):
test_model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
max_length = 8
class DummyModel(tf.Module):
def __init__(self, model):
super(DummyModel, self).__init__()
self.model = model
@tf.function(
input_signature=(
tf.TensorSpec((None, max_length), tf.int32, name="input_ids"),
tf.TensorSpec((None, max_length), tf.int32, name="attention_mask"),
),
jit_compile=True,
)
def serving(self, input_ids, attention_mask):
outputs = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_length,
return_dict_in_generate=True,
)
return {"sequences": outputs["sequences"]}
dummy_input_ids = [[2, 3, 4, 1, 0, 0, 0, 0], [102, 103, 104, 105, 1, 0, 0, 0]]
dummy_attention_masks = [[1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0, 0]]
dummy_model = DummyModel(model=test_model)
with tempfile.TemporaryDirectory() as tmp_dir:
tf.saved_model.save(dummy_model, tmp_dir, signatures={"serving_default": dummy_model.serving})
serving_func = tf.saved_model.load(tmp_dir).signatures["serving_default"]
for batch_size in range(1, len(dummy_input_ids) + 1):
inputs = {
"input_ids": tf.constant(dummy_input_ids[:batch_size]),
"attention_mask": tf.constant(dummy_attention_masks[:batch_size]),
}
tf_func_outputs = serving_func(**inputs)["sequences"]
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_length)
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
@require_tf
@is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment