Unverified Commit 9895670e authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`InstructBlip`] Add accelerate support for instructblip (#24488)

* add accelerate support for instructblip

* add `_keep_in_fp32_modules`

* dynamically adapt `_no_split_modules`

* better fix

* same logic for `_keep_in_fp32_modules`
parent 57579238
...@@ -281,6 +281,8 @@ class InstructBlipPreTrainedModel(PreTrainedModel): ...@@ -281,6 +281,8 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
r"language_model.decoder.embed_tokens.weight", r"language_model.decoder.embed_tokens.weight",
r"language_model.lm_head.weight", r"language_model.lm_head.weight",
] ]
_no_split_modules = ["InstructBlipAttention", "InstructBlipQFormerMultiHeadAttention"]
_keep_in_fp32_modules = []
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2PreTrainedModel._init_weights with Blip2->InstructBlip # Copied from transformers.models.blip_2.modeling_blip_2.Blip2PreTrainedModel._init_weights with Blip2->InstructBlip
def _init_weights(self, module): def _init_weights(self, module):
...@@ -1264,11 +1266,18 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel): ...@@ -1264,11 +1266,18 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel):
self.qformer = InstructBlipQFormerModel(config.qformer_config) self.qformer = InstructBlipQFormerModel(config.qformer_config)
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
if config.use_decoder_only_language_model: if config.use_decoder_only_language_model:
language_model = AutoModelForCausalLM.from_config(config.text_config) language_model = AutoModelForCausalLM.from_config(config.text_config)
else: else:
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
if language_model._no_split_modules is not None:
self._no_split_modules.extend(language_model._no_split_modules)
if language_model._keep_in_fp32_modules is not None:
self._keep_in_fp32_modules.extend(language_model._keep_in_fp32_modules)
self.language_model = language_model self.language_model = language_model
# Initialize weights and apply final processing # Initialize weights and apply final processing
...@@ -1422,7 +1431,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel): ...@@ -1422,7 +1431,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
attention_mask = torch.cat([language_model_attention_mask, attention_mask], dim=1) attention_mask = torch.cat([language_model_attention_mask.to(attention_mask.device), attention_mask], dim=1)
if self.config.use_decoder_only_language_model: if self.config.use_decoder_only_language_model:
outputs = self.language_model( outputs = self.language_model(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment