Unverified Commit b4fd49b6 authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

Update unwrap from accelerate (#29933)



* Use unwrap with the one in accelerate

* oups

* update unwrap

* fix

* wording

* raise error instead

* comment

* doc

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarZach Mueller <muellerzr@gmail.com>

* style

* put else

---------
Co-authored-by: default avatarZach Mueller <muellerzr@gmail.com>
parent fbd8c51f
...@@ -109,6 +109,7 @@ if is_accelerate_available(): ...@@ -109,6 +109,7 @@ if is_accelerate_available():
from accelerate.hooks import add_hook_to_module from accelerate.hooks import add_hook_to_module
from accelerate.utils import ( from accelerate.utils import (
check_tied_parameters_on_same_device, check_tied_parameters_on_same_device,
extract_model_from_parallel,
find_tied_parameters, find_tied_parameters,
get_balanced_memory, get_balanced_memory,
get_max_memory, get_max_memory,
...@@ -4805,13 +4806,29 @@ class SequenceSummary(nn.Module): ...@@ -4805,13 +4806,29 @@ class SequenceSummary(nn.Module):
return output return output
def unwrap_model(model: nn.Module) -> nn.Module: def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module:
""" """
Recursively unwraps a model from potential containers (as used in distributed training). Recursively unwraps a model from potential containers (as used in distributed training).
Args: Args:
model (`torch.nn.Module`): The model to unwrap. model (`torch.nn.Module`): The model to unwrap.
""" recursive (`bool`, *optional*, defaults to `False`):
Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers
recursively, not just the top-level distributed containers.
"""
# Use accelerate implementation if available (should always be the case when using torch)
# This is for pytorch, as we also have to handle things like dynamo
if is_accelerate_available():
kwargs = {}
if recursive:
if not is_accelerate_available("0.29.0"):
raise RuntimeError(
"Setting `recursive=True` to `unwrap_model` requires `accelerate` v0.29.0. Please upgrade your version of accelerate"
)
else:
kwargs["recursive"] = recursive
return extract_model_from_parallel(model, **kwargs)
else:
# since there could be multiple levels of wrapping, unwrap recursively # since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"): if hasattr(model, "module"):
return unwrap_model(model.module) return unwrap_model(model.module)
......
...@@ -63,7 +63,7 @@ from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_h ...@@ -63,7 +63,7 @@ from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_h
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from .integrations.tpu import tpu_spmd_dataloader from .integrations.tpu import tpu_spmd_dataloader
from .modelcard import TrainingSummary from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from .modeling_utils import PreTrainedModel, load_sharded_checkpoint
from .models.auto.modeling_auto import ( from .models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_MAPPING_NAMES, MODEL_MAPPING_NAMES,
...@@ -684,7 +684,7 @@ class Trainer: ...@@ -684,7 +684,7 @@ class Trainer:
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper:
https://arxiv.org/abs/2310.05914 https://arxiv.org/abs/2310.05914
""" """
unwrapped_model = unwrap_model(model) unwrapped_model = self.accelerator.unwrap_model(model)
if _is_peft_model(unwrapped_model): if _is_peft_model(unwrapped_model):
embeddings = unwrapped_model.base_model.model.get_input_embeddings() embeddings = unwrapped_model.base_model.model.get_input_embeddings()
...@@ -705,7 +705,7 @@ class Trainer: ...@@ -705,7 +705,7 @@ class Trainer:
if not hasattr(self, "neftune_hook_handle"): if not hasattr(self, "neftune_hook_handle"):
raise ValueError("Neftune is not activated make sure to call `trainer._activate_neftune()` first") raise ValueError("Neftune is not activated make sure to call `trainer._activate_neftune()` first")
unwrapped_model = unwrap_model(model) unwrapped_model = self.accelerator.unwrap_model(model)
if _is_peft_model(unwrapped_model): if _is_peft_model(unwrapped_model):
embeddings = unwrapped_model.base_model.model.get_input_embeddings() embeddings = unwrapped_model.base_model.model.get_input_embeddings()
...@@ -1617,7 +1617,7 @@ class Trainer: ...@@ -1617,7 +1617,7 @@ class Trainer:
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
if unwrap_model(model) is not model: if self.accelerator.unwrap_model(model) is not model:
return model return model
# Mixed precision training with apex (torch < 1.6) # Mixed precision training with apex (torch < 1.6)
...@@ -3165,7 +3165,7 @@ class Trainer: ...@@ -3165,7 +3165,7 @@ class Trainer:
self._past = outputs[self.args.past_index] self._past = outputs[self.args.past_index]
if labels is not None: if labels is not None:
unwrapped_model = unwrap_model(model) unwrapped_model = self.accelerator.unwrap_model(model)
if _is_peft_model(unwrapped_model): if _is_peft_model(unwrapped_model):
model_name = unwrapped_model.base_model.model._get_name() model_name = unwrapped_model.base_model.model._get_name()
else: else:
...@@ -3272,8 +3272,8 @@ class Trainer: ...@@ -3272,8 +3272,8 @@ class Trainer:
supported_classes = (PushToHubMixin,) supported_classes = (PushToHubMixin,)
xm.rendezvous("saving_checkpoint") xm.rendezvous("saving_checkpoint")
if not isinstance(model, supported_classes): if not isinstance(model, supported_classes):
if isinstance(unwrap_model(model), supported_classes): if isinstance(self.accelerator.unwrap_model(model), supported_classes):
unwrap_model(model).save_pretrained( self.accelerator.unwrap_model(model).save_pretrained(
output_dir, output_dir,
is_main_process=self.args.should_save, is_main_process=self.args.should_save,
state_dict=model.state_dict(), state_dict=model.state_dict(),
...@@ -3311,8 +3311,8 @@ class Trainer: ...@@ -3311,8 +3311,8 @@ class Trainer:
if state_dict is None: if state_dict is None:
state_dict = self.model.state_dict() state_dict = self.model.state_dict()
if isinstance(unwrap_model(self.model), supported_classes): if isinstance(self.accelerator.unwrap_model(self.model), supported_classes):
unwrap_model(self.model).save_pretrained( self.accelerator.unwrap_model(self.model).save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
) )
else: else:
...@@ -3969,7 +3969,7 @@ class Trainer: ...@@ -3969,7 +3969,7 @@ class Trainer:
f.write(model_card) f.write(model_card)
if is_peft_library: if is_peft_library:
unwrap_model(self.model).create_or_update_model_card(self.args.output_dir) self.accelerator.unwrap_model(self.model).create_or_update_model_card(self.args.output_dir)
def _push_from_checkpoint(self, checkpoint_folder): def _push_from_checkpoint(self, checkpoint_folder):
# Only push from one node. # Only push from one node.
......
...@@ -123,7 +123,6 @@ if is_torch_available(): ...@@ -123,7 +123,6 @@ if is_torch_available():
Trainer, Trainer,
TrainerState, TrainerState,
) )
from transformers.modeling_utils import unwrap_model
from transformers.trainer_pt_utils import AcceleratorConfig from transformers.trainer_pt_utils import AcceleratorConfig
if is_safetensors_available(): if is_safetensors_available():
...@@ -2468,8 +2467,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -2468,8 +2467,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer = get_regression_trainer(learning_rate=0.1) trainer = get_regression_trainer(learning_rate=0.1)
def assert_flos_extraction(trainer, wrapped_model_to_check): def assert_flos_extraction(trainer, wrapped_model_to_check):
self.assertEqual(trainer.model, unwrap_model(wrapped_model_to_check)) self.assertEqual(trainer.model, trainer.accelerator.unwrap_model(wrapped_model_to_check))
self.assertGreaterEqual(getattr(unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0) self.assertGreaterEqual(
getattr(trainer.accelerator.unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0
)
# with plain model # with plain model
assert_flos_extraction(trainer, trainer.model) assert_flos_extraction(trainer, trainer.model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment