Unverified Commit 69122657 authored by Abel's avatar Abel Committed by GitHub
Browse files

Make T5 compatible with ONNX (#5518)



* Default decoder inputs to encoder ones for T5 if neither are specified.

* Fixing typo, now all tests are passing.

* Changing einsum to operations supported by onnx

* Adding a test to ensure T5 can be exported to onnx op>9

* Modified test for onnx export to make it faster

* Styling changes.

* Styling changes.

* Changing notation for matrix multiplication
Co-authored-by: default avatarAbel Riboulot <tkai@protomail.com>
parent 989ae326
...@@ -358,7 +358,10 @@ class T5Attention(nn.Module): ...@@ -358,7 +358,10 @@ class T5Attention(nn.Module):
else: else:
present_key_value_state = (None,) present_key_value_state = (None,)
scores = torch.einsum("bnqd,bnkd->bnqk", q, k) # (bs, n_heads, qlen, klen) # (bs, n_heads, qlen, klen)
scores = torch.matmul(
q, k.transpose(3, 2)
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", q, k), compatible with onnx op>9
if position_bias is None: if position_bias is None:
if not self.has_relative_attention_bias: if not self.has_relative_attention_bias:
...@@ -818,7 +821,8 @@ T5_INPUTS_DOCSTRING = r""" ...@@ -818,7 +821,8 @@ T5_INPUTS_DOCSTRING = r"""
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation. Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
If `decoder_past_key_value_states` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_value_states`). If `decoder_past_key_value_states` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_value_states`).
To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at
`T5 Training <./t5.html#training>`__. `T5 Training <./t5.html#training>`__. If decoder_input_ids and decoder_inputs_embeds are both None,
decoder_input_ids takes the value of input_ids.
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`): decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
...@@ -837,7 +841,8 @@ T5_INPUTS_DOCSTRING = r""" ...@@ -837,7 +841,8 @@ T5_INPUTS_DOCSTRING = r"""
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation. Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
If `decoder_past_key_value_states` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `decoder_past_key_value_states`). If `decoder_past_key_value_states` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `decoder_past_key_value_states`).
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
than the model's internal embedding lookup matrix. than the model's internal embedding lookup matrix. If decoder_input_ids and decoder_inputs_embeds are both None,
decoder_inputs_embeds takes the value of inputs_embeds.
head_mask: (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`): head_mask: (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
...@@ -934,7 +939,7 @@ class T5Model(T5PreTrainedModel): ...@@ -934,7 +939,7 @@ class T5Model(T5PreTrainedModel):
>>> model = T5Model.from_pretrained('t5-small') >>> model = T5Model.from_pretrained('t5-small')
>>> input_ids = tokenizer.encode("Hello, my dog is cute", return_tensors="pt") # Batch size 1 >>> input_ids = tokenizer.encode("Hello, my dog is cute", return_tensors="pt") # Batch size 1
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids) >>> outputs = model(input_ids=input_ids)
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
""" """
...@@ -953,6 +958,12 @@ class T5Model(T5PreTrainedModel): ...@@ -953,6 +958,12 @@ class T5Model(T5PreTrainedModel):
hidden_states = encoder_outputs[0] hidden_states = encoder_outputs[0]
# If the model is only provided with either input_ids or inputs_embeds,
# use them as the inputs of the decoder. self.encoder checks for input_ids XOR inputs_embeds
if (decoder_input_ids is None) and (decoder_inputs_embeds is None):
decoder_input_ids = input_ids
decoder_inputs_embeds = inputs_embeds
# If decoding with past key value states, only the last tokens # If decoding with past key value states, only the last tokens
# should be given as an input # should be given as an input
if decoder_past_key_value_states is not None: if decoder_past_key_value_states is not None:
...@@ -1076,7 +1087,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1076,7 +1087,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small') >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> model = T5ForConditionalGeneration.from_pretrained('t5-small') >>> model = T5ForConditionalGeneration.from_pretrained('t5-small')
>>> input_ids = tokenizer.encode("Hello, my dog is cute", return_tensors="pt") # Batch size 1 >>> input_ids = tokenizer.encode("Hello, my dog is cute", return_tensors="pt") # Batch size 1
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids) >>> outputs = model(input_ids=input_ids, labels=input_ids)
>>> loss, prediction_scores = outputs[:2] >>> loss, prediction_scores = outputs[:2]
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small') >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
......
...@@ -351,6 +351,16 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -351,6 +351,16 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
model = T5Model.from_pretrained(model_name) model = T5Model.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
def test_export_to_onnx(self):
import tempfile
config_and_inputs = self.model_tester.prepare_config_and_inputs()
model = T5Model(config_and_inputs[0])
with tempfile.TemporaryDirectory() as tmpdirname:
torch.onnx.export(
model, config_and_inputs[1], f"{tmpdirname}/t5_test.onnx", export_params=True, opset_version=9,
)
@require_torch @require_torch
class T5ModelIntegrationTests(unittest.TestCase): class T5ModelIntegrationTests(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment