"vscode:/vscode.git/clone" did not exist on "7fa4b36ebaebc6f5ca72ccf08ed0fe2a84422c5c"
Unverified Commit fc1d841b authored by nlpcat's avatar nlpcat Committed by GitHub
Browse files

change shape to support dynamic batch input in tf.function XLA generate for tf serving (#18372)



* change shape to support dynamic batch input in tf.generate

* add tests
Co-authored-by: default avatarnlpcatcode <nlpcodecat@gmail.com>
parent b69a62d5
......@@ -1533,7 +1533,7 @@ class TFGenerationMixin:
# 2. Define model inputs
input_ids = self._prepare_model_inputs(input_ids, bos_token_id)
# inputs_ids now has to be defined and cannot be None anymore
batch_size = input_ids.shape[0]
batch_size = shape_list(input_ids)[0]
# 3. Prepare other model kwargs
if output_attentions is not None:
......@@ -1702,7 +1702,8 @@ class TFGenerationMixin:
@staticmethod
def _expand_to_num_beams(tensor: tf.Tensor, num_beams: int) -> tf.Tensor:
return tf.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:])
shape = shape_list(tensor)
return tf.broadcast_to(tensor[:, None], (shape[0], num_beams) + tuple(shape[1:]))
def _prepare_attention_mask_for_generation(
self,
......@@ -2162,7 +2163,7 @@ class TFGenerationMixin:
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
# 3. init tensors to use for "xla-compileable" generate function
batch_size, cur_len = input_ids.shape
batch_size, cur_len = shape_list(input_ids)
# initialize `generated` (`input_ids` padded with `pad_token_id`), `finished_sequences`
input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0)
......@@ -2432,7 +2433,7 @@ class TFGenerationMixin:
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
# 3. init tensors to use for "xla-compileable" generate function
batch_size, cur_len = input_ids.shape
batch_size, cur_len = shape_list(input_ids)
# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0)
......@@ -2678,18 +2679,16 @@ class TFGenerationMixin:
def flatten_beam_dim(tensor, batch_axis=0):
"""Flattens the first two dimensions of a non-scalar array."""
shape = shape_list(tensor)
return tf.reshape(
tensor,
tensor.shape[:batch_axis]
+ [tensor.shape[batch_axis] * tensor.shape[batch_axis + 1]]
+ tensor.shape[batch_axis + 2 :],
shape[:batch_axis] + [shape[batch_axis] * shape[batch_axis + 1]] + shape[batch_axis + 2 :],
)
def unflatten_beam_dim(tensor, batch_size, num_beams, batch_axis=0):
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
return tf.reshape(
tensor, tensor.shape[:batch_axis] + [batch_size, num_beams] + tensor.shape[batch_axis + 1 :]
)
shape = shape_list(tensor)
return tf.reshape(tensor, shape[:batch_axis] + [batch_size, num_beams] + shape[batch_axis + 1 :])
def gather_beams(nested, beam_indices, batch_axis=0):
"""Gathers the beam slices indexed by beam_indices into new beam array."""
......@@ -2748,7 +2747,7 @@ class TFGenerationMixin:
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
# 3. init tensors to use for "xla-compileable" generate function
batch_size, num_beams, cur_len = input_ids.shape
batch_size, num_beams, cur_len = shape_list(input_ids)
# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * (
......@@ -2894,7 +2893,7 @@ class TFGenerationMixin:
eos_in_next_token = tf.broadcast_to(eos_in_next_token, topk_sequences[:, :, cur_len].shape)
did_topk_just_finished = eos_in_next_token & tf.broadcast_to(
tf.concat((tf.ones((num_beams), dtype=tf.bool), tf.zeros((num_beams), dtype=tf.bool)), axis=0),
eos_in_next_token.shape,
shape_list(eos_in_next_token),
)
# non-top `num_beams` eos tokens can't be used to finish a beam, but the others can't be used in the next
......@@ -2917,7 +2916,7 @@ class TFGenerationMixin:
topk_log_probs = topk_log_probs / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty)
beams_in_batch_are_full = (
tf.broadcast_to(
tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), did_topk_just_finished.shape
tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), shape_list(did_topk_just_finished)
)
& early_stopping
)
......
......@@ -73,6 +73,7 @@ if is_tf_available():
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
BertConfig,
TFAutoModel,
TFAutoModelForSeq2SeqLM,
TFAutoModelForSequenceClassification,
TFBertModel,
TFSharedEmbeddings,
......@@ -2163,6 +2164,46 @@ class UtilsFunctionsTest(unittest.TestCase):
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
def test_generate_tf_function_export(self):
test_model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
max_length = 8
class DummyModel(tf.Module):
def __init__(self, model):
super(DummyModel, self).__init__()
self.model = model
@tf.function(
input_signature=(
tf.TensorSpec((None, max_length), tf.int32, name="input_ids"),
tf.TensorSpec((None, max_length), tf.int32, name="attention_mask"),
),
jit_compile=True,
)
def serving(self, input_ids, attention_mask):
outputs = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_length,
return_dict_in_generate=True,
)
return {"sequences": outputs["sequences"]}
dummy_input_ids = [[2, 3, 4, 1, 0, 0, 0, 0], [102, 103, 104, 105, 1, 0, 0, 0]]
dummy_attention_masks = [[1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0, 0]]
dummy_model = DummyModel(model=test_model)
with tempfile.TemporaryDirectory() as tmp_dir:
tf.saved_model.save(dummy_model, tmp_dir, signatures={"serving_default": dummy_model.serving})
serving_func = tf.saved_model.load(tmp_dir).signatures["serving_default"]
for batch_size in range(1, len(dummy_input_ids) + 1):
inputs = {
"input_ids": tf.constant(dummy_input_ids[:batch_size]),
"attention_mask": tf.constant(dummy_attention_masks[:batch_size]),
}
tf_func_outputs = serving_func(**inputs)["sequences"]
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_length)
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
@require_tf
@is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment