Unverified Commit 04b2f13c authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

🚨🚨🚨 Enforce single model initialization (#21431)

* Enforce single model initialization

* Add OneFormer example for problem 3

* Do it the Stas way

* Actually rename the uses...

* Rewrite test

* Try to change the test this way

* Fix all init slow/fast tests

* Break connection

* Fix more tests

* Fix test for initialization

* Remove custom test

* Quality

* Fix last failing tests

* The end?
parent 2020ac4b
...@@ -492,6 +492,48 @@ model = BrandNewBertModel(BrandNewBertConfig()) ...@@ -492,6 +492,48 @@ model = BrandNewBertModel(BrandNewBertConfig())
The above command will create a model according to the default parameters as defined in `BrandNewBertConfig()` with The above command will create a model according to the default parameters as defined in `BrandNewBertConfig()` with
random weights, thus making sure that the `init()` methods of all components works. random weights, thus making sure that the `init()` methods of all components works.
Note that all random initialization should happen in the `_init_weights` method of your `BrandnewBertPreTrainedModel`
class. It should initialize all leaf modules depending on the variables of the config. Here is an example with the
BERT `_init_weights` method:
```py
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
```
You can have some more custom schemes if you need a special initialization for some modules. For instance, in
`Wav2Vec2ForPreTraining`, the last two linear layers need to have the initialization of the regular PyTorch `nn.Linear`
but all the other ones should use an initialization as above. This is coded like this:
```py
def _init_weights(self, module):
"""Initialize the weights"""
if isinstnace(module, Wav2Vec2ForPreTraining):
module.project_hid.reset_parameters()
module.project_q.reset_parameters()
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
```
The `_is_hf_initialized` flag is internally used to make sure we only initialize a submodule once. By setting it to
`True` for `module.project_q` and `module.project_hid`, we make sure the custom initialization we did is not overridden later on,
the `_init_weights` function won't be applied to them.
**6. Write a conversion script** **6. Write a conversion script**
Next, you should write a conversion script that lets you convert the checkpoint you used to debug *brand_new_bert* in Next, you should write a conversion script that lets you convert the checkpoint you used to debug *brand_new_bert* in
......
...@@ -436,6 +436,17 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): ...@@ -436,6 +436,17 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
) )
def set_initialized_submodules(model, state_dict_keys):
"""
Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state
dict.
"""
for module_name, module in model.named_modules():
loaded_keys = [k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")]
if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0:
module._is_hf_initialized = True
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
# Convert old format to new format if needed from a PyTorch state_dict # Convert old format to new format if needed from a PyTorch state_dict
old_keys = [] old_keys = []
...@@ -1176,7 +1187,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1176,7 +1187,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
""" """
Initialize the weights. This method should be overridden by derived class. Initialize the weights. This method should be overridden by derived class.
""" """
raise NotImplementedError(f"Make sure `_init_weights` is implemented for {self.__class__}") pass
def _initialize_weights(self, module):
"""
Initialize the weights if they are not already initialized.
"""
if getattr(module, "_is_hf_initialized", False):
return
self._init_weights(module)
module._is_hf_initialized = True
def tie_weights(self): def tie_weights(self):
""" """
...@@ -1505,7 +1525,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1505,7 +1525,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
def init_weights(self): def init_weights(self):
""" """
If needed prunes and maybe initializes weights. If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
initialization logic in `_init_weights`.
""" """
# Prune heads if needed # Prune heads if needed
if self.config.pruned_heads: if self.config.pruned_heads:
...@@ -1513,7 +1534,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1513,7 +1534,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if _init_weights: if _init_weights:
# Initialize weights # Initialize weights
self.apply(self._init_weights) self.apply(self._initialize_weights)
# Tie weights should be skipped when not initializing all weights # Tie weights should be skipped when not initializing all weights
# since from_pretrained(...) calls tie weights anyways # since from_pretrained(...) calls tie weights anyways
...@@ -2713,11 +2734,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2713,11 +2734,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. # retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
if _fast_init: if _fast_init:
uninitialized_modules = model.retrieve_modules_from_names( if remove_prefix_from_model:
missing_keys, add_prefix=add_prefix_to_model, remove_prefix=remove_prefix_from_model _loaded_keys = [f"{prefix}.{k}" for k in loaded_keys]
) elif add_prefix_to_model:
for module in uninitialized_modules: _loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys]
model._init_weights(module) else:
_loaded_keys = loaded_keys
set_initialized_submodules(model, _loaded_keys)
# This will only initialize submodules that are not marked as initialized by the line above.
model.apply(model._initialize_weights)
# Set some modules to fp32 if any # Set some modules to fp32 if any
if keep_in_fp32_modules is not None: if keep_in_fp32_modules is not None:
......
...@@ -1067,10 +1067,12 @@ class AltCLIPPreTrainedModel(PreTrainedModel): ...@@ -1067,10 +1067,12 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
module.text_projection.weight, module.text_projection.weight,
std=module.text_embed_dim**-0.5 * self.config.initializer_factor, std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
) )
module.text_projection._is_hf_initialized = True
nn.init.normal_( nn.init.normal_(
module.visual_projection.weight, module.visual_projection.weight,
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
) )
module.visual_projection._is_hf_initialized = True
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
......
...@@ -1473,8 +1473,9 @@ class BartForSequenceClassification(BartPretrainedModel): ...@@ -1473,8 +1473,9 @@ class BartForSequenceClassification(BartPretrainedModel):
config.num_labels, config.num_labels,
config.classifier_dropout, config.classifier_dropout,
) )
self.model._init_weights(self.classification_head.dense)
self.model._init_weights(self.classification_head.out_proj) # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -1601,7 +1602,8 @@ class BartForQuestionAnswering(BartPretrainedModel): ...@@ -1601,7 +1602,8 @@ class BartForQuestionAnswering(BartPretrainedModel):
self.model = BartModel(config) self.model = BartModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.model._init_weights(self.qa_outputs) # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
......
...@@ -2658,8 +2658,9 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel): ...@@ -2658,8 +2658,9 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
config.num_labels, config.num_labels,
config.classifier_dropout, config.classifier_dropout,
) )
self.model._init_weights(self.classification_head.dense)
self.model._init_weights(self.classification_head.out_proj) # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -2785,7 +2786,8 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel): ...@@ -2785,7 +2786,8 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
self.model = BigBirdPegasusModel(config) self.model = BigBirdPegasusModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.model._init_weights(self.qa_outputs) # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
......
...@@ -1186,6 +1186,9 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel): ...@@ -1186,6 +1186,9 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
base_model = FSMTModel(config) base_model = FSMTModel(config)
self.model = base_model self.model = base_model
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(FSMT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(FSMT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@add_end_docstrings(FSMT_GENERATION_EXAMPLE) @add_end_docstrings(FSMT_GENERATION_EXAMPLE)
......
...@@ -2543,8 +2543,9 @@ class LEDForSequenceClassification(LEDPreTrainedModel): ...@@ -2543,8 +2543,9 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
config.num_labels, config.num_labels,
config.classifier_dropout, config.classifier_dropout,
) )
self.led._init_weights(self.classification_head.dense)
self.led._init_weights(self.classification_head.out_proj) # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -2672,7 +2673,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel): ...@@ -2672,7 +2673,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
self.led = LEDModel(config) self.led = LEDModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.led._init_weights(self.qa_outputs) # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
......
...@@ -866,6 +866,9 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin): ...@@ -866,6 +866,9 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
self.hidden_states_norms = nn.ModuleList([nn.LayerNorm(num_channels) for num_channels in self.channels]) self.hidden_states_norms = nn.ModuleList([nn.LayerNorm(num_channels) for num_channels in self.channels])
# Initialize weights and apply final processing
self.post_init()
@property @property
def channels(self): def channels(self):
return [self.out_feature_channels[name] for name in self.out_features] return [self.out_feature_channels[name] for name in self.out_features]
......
...@@ -1447,8 +1447,9 @@ class MBartForSequenceClassification(MBartPreTrainedModel): ...@@ -1447,8 +1447,9 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
config.num_labels, config.num_labels,
config.classifier_dropout, config.classifier_dropout,
) )
self.model._init_weights(self.classification_head.dense)
self.model._init_weights(self.classification_head.out_proj) # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -1574,7 +1575,8 @@ class MBartForQuestionAnswering(MBartPreTrainedModel): ...@@ -1574,7 +1575,8 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
self.model = MBartModel(config) self.model = MBartModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.model._init_weights(self.qa_outputs) # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
......
...@@ -1610,8 +1610,8 @@ class MvpForSequenceClassification(MvpPreTrainedModel): ...@@ -1610,8 +1610,8 @@ class MvpForSequenceClassification(MvpPreTrainedModel):
config.classifier_dropout, config.classifier_dropout,
) )
self.model._init_weights(self.classification_head.dense) # Initialize weights and apply final processing
self.model._init_weights(self.classification_head.out_proj) self.post_init()
def set_lightweight_tuning(self): def set_lightweight_tuning(self):
self.model.set_lightweight_tuning() self.model.set_lightweight_tuning()
...@@ -1737,7 +1737,8 @@ class MvpForQuestionAnswering(MvpPreTrainedModel): ...@@ -1737,7 +1737,8 @@ class MvpForQuestionAnswering(MvpPreTrainedModel):
self.model = MvpModel(config) self.model = MvpModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.model._init_weights(self.qa_outputs) # Initialize weights and apply final processing
self.post_init()
def set_lightweight_tuning(self): def set_lightweight_tuning(self):
self.model.set_lightweight_tuning() self.model.set_lightweight_tuning()
......
...@@ -2801,6 +2801,7 @@ class OneFormerPreTrainedModel(PreTrainedModel): ...@@ -2801,6 +2801,7 @@ class OneFormerPreTrainedModel(PreTrainedModel):
elif isinstance(module, OneFormerTransformerDecoder): elif isinstance(module, OneFormerTransformerDecoder):
nn.init.xavier_uniform_(module.query_input_projection.weight, gain=xavier_std) nn.init.xavier_uniform_(module.query_input_projection.weight, gain=xavier_std)
nn.init.constant_(module.query_input_projection.bias, 0) nn.init.constant_(module.query_input_projection.bias, 0)
module.query_input_projection._is_hf_initialized = True
elif isinstance(module, OneFormerPixelDecoderEncoderMultiscaleDeformableAttention): elif isinstance(module, OneFormerPixelDecoderEncoderMultiscaleDeformableAttention):
nn.init.constant_(module.sampling_offsets.weight.data, 0.0) nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(module.n_heads, dtype=torch.float32) * (2.0 * math.pi / module.n_heads) thetas = torch.arange(module.n_heads, dtype=torch.float32) * (2.0 * math.pi / module.n_heads)
......
...@@ -1420,8 +1420,9 @@ class PLBartForSequenceClassification(PLBartPreTrainedModel): ...@@ -1420,8 +1420,9 @@ class PLBartForSequenceClassification(PLBartPreTrainedModel):
config.num_labels, config.num_labels,
config.classifier_dropout, config.classifier_dropout,
) )
self.model._init_weights(self.classification_head.dense)
self.model._init_weights(self.classification_head.out_proj) # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
......
...@@ -301,6 +301,12 @@ class UperNetPreTrainedModel(PreTrainedModel): ...@@ -301,6 +301,12 @@ class UperNetPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
def _init_weights(self, module):
if isinstance(module, UperNetPreTrainedModel):
module.backbone.init_weights()
module.decode_head.init_weights()
module.auxiliary_head.init_weights()
def init_weights(self): def init_weights(self):
"""Initialize the weights""" """Initialize the weights"""
self.backbone.init_weights() self.backbone.init_weights()
......
...@@ -1049,8 +1049,14 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): ...@@ -1049,8 +1049,14 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
# Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
if isinstance(module, Wav2Vec2ForPreTraining):
module.project_hid.reset_parameters()
module.project_q.reset_parameters()
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
# gumbel softmax requires special init # gumbel softmax requires special init
if isinstance(module, Wav2Vec2GumbelVectorQuantizer): elif isinstance(module, Wav2Vec2GumbelVectorQuantizer):
module.weight_proj.weight.data.normal_(mean=0.0, std=1) module.weight_proj.weight.data.normal_(mean=0.0, std=1)
module.weight_proj.bias.data.zero_() module.weight_proj.bias.data.zero_()
nn.init.uniform_(module.codevectors) nn.init.uniform_(module.codevectors)
...@@ -1345,13 +1351,12 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel): ...@@ -1345,13 +1351,12 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
self.quantizer = Wav2Vec2GumbelVectorQuantizer(config) self.quantizer = Wav2Vec2GumbelVectorQuantizer(config)
# Initialize weights and apply final processing
self.post_init()
# make sure that project_hid & project_q are initialized like normal linear layers
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim) self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim) self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
# Initialize weights and apply final processing
self.post_init()
def set_gumbel_temperature(self, temperature: int): def set_gumbel_temperature(self, temperature: int):
""" """
Set the Gumbel softmax temperature to a given value. Only necessary for training Set the Gumbel softmax temperature to a given value. Only necessary for training
......
...@@ -1089,8 +1089,14 @@ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel): ...@@ -1089,8 +1089,14 @@ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
# Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
if isinstance(module, Wav2Vec2ConformerForPreTraining):
module.project_hid.reset_parameters()
module.project_q.reset_parameters()
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
# gumbel softmax requires special init # gumbel softmax requires special init
if isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer): elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
module.weight_proj.weight.data.normal_(mean=0.0, std=1) module.weight_proj.weight.data.normal_(mean=0.0, std=1)
module.weight_proj.bias.data.zero_() module.weight_proj.bias.data.zero_()
nn.init.uniform_(module.codevectors) nn.init.uniform_(module.codevectors)
...@@ -1381,13 +1387,12 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel): ...@@ -1381,13 +1387,12 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config) self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config)
# Initialize weights and apply final processing
self.post_init()
# make sure that project_hid & project_q are initialized like normal linear layers
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim) self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim) self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature
def set_gumbel_temperature(self, temperature: int): def set_gumbel_temperature(self, temperature: int):
""" """
......
...@@ -962,7 +962,6 @@ class WavLMAdapterLayer(nn.Module): ...@@ -962,7 +962,6 @@ class WavLMAdapterLayer(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PreTrainedModel with Wav2Vec2->WavLM, wav2vec2->wavlm
class WavLMPreTrainedModel(PreTrainedModel): class WavLMPreTrainedModel(PreTrainedModel):
""" """
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
......
...@@ -1496,3 +1496,6 @@ class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, un ...@@ -1496,3 +1496,6 @@ class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, un
def test_retain_grad_hidden_states_attentions(self): def test_retain_grad_hidden_states_attentions(self):
# decoder cannot keep gradients # decoder cannot keep gradients
return return
def test_save_load_fast_init_from_base(self):
pass
...@@ -410,17 +410,23 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -410,17 +410,23 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
configs_no_init = _config_zero_init(config) configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config=configs_no_init) model = model_class(config=configs_no_init)
# Skip the check for the backbone
for name, module in model.named_modules():
if module.__class__.__name__ == "DetaBackboneWithPositionalEncodings":
backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()]
break
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.requires_grad: if param.requires_grad:
if param.requires_grad: if (
if ( "level_embed" in name
"level_embed" in name or "sampling_offsets.bias" in name
or "sampling_offsets.bias" in name or "value_proj" in name
or "value_proj" in name or "output_proj" in name
or "output_proj" in name or "reference_points" in name
or "reference_points" in name or name in backbone_params
): ):
continue continue
self.assertIn( self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(), ((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0], [0.0, 1.0],
......
...@@ -24,7 +24,7 @@ from transformers.models.auto import get_values ...@@ -24,7 +24,7 @@ from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
if is_torch_available(): if is_torch_available():
...@@ -242,6 +242,29 @@ class DPTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -242,6 +242,29 @@ class DPTModelTest(ModelTesterMixin, unittest.TestCase):
loss = model(**inputs).loss loss = model(**inputs).loss
loss.backward() loss.backward()
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
# Skip the check for the backbone
backbone_params = []
for name, module in model.named_modules():
if module.__class__.__name__ == "DPTViTHybridEmbeddings":
backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()]
break
for name, param in model.named_parameters():
if param.requires_grad:
if name in backbone_params:
continue
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in DPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in DPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
...@@ -24,7 +24,7 @@ from transformers.models.auto import get_values ...@@ -24,7 +24,7 @@ from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
if is_torch_available(): if is_torch_available():
...@@ -256,6 +256,29 @@ class DPTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -256,6 +256,29 @@ class DPTModelTest(ModelTesterMixin, unittest.TestCase):
loss = model(**inputs).loss loss = model(**inputs).loss
loss.backward() loss.backward()
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
# Skip the check for the backbone
backbone_params = []
for name, module in model.named_modules():
if module.__class__.__name__ == "DPTViTHybridEmbeddings":
backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()]
break
for name, param in model.named_parameters():
if param.requires_grad:
if name in backbone_params:
continue
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in DPT_PRETRAINED_MODEL_ARCHIVE_LIST[1:]: for model_name in DPT_PRETRAINED_MODEL_ARCHIVE_LIST[1:]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment