"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "4a872caef4e70595202c64687a074f99772d8e92"
Unverified Commit 60b69f7d authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: detect special architectures when loaded from PEFT (#24198)

parent 97527898
......@@ -4232,6 +4232,15 @@ class GenerationMixin:
# other auxiliary variables
max_len = stopping_criteria[0].max_length
assistant_kv_indexing = (
1
if "bloom" in assistant_model.__class__.__name__.lower()
or (
assistant_model.config.architectures is not None
and "bloom" in assistant_model.config.architectures[0].lower()
)
else 0
)
this_peer_finished = False # used by synced_gpus only
while True:
......@@ -4247,7 +4256,6 @@ class GenerationMixin:
# Assistant: main logic start
cur_len = input_ids.shape[-1]
assistant_kv_indexing = 0 if "bloom" not in assistant_model.__class__.__name__.lower() else 1
# 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
......@@ -4512,7 +4520,10 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
)
)
past_key_values = tuple(new_past)
elif "bloom" in model.__class__.__name__.lower(): # bloom is special
# bloom is special
elif "bloom" in model.__class__.__name__.lower() or (
model.config.architectures is not None and "bloom" in model.config.architectures[0].lower()
):
for idx in range(len(past_key_values)):
new_past.append(
(
......@@ -4521,7 +4532,10 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
)
)
past_key_values = tuple(new_past)
elif "gptbigcode" in model.__class__.__name__.lower(): # gptbigcode is too
# gptbigcode is too
elif "gptbigcode" in model.__class__.__name__.lower() or (
model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower()
):
if model.config.multi_query:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :maximum_length, :]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment