"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "72b19ca680b5b9fb4cef6ed8c599c48d2449cb8b"
Unverified Commit 9895670e authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`InstructBlip`] Add accelerate support for instructblip (#24488)

* add accelerate support for instructblip

* add `_keep_in_fp32_modules`

* dynamically adapt `_no_split_modules`

* better fix

* same logic for `_keep_in_fp32_modules`
parent 57579238
......@@ -281,6 +281,8 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
r"language_model.decoder.embed_tokens.weight",
r"language_model.lm_head.weight",
]
_no_split_modules = ["InstructBlipAttention", "InstructBlipQFormerMultiHeadAttention"]
_keep_in_fp32_modules = []
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2PreTrainedModel._init_weights with Blip2->InstructBlip
def _init_weights(self, module):
......@@ -1264,11 +1266,18 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel):
self.qformer = InstructBlipQFormerModel(config.qformer_config)
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
if config.use_decoder_only_language_model:
language_model = AutoModelForCausalLM.from_config(config.text_config)
else:
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
if language_model._no_split_modules is not None:
self._no_split_modules.extend(language_model._no_split_modules)
if language_model._keep_in_fp32_modules is not None:
self._keep_in_fp32_modules.extend(language_model._keep_in_fp32_modules)
self.language_model = language_model
# Initialize weights and apply final processing
......@@ -1422,7 +1431,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
attention_mask = torch.cat([language_model_attention_mask, attention_mask], dim=1)
attention_mask = torch.cat([language_model_attention_mask.to(attention_mask.device), attention_mask], dim=1)
if self.config.use_decoder_only_language_model:
outputs = self.language_model(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment