Unverified Commit 776e82d2 authored by ayushtiku5's avatar ayushtiku5 Committed by GitHub
Browse files

Add support to provide initial tokens to decoder of encoder-decoder type models (#7577)



* Add support to provide initial tokens for decoding

* Add docstring

* improve code quality

* code reformat

* code reformat

* minor change

* remove appending decoder start token
Co-authored-by: default avatarAyush Jain <a.jain@sprinklr.com>
parent 406a49df
...@@ -111,6 +111,7 @@ class GenerationMixin: ...@@ -111,6 +111,7 @@ class GenerationMixin:
def generate( def generate(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
max_length: Optional[int] = None, max_length: Optional[int] = None,
min_length: Optional[int] = None, min_length: Optional[int] = None,
do_sample: Optional[bool] = None, do_sample: Optional[bool] = None,
...@@ -151,6 +152,9 @@ class GenerationMixin: ...@@ -151,6 +152,9 @@ class GenerationMixin:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation. If :obj:`None` the method initializes The sequence used as a prompt for the generation. If :obj:`None` the method initializes
it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`. it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
initial input_ids for the decoder of encoder-decoder type models. If :obj:`None` then only
decoder_start_token_id is passed as the first token to the decoder.
max_length (:obj:`int`, `optional`, defaults to 20): max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated. The maximum length of the sequence to be generated.
min_length (:obj:`int`, `optional`, defaults to 10): min_length (:obj:`int`, `optional`, defaults to 10):
...@@ -417,14 +421,19 @@ class GenerationMixin: ...@@ -417,14 +421,19 @@ class GenerationMixin:
) # shape: (batch_size * num_return_sequences * num_beams, cur_len) ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
# create empty decoder input_ids device = next(self.parameters()).device
input_ids = torch.full( if decoder_input_ids is not None:
(effective_batch_size * num_beams, 1), # give initial decoder input ids
decoder_start_token_id, input_ids = decoder_input_ids.repeat(effective_batch_size * num_beams, 1).to(device)
dtype=torch.long, else:
device=next(self.parameters()).device, # create empty decoder input_ids
) input_ids = torch.full(
cur_len = 1 (effective_batch_size * num_beams, 1),
decoder_start_token_id,
dtype=torch.long,
device=device,
)
cur_len = input_ids.shape[-1]
assert ( assert (
batch_size == encoder_outputs.last_hidden_state.shape[0] batch_size == encoder_outputs.last_hidden_state.shape[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment