Unverified Commit 60b69f7d authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: detect special architectures when loaded from PEFT (#24198)

parent 97527898
...@@ -4232,6 +4232,15 @@ class GenerationMixin: ...@@ -4232,6 +4232,15 @@ class GenerationMixin:
# other auxiliary variables # other auxiliary variables
max_len = stopping_criteria[0].max_length max_len = stopping_criteria[0].max_length
assistant_kv_indexing = (
1
if "bloom" in assistant_model.__class__.__name__.lower()
or (
assistant_model.config.architectures is not None
and "bloom" in assistant_model.config.architectures[0].lower()
)
else 0
)
this_peer_finished = False # used by synced_gpus only this_peer_finished = False # used by synced_gpus only
while True: while True:
...@@ -4247,7 +4256,6 @@ class GenerationMixin: ...@@ -4247,7 +4256,6 @@ class GenerationMixin:
# Assistant: main logic start # Assistant: main logic start
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
assistant_kv_indexing = 0 if "bloom" not in assistant_model.__class__.__name__.lower() else 1
# 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a # 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we # `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
...@@ -4512,7 +4520,10 @@ def _crop_past_key_values(model, past_key_values, maximum_length): ...@@ -4512,7 +4520,10 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
) )
) )
past_key_values = tuple(new_past) past_key_values = tuple(new_past)
elif "bloom" in model.__class__.__name__.lower(): # bloom is special # bloom is special
elif "bloom" in model.__class__.__name__.lower() or (
model.config.architectures is not None and "bloom" in model.config.architectures[0].lower()
):
for idx in range(len(past_key_values)): for idx in range(len(past_key_values)):
new_past.append( new_past.append(
( (
...@@ -4521,7 +4532,10 @@ def _crop_past_key_values(model, past_key_values, maximum_length): ...@@ -4521,7 +4532,10 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
) )
) )
past_key_values = tuple(new_past) past_key_values = tuple(new_past)
elif "gptbigcode" in model.__class__.__name__.lower(): # gptbigcode is too # gptbigcode is too
elif "gptbigcode" in model.__class__.__name__.lower() or (
model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower()
):
if model.config.multi_query: if model.config.multi_query:
for idx in range(len(past_key_values)): for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :maximum_length, :] past_key_values[idx] = past_key_values[idx][:, :maximum_length, :]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment