Unverified Commit bbfb9fc2 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: starcoder 🤜 🤛 assisted generation (#23182)

* starcoder has joined the chat

* indexing that works for all
parent dbc12269
...@@ -4221,6 +4221,9 @@ class GenerationMixin: ...@@ -4221,6 +4221,9 @@ class GenerationMixin:
# keep track of which sequences are already finished # keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
# other auxiliary variables
max_len = stopping_criteria[0].max_length
this_peer_finished = False # used by synced_gpus only this_peer_finished = False # used by synced_gpus only
while True: while True:
if synced_gpus: if synced_gpus:
...@@ -4235,7 +4238,7 @@ class GenerationMixin: ...@@ -4235,7 +4238,7 @@ class GenerationMixin:
# Assistant: main logic start # Assistant: main logic start
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
max_len = stopping_criteria[0].max_length assistant_kv_indexing = 0 if "bloom" not in assistant_model.__class__.__name__.lower() else 1
# 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a # 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we # `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
...@@ -4244,7 +4247,7 @@ class GenerationMixin: ...@@ -4244,7 +4247,7 @@ class GenerationMixin:
for _ in range(int(assistant_model.max_assistant_tokens)): for _ in range(int(assistant_model.max_assistant_tokens)):
# 1.1. use the assistant model to obtain the next candidate logits # 1.1. use the assistant model to obtain the next candidate logits
if "assistant_past_key_values" in model_kwargs: if "assistant_past_key_values" in model_kwargs:
prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[2] prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2]
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
new_token_len = candidate_input_ids.shape[1] - prev_seq_len new_token_len = candidate_input_ids.shape[1] - prev_seq_len
assist_inputs = candidate_input_ids[:, -new_token_len:] assist_inputs = candidate_input_ids[:, -new_token_len:]
...@@ -4505,6 +4508,13 @@ def _crop_past_key_values(model, past_key_values, maximum_length): ...@@ -4505,6 +4508,13 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
) )
) )
past_key_values = tuple(new_past) past_key_values = tuple(new_past)
elif "gptbigcode" in model.__class__.__name__.lower(): # gptbigcode is too
if model.config.multi_query:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :maximum_length, :]
else:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :]
else: else:
for idx in range(len(past_key_values)): for idx in range(len(past_key_values)):
new_past.append( new_past.append(
......
...@@ -1473,7 +1473,7 @@ class GenerationTesterMixin: ...@@ -1473,7 +1473,7 @@ class GenerationTesterMixin:
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes # may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
if any( if any(
model_name in model_class.__name__.lower() model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text", "git", "prophetnet"] for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"]
): ):
return return
...@@ -1529,7 +1529,7 @@ class GenerationTesterMixin: ...@@ -1529,7 +1529,7 @@ class GenerationTesterMixin:
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes # may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
if any( if any(
model_name in model_class.__name__.lower() model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text", "git", "prophetnet"] for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"]
): ):
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment