Unverified Commit 184f60dc authored by jiqing-feng's avatar jiqing-feng Committed by GitHub
Browse files

add attention_mask and position_ids in assisted model (#26892)

* add attention_mask and position_ids in assisted model

* fix bug

* fix attention mask

* fix attention_mask

* check assist inputs

* check assist input ids length

* fix assist model type

* set assist attention mask device
parent cf32c941
......@@ -4488,11 +4488,6 @@ class GenerationMixin:
else:
num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
# check if assistant model accepts encoder_outputs
assistant_accepts_encoder_outputs = "encoder_outputs" in set(
inspect.signature(assistant_model.forward).parameters.keys()
)
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
......@@ -4535,15 +4530,6 @@ class GenerationMixin:
# other auxiliary variables
max_len = stopping_criteria[0].max_length
assistant_kv_indexing = (
1
if "bloom" in assistant_model.__class__.__name__.lower()
or (
assistant_model.config.architectures is not None
and "bloom" in assistant_model.config.architectures[0].lower()
)
else 0
)
this_peer_finished = False # used by synced_gpus only
while True:
......@@ -4564,42 +4550,28 @@ class GenerationMixin:
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
# need access to the assistant cache to secure strong speedups.
candidate_input_ids = input_ids
assistant_attention_mask = model_kwargs.get("attention_mask", None)
assistant_decoder_attention_mask = model_kwargs.get("decoder_attention_mask", None)
assistant_encoder_outputs = (model_kwargs.get("assistant_encoder_outputs", None),)
for _ in range(int(num_assistant_tokens)):
# 1.1. use the assistant model to obtain the next candidate logits
if "assistant_past_key_values" in model_kwargs:
prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2]
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
assist_inputs = candidate_input_ids[:, -new_token_len:]
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
assistant_inputs = assistant_model.prepare_inputs_for_generation(
candidate_input_ids,
attention_mask=assistant_attention_mask,
decoder_attention_mask=assistant_decoder_attention_mask,
encoder_outputs=assistant_encoder_outputs,
past_key_values=model_kwargs.get("assistant_past_key_values", None),
)
if assistant_inputs.get("past_key_values", None) is not None:
if assistant_model.config.is_encoder_decoder:
assistant_model_outputs = assistant_model(
decoder_input_ids=assist_inputs,
past_key_values=model_kwargs["assistant_past_key_values"],
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
)
input_ids_len = assistant_inputs["decoder_input_ids"].shape[-1]
else:
encoder_kwargs = {}
input_ids_len = assistant_inputs["input_ids"].shape[-1]
if assistant_accepts_encoder_outputs and "assistant_encoder_outputs" in model_kwargs:
encoder_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
assistant_model_outputs = assistant_model(
assist_inputs, past_key_values=model_kwargs["assistant_past_key_values"], **encoder_kwargs
)
else:
if assistant_model.config.is_encoder_decoder:
assistant_model_outputs = assistant_model(
decoder_input_ids=candidate_input_ids,
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
)
else:
encoder_kwargs = {}
if input_ids_len not in (1, 2):
raise ValueError("The length of the input ids in assistant inputs should be 1 or 2")
if assistant_accepts_encoder_outputs and "assistant_encoder_outputs" in model_kwargs:
encoder_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
assistant_model_outputs = assistant_model(candidate_input_ids, **encoder_kwargs)
assistant_model_outputs = assistant_model(**assistant_inputs)
# 1.2. greedily select the next candidate token
model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values
......@@ -4607,8 +4579,31 @@ class GenerationMixin:
assistant_model_outputs.logits[:, -1, :] = logits_processor(
candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
)
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
if assistant_model.config.is_encoder_decoder and assistant_decoder_attention_mask is not None:
assistant_decoder_attention_mask = torch.cat(
(
assistant_decoder_attention_mask,
torch.ones(
[1, 1],
dtype=assistant_decoder_attention_mask.dtype,
device=assistant_decoder_attention_mask.device,
),
),
dim=-1,
)
elif not assistant_model.config.is_encoder_decoder and assistant_attention_mask is not None:
assistant_attention_mask = torch.cat(
(
assistant_attention_mask,
torch.ones(
[1, 1], dtype=assistant_attention_mask.dtype, device=assistant_attention_mask.device
),
),
dim=-1,
)
# 1.3. stop assistant generation on EOS
if eos_token_id_tensor is not None:
......@@ -4744,6 +4739,13 @@ class GenerationMixin:
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# Update attention_mask for the assistant's next round of generations
if n_matches > 0 and model_kwargs.get("attention_mask", None) is not None:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], n_matches))], dim=-1
)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment