Unverified Commit 976189a6 authored by Xuehai Pan's avatar Xuehai Pan Committed by GitHub
Browse files

Fix initialization for missing parameters in `from_pretrained` under ZeRO-3 (#28245)

* Fix initialization for missing parameters in `from_pretrained` under ZeRO-3

* Test initialization for missing parameters under ZeRO-3

* Add more tests

* Only enable deepspeed context for per-module level parameters

* Enable deepspeed context only once

* Move class definition inside test case body
parent 357971ec
......@@ -19,6 +19,7 @@ import functools
import gc
import importlib.metadata
import inspect
import itertools
import json
import os
import re
......@@ -544,10 +545,14 @@ def set_initialized_submodules(model, state_dict_keys):
Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state
dict.
"""
not_initialized_submodules = {}
for module_name, module in model.named_modules():
loaded_keys = [k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")]
if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0:
loaded_keys = {k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")}
if loaded_keys.issuperset(module.state_dict()):
module._is_hf_initialized = True
else:
not_initialized_submodules[module_name] = module
return not_initialized_submodules
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
......@@ -3917,7 +3922,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif add_prefix_to_model:
expected_keys = [".".join([prefix, s]) for s in expected_keys]
missing_keys = list(set(expected_keys) - set(loaded_keys))
missing_keys = sorted(set(expected_keys) - set(loaded_keys))
unexpected_keys = set(loaded_keys) - set(expected_keys)
# Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model
# buffers
......@@ -3926,10 +3931,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model_buffers = {key[len(_prefix) :] if key.startswith(_prefix) else key for key in model_buffers}
elif add_prefix_to_model:
model_buffers = {".".join([prefix, key]) for key in model_buffers}
unexpected_keys = list(unexpected_keys - model_buffers)
unexpected_keys = sorted(unexpected_keys - model_buffers)
model.tie_weights()
if device_map is None and not is_fsdp_enabled():
if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
ptrs = collections.defaultdict(list)
for name, tensor in model.state_dict().items():
id_tensor = id_tensor_storage(tensor)
......@@ -4000,8 +4005,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
_loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys]
else:
_loaded_keys = loaded_keys
set_initialized_submodules(model, _loaded_keys)
not_initialized_submodules = set_initialized_submodules(model, _loaded_keys)
else:
not_initialized_submodules = dict(model.named_modules())
# This will only initialize submodules that are not marked as initialized by the line above.
if is_deepspeed_zero3_enabled():
import deepspeed
not_initialized_parameters = list(
set(
itertools.chain.from_iterable(
submodule.parameters(recurse=False) for submodule in not_initialized_submodules.values()
)
)
)
with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0):
model.apply(model._initialize_weights)
else:
model.apply(model._initialize_weights)
# Set some modules to fp32 if any
......
......@@ -225,6 +225,78 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
AutoModel.from_pretrained(T5_TINY)
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)
def test_init_zero3_missing_params(self):
# test that zero.Init() for missing parameters works correctly under zero3
import deepspeed
import torch
from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel
class TinyGPT2WithUninitializedWeights(GPT2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.transformer = AutoModel.from_pretrained(GPT2_TINY, config=config)
self.new_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=True)
def forward(self, *args, **kwargs):
transformer_outputs = self.transformer(*args, **kwargs)
hidden_states = transformer_outputs[0]
return self.new_head(hidden_states).float()
def _init_weights(self, module):
super()._init_weights(module)
if module is self.new_head:
self.new_head.weight.data.fill_(-100.0)
self.new_head.bias.data.fill_(+100.0)
ds_config = {
"train_batch_size": 1,
"zero_optimization": {
"stage": 3,
},
}
dschf = HfDeepSpeedConfig(ds_config)
self.assertTrue(dschf.is_zero3())
self.assertTrue(is_deepspeed_zero3_enabled())
with LoggingLevel(logging.INFO):
with mockenv_context(**self.dist_env_1_gpu):
logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl:
model = TinyGPT2WithUninitializedWeights.from_pretrained(GPT2_TINY)
self.assertIn("Detected DeepSpeed ZeRO-3", cl.out)
self.assertRegex(cl.out, r"newly initialized.*new_head\.bias.*new_head\.weight")
with deepspeed.zero.GatheredParameters([model.new_head.weight, model.new_head.bias]):
self.assertTrue(
torch.allclose(model.new_head.weight, torch.tensor(-100.0, device=model.new_head.weight.device)),
)
self.assertTrue(
torch.allclose(model.new_head.bias, torch.tensor(+100.0, device=model.new_head.bias.device)),
)
# now remove zero optimization
del ds_config["zero_optimization"]
dschf = HfDeepSpeedConfig(ds_config)
self.assertFalse(dschf.is_zero3())
self.assertFalse(is_deepspeed_zero3_enabled())
with LoggingLevel(logging.INFO):
with mockenv_context(**self.dist_env_1_gpu):
logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl:
model = TinyGPT2WithUninitializedWeights.from_pretrained(GPT2_TINY)
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)
self.assertRegex(cl.out, r"newly initialized.*new_head\.bias.*new_head\.weight")
self.assertTrue(
torch.allclose(model.new_head.weight, torch.tensor(-100.0, device=model.new_head.weight.device)),
)
self.assertTrue(
torch.allclose(model.new_head.bias, torch.tensor(+100.0, device=model.new_head.bias.device)),
)
class TrainerIntegrationDeepSpeedWithCustomConfig(TestCasePlus):
def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment