Unverified Commit 0676d992 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`from_pretrained`] Make from_pretrained fast again (#27709)



* Skip nn.Module.reset_parameters

* Actually skip

* Check quality

* Maybe change all inits

* Fix init issues: only modify public functions

* Add a small test for now

* Style

* test updates

* style

* nice tes

* style

* make it even faster

* one more second

* remove fx icompatible

* Update tests/test_modeling_common.py
Co-authored-by: default avatarLysandre Debut <hi@lysand.re>

* Update tests/test_modeling_common.py
Co-authored-by: default avatarLysandre Debut <hi@lysand.re>

* skip

* fix quality

* protect the import

---------
Co-authored-by: default avatarLysandre Debut <hi@lysand.re>
parent 9f18cc6d
...@@ -154,6 +154,23 @@ else: ...@@ -154,6 +154,23 @@ else:
if is_peft_available(): if is_peft_available():
from .utils import find_adapter_config_file from .utils import find_adapter_config_file
TORCH_INIT_FUNCTIONS = {
"uniform_": nn.init.uniform_,
"normal_": nn.init.normal_,
"trunc_normal_": nn.init.trunc_normal_,
"constant_": nn.init.constant_,
"xavier_uniform_": nn.init.xavier_uniform_,
"xavier_normal_": nn.init.xavier_normal_,
"kaiming_uniform_": nn.init.kaiming_uniform_,
"kaiming_normal_": nn.init.kaiming_normal_,
"uniform": nn.init.uniform,
"normal": nn.init.normal,
"xavier_uniform": nn.init.xavier_uniform,
"xavier_normal": nn.init.xavier_normal,
"kaiming_uniform": nn.init.kaiming_uniform,
"kaiming_normal": nn.init.kaiming_normal,
}
@contextmanager @contextmanager
def no_init_weights(_enable=True): def no_init_weights(_enable=True):
...@@ -164,12 +181,24 @@ def no_init_weights(_enable=True): ...@@ -164,12 +181,24 @@ def no_init_weights(_enable=True):
""" """
global _init_weights global _init_weights
old_init_weights = _init_weights old_init_weights = _init_weights
if _enable: if _enable:
_init_weights = False _init_weights = False
def _skip_init(*args, **kwargs):
pass
# # Save the original initialization functions
for name, init_func in TORCH_INIT_FUNCTIONS.items():
setattr(torch.nn.init, name, _skip_init)
try: try:
yield yield
finally: finally:
_init_weights = old_init_weights _init_weights = old_init_weights
if _enable:
# # Restore the original initialization functions
for name, init_func in TORCH_INIT_FUNCTIONS.items():
setattr(torch.nn.init, name, init_func)
def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
...@@ -1506,7 +1535,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1506,7 +1535,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
def _init_weights(self, module): def _init_weights(self, module):
""" """
Initialize the weights. This method should be overridden by derived class. Initialize the weights. This method should be overridden by derived class and is
the only initialization method that will be called when loading a checkpoint
using `from_pretrained`. Any attempt to initialize outside of this function
will be useless as the torch.nn.init function are all replaced with skip.
""" """
pass pass
...@@ -3414,6 +3446,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3414,6 +3446,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
with ContextManagers(init_contexts): with ContextManagers(init_contexts):
# Let's make sure we don't run the init function of buffer modules
model = cls(config, *model_args, **model_kwargs) model = cls(config, *model_args, **model_kwargs)
# make sure we use the model's config since the __init__ call might have copied it # make sure we use the model's config since the __init__ call might have copied it
......
...@@ -36,8 +36,10 @@ from transformers import ( ...@@ -36,8 +36,10 @@ from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
PretrainedConfig, PretrainedConfig,
PreTrainedModel,
is_torch_available, is_torch_available,
logging, logging,
set_seed,
) )
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.models.auto.modeling_auto import ( from transformers.models.auto.modeling_auto import (
...@@ -85,7 +87,7 @@ from transformers.utils import ( ...@@ -85,7 +87,7 @@ from transformers.utils import (
is_torch_fx_available, is_torch_fx_available,
is_torch_sdpa_available, is_torch_sdpa_available,
) )
from transformers.utils.generic import ModelOutput from transformers.utils.generic import ContextManagers, ModelOutput
if is_accelerate_available(): if is_accelerate_available():
...@@ -99,6 +101,7 @@ if is_torch_available(): ...@@ -99,6 +101,7 @@ if is_torch_available():
from torch import nn from torch import nn
from transformers import MODEL_MAPPING, AdaptiveEmbedding from transformers import MODEL_MAPPING, AdaptiveEmbedding
from transformers.modeling_utils import no_init_weights
from transformers.pytorch_utils import id_tensor_storage from transformers.pytorch_utils import id_tensor_storage
...@@ -428,6 +431,56 @@ class ModelTesterMixin: ...@@ -428,6 +431,56 @@ class ModelTesterMixin:
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_fast_init_context_manager(self):
# 1. Create a dummy class. Should have buffers as well? To make sure we test __init__
class MyClass(PreTrainedModel):
config_class = PretrainedConfig
def __init__(self, config=None):
super().__init__(config if config is not None else PretrainedConfig())
self.linear = nn.Linear(10, 10, bias=True)
self.embedding = nn.Embedding(10, 10)
self.std = 1
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data = nn.init.kaiming_uniform_(module.weight.data, np.sqrt(5))
if module.bias is not None:
module.bias.data.normal_(mean=0.0, std=self.std)
# 2. Make sure a linear layer's reset params is properly skipped:
with ContextManagers([no_init_weights(True)]):
no_init_instance = MyClass()
set_seed(0)
expected_bias = torch.tensor(
([0.2975, 0.2131, -0.1379, -0.0796, -0.3012, -0.0057, -0.2381, -0.2439, -0.0174, 0.0475])
)
init_instance = MyClass()
torch.testing.assert_allclose(init_instance.linear.bias, expected_bias, rtol=1e-3, atol=1e-4)
set_seed(0)
torch.testing.assert_allclose(
init_instance.linear.weight, nn.init.kaiming_uniform_(no_init_instance.linear.weight, np.sqrt(5))
)
# 3. Make sure weights that are not present use init_weight_ and get expected values
with tempfile.TemporaryDirectory() as tmpdirname:
state_dict = init_instance.state_dict()
del state_dict["linear.weight"]
init_instance.config.save_pretrained(tmpdirname)
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
set_seed(0)
model_fast_init = MyClass.from_pretrained(tmpdirname)
set_seed(0)
model_slow_init = MyClass.from_pretrained(tmpdirname, _fast_init=False)
for key in model_fast_init.state_dict().keys():
max_diff = torch.max(torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]))
self.assertLessEqual(max_diff.item(), 1e-3, msg=f"{key} not identical")
def test_save_load_fast_init_to_base(self): def test_save_load_fast_init_to_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING: if config.__class__ not in MODEL_MAPPING:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment