Unverified Commit e8acb700 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Pass attn_implementation when using AutoXXX.from_config (#30507)

* Pass attn_implementation when using AutoXXX.from_config

* Fix
parent 80126f98
...@@ -1194,9 +1194,13 @@ class Blip2Model(Blip2PreTrainedModel): ...@@ -1194,9 +1194,13 @@ class Blip2Model(Blip2PreTrainedModel):
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
if config.use_decoder_only_language_model: if config.use_decoder_only_language_model:
language_model = AutoModelForCausalLM.from_config(config.text_config) language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
else: else:
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) language_model = AutoModelForSeq2SeqLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
# Update _tied_weights_keys using the base model used. # Update _tied_weights_keys using the base model used.
if language_model._tied_weights_keys is not None: if language_model._tied_weights_keys is not None:
...@@ -1549,9 +1553,13 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel): ...@@ -1549,9 +1553,13 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
if config.use_decoder_only_language_model: if config.use_decoder_only_language_model:
language_model = AutoModelForCausalLM.from_config(config.text_config) language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
else: else:
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) language_model = AutoModelForSeq2SeqLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
# Update _tied_weights_keys using the base model used. # Update _tied_weights_keys using the base model used.
if language_model._tied_weights_keys is not None: if language_model._tied_weights_keys is not None:
......
...@@ -367,7 +367,9 @@ class DepthAnythingForDepthEstimation(DepthAnythingPreTrainedModel): ...@@ -367,7 +367,9 @@ class DepthAnythingForDepthEstimation(DepthAnythingPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.backbone = AutoBackbone.from_config(config.backbone_config) self.backbone = AutoBackbone.from_config(
config.backbone_config, attn_implementation=config._attn_implementation
)
self.neck = DepthAnythingNeck(config) self.neck = DepthAnythingNeck(config)
self.head = DepthAnythingDepthEstimationHead(config) self.head = DepthAnythingDepthEstimationHead(config)
......
...@@ -209,12 +209,12 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -209,12 +209,12 @@ class EncoderDecoderModel(PreTrainedModel):
if encoder is None: if encoder is None:
from ..auto.modeling_auto import AutoModel from ..auto.modeling_auto import AutoModel
encoder = AutoModel.from_config(config.encoder) encoder = AutoModel.from_config(config.encoder, attn_implementation=config._attn_implementation)
if decoder is None: if decoder is None:
from ..auto.modeling_auto import AutoModelForCausalLM from ..auto.modeling_auto import AutoModelForCausalLM
decoder = AutoModelForCausalLM.from_config(config.decoder) decoder = AutoModelForCausalLM.from_config(config.decoder, attn_implementation=config._attn_implementation)
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
......
...@@ -149,7 +149,9 @@ class FuyuForCausalLM(FuyuPreTrainedModel): ...@@ -149,7 +149,9 @@ class FuyuForCausalLM(FuyuPreTrainedModel):
super().__init__(config) super().__init__(config)
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config) self.language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
self.vision_embed_tokens = nn.Linear( self.vision_embed_tokens = nn.Linear(
config.patch_size * config.patch_size * config.num_channels, config.hidden_size config.patch_size * config.patch_size * config.num_channels, config.hidden_size
......
...@@ -1476,7 +1476,7 @@ class Idefics2Model(Idefics2PreTrainedModel): ...@@ -1476,7 +1476,7 @@ class Idefics2Model(Idefics2PreTrainedModel):
self.vision_model = Idefics2VisionTransformer(config.vision_config) self.vision_model = Idefics2VisionTransformer(config.vision_config)
self.connector = Idefics2Connector(config) self.connector = Idefics2Connector(config)
self.text_model = AutoModel.from_config(config.text_config) self.text_model = AutoModel.from_config(config.text_config, attn_implementation=config._attn_implementation)
self.image_seq_len = config.perceiver_config.resampler_n_latents self.image_seq_len = config.perceiver_config.resampler_n_latents
self.image_token_id = self.config.image_token_id self.image_token_id = self.config.image_token_id
......
...@@ -1251,9 +1251,13 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel): ...@@ -1251,9 +1251,13 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel):
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
if config.use_decoder_only_language_model: if config.use_decoder_only_language_model:
language_model = AutoModelForCausalLM.from_config(config.text_config) language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
else: else:
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) language_model = AutoModelForSeq2SeqLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
if language_model._no_split_modules is not None: if language_model._no_split_modules is not None:
self._no_split_modules.extend(language_model._no_split_modules) self._no_split_modules.extend(language_model._no_split_modules)
......
...@@ -506,12 +506,16 @@ class RagModel(RagPreTrainedModel): ...@@ -506,12 +506,16 @@ class RagModel(RagPreTrainedModel):
if question_encoder is None: if question_encoder is None:
from ..auto.modeling_auto import AutoModel from ..auto.modeling_auto import AutoModel
question_encoder = AutoModel.from_config(config.question_encoder) question_encoder = AutoModel.from_config(
config.question_encoder, attn_implementation=config._attn_implementation
)
if generator is None: if generator is None:
from ..auto.modeling_auto import AutoModelForSeq2SeqLM from ..auto.modeling_auto import AutoModelForSeq2SeqLM
generator = AutoModelForSeq2SeqLM.from_config(config.generator) generator = AutoModelForSeq2SeqLM.from_config(
config.generator, attn_implementation=config._attn_implementation
)
self.retriever = retriever self.retriever = retriever
if self.retriever is not None: if self.retriever is not None:
......
...@@ -212,10 +212,10 @@ class SpeechEncoderDecoderModel(PreTrainedModel): ...@@ -212,10 +212,10 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
super().__init__(config) super().__init__(config)
if encoder is None: if encoder is None:
encoder = AutoModel.from_config(config.encoder) encoder = AutoModel.from_config(config.encoder, attn_implementation=config._attn_implementation)
if decoder is None: if decoder is None:
decoder = AutoModelForCausalLM.from_config(config.decoder) decoder = AutoModelForCausalLM.from_config(config.decoder, attn_implementation=config._attn_implementation)
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
......
...@@ -190,10 +190,10 @@ class VisionEncoderDecoderModel(PreTrainedModel): ...@@ -190,10 +190,10 @@ class VisionEncoderDecoderModel(PreTrainedModel):
super().__init__(config) super().__init__(config)
if encoder is None: if encoder is None:
encoder = AutoModel.from_config(config.encoder) encoder = AutoModel.from_config(config.encoder, attn_implementation=config._attn_implementation)
if decoder is None: if decoder is None:
decoder = AutoModelForCausalLM.from_config(config.decoder) decoder = AutoModelForCausalLM.from_config(config.decoder, attn_implementation=config._attn_implementation)
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
......
...@@ -185,10 +185,12 @@ class VisionTextDualEncoderModel(PreTrainedModel): ...@@ -185,10 +185,12 @@ class VisionTextDualEncoderModel(PreTrainedModel):
if isinstance(config.vision_config, CLIPVisionConfig): if isinstance(config.vision_config, CLIPVisionConfig):
vision_model = CLIPVisionModel(config.vision_config) vision_model = CLIPVisionModel(config.vision_config)
else: else:
vision_model = AutoModel.from_config(config.vision_config) vision_model = AutoModel.from_config(
config.vision_config, attn_implementation=config._attn_implementation
)
if text_model is None: if text_model is None:
text_model = AutoModel.from_config(config.text_config) text_model = AutoModel.from_config(config.text_config, attn_implementation=config._attn_implementation)
self.vision_model = vision_model self.vision_model = vision_model
self.text_model = text_model self.text_model = text_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment