"docs/source/vscode:/vscode.git/clone" did not exist on "be79cd7d8e3911d67230443eef7faf0f486533db"
Unverified Commit 9f675b05 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer] self.model_wrapped + _model_unwrap (#9390)



* model wrapped + model_unwrap

* cleanup

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* style

* deprecation warning

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 453a70d4
...@@ -162,6 +162,14 @@ if is_fairscale_available(): ...@@ -162,6 +162,14 @@ if is_fairscale_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def _model_unwrap(model: nn.Module) -> nn.Module:
# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return _model_unwrap(model.module)
else:
return model
class Trainer: class Trainer:
""" """
Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers. Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
...@@ -212,6 +220,16 @@ class Trainer: ...@@ -212,6 +220,16 @@ class Trainer:
containing the optimizer and the scheduler to use. Will default to an instance of containing the optimizer and the scheduler to use. Will default to an instance of
:class:`~transformers.AdamW` on your model and a scheduler given by :class:`~transformers.AdamW` on your model and a scheduler given by
:func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`. :func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
Important accessors:
``self.model`` - always points to the core model. If using a transformers model, it will be a
:class:`PreTrainedModel` subclass.
``self.model_wrapped`` - always points to the most external model in case one or more other modules wrap the
original model. This is the model that should be used for the forward pass. For example, under ``DeepSpeed``,
the inner model is wrapped in ``DeepSpeed`` and then again in ``DistributedDataParallel``. If the inner model
hasn't been wrapped, then ``self.model_wrapped`` is the same as ``self.model``.
""" """
def __init__( def __init__(
...@@ -234,30 +252,37 @@ class Trainer: ...@@ -234,30 +252,37 @@ class Trainer:
self.args = args self.args = args
# Seed must be set before instantiating the model when using model # Seed must be set before instantiating the model when using model
set_seed(self.args.seed) set_seed(self.args.seed)
assert (
model is not None or model_init is not None
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
self.model_init = model_init
self.hp_name = None self.hp_name = None
if model is None and model_init is not None:
model = self.call_model_init()
if self.args.model_parallel and not model.is_parallelizable: if model is None:
raise ValueError( if model_init is not None:
f"{model.__class__.__name__} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used" self.model_init = model_init
model = self.call_model_init()
else:
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
else:
if model_init is not None:
warnings.warn(
"`Trainer` requires either a `model` or `model_init` argument, but not both. "
"`model_init` will overwrite your model when calling the `train` method. This will become a fatal error in the next release.",
FutureWarning,
) )
self.model_init = model_init
# Model parallel
if model is not None and not self.args.model_parallel:
model = model.to(args.device)
self.model = model
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
self.data_collator = data_collator if data_collator is not None else default_collator self.data_collator = data_collator if data_collator is not None else default_collator
self.train_dataset = train_dataset self.train_dataset = train_dataset
self.eval_dataset = eval_dataset self.eval_dataset = eval_dataset
self.tokenizer = tokenizer self.tokenizer = tokenizer
# Model parallel
if not self.args.model_parallel:
model = model.to(args.device)
# later use `self.model is self.model_wrapped` to check if it's wrapped or not
self.model_wrapped = model
self.model = model
self.compute_metrics = compute_metrics self.compute_metrics = compute_metrics
self.optimizer, self.lr_scheduler = optimizers self.optimizer, self.lr_scheduler = optimizers
if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None): if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
...@@ -640,9 +665,11 @@ class Trainer: ...@@ -640,9 +665,11 @@ class Trainer:
set_seed(self.args.seed) set_seed(self.args.seed)
model = self.call_model_init(trial) model = self.call_model_init(trial)
if not self.args.model_parallel: if not self.args.model_parallel:
self.model = model.to(self.args.device) model = model.to(self.args.device)
self.model = model
self.model_wrapped = model
# Reinitializes optimizer and scheduler # Reinitializes optimizer and scheduler
self.optimizer, self.lr_scheduler = None, None self.optimizer, self.lr_scheduler = None, None
...@@ -681,8 +708,9 @@ class Trainer: ...@@ -681,8 +708,9 @@ class Trainer:
# Check if saved optimizer or scheduler states exist # Check if saved optimizer or scheduler states exist
self._load_optimizer_and_scheduler(model_path) self._load_optimizer_and_scheduler(model_path)
model = self.model_wrapped
# Mixed precision training with apex (torch < 1.6) # Mixed precision training with apex (torch < 1.6)
model = self.model
if self.use_apex: if self.use_apex:
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
...@@ -707,6 +735,14 @@ class Trainer: ...@@ -707,6 +735,14 @@ class Trainer:
# find_unused_parameters breaks checkpointing as per # find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
# important: at this point:
# self.model is the Transformers Model
# self.model_wrapped is DDP(Transformers Model), DDP(Deepspeed(Transformers Model)), etc.
# Train! # Train!
if is_torch_tpu_available(): if is_torch_tpu_available():
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size() total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
...@@ -937,12 +973,10 @@ class Trainer: ...@@ -937,12 +973,10 @@ class Trainer:
self.control = self.callback_handler.on_save(self.args, self.state, self.control) self.control = self.callback_handler.on_save(self.args, self.state, self.control)
def _save_checkpoint(self, model, trial, metrics=None): def _save_checkpoint(self, model, trial, metrics=None):
# In all cases (even distributed/parallel), self.model is always a reference # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
# to the model we want to save. # want to save.
if hasattr(model, "module"): assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model"
assert model.module is self.model, f"Module {model.module} should be a reference to self.model"
else:
assert model is self.model, f"Model {model} should be a reference to self.model"
# Save model checkpoint # Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
...@@ -1630,30 +1664,7 @@ class Trainer: ...@@ -1630,30 +1664,7 @@ class Trainer:
Returns: Returns:
:obj:`int`: The number of floating-point operations. :obj:`int`: The number of floating-point operations.
""" """
if hasattr(self.model, "floating_point_ops"):
model = self._actual_model(self.model) return self.model.floating_point_ops(inputs)
if hasattr(model, "floating_point_ops"):
return model.floating_point_ops(inputs)
else: else:
return 0 return 0
@staticmethod
def _actual_model(
model: Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]
) -> torch.nn.modules.Module:
"""
Args:
model: (:obj:`Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]`):
Model object used during training
Returns:
:obj:`torch.nn.modules.Module`: unwrapped module
"""
if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
model = model.module
else:
model = model
return model
...@@ -53,6 +53,7 @@ if is_torch_available(): ...@@ -53,6 +53,7 @@ if is_torch_available():
Trainer, Trainer,
TrainerState, TrainerState,
) )
from transformers.trainer import _model_unwrap
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt" PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
...@@ -850,8 +851,8 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -850,8 +851,8 @@ class TrainerIntegrationTest(unittest.TestCase):
trainer = get_regression_trainer(learning_rate=0.1) trainer = get_regression_trainer(learning_rate=0.1)
def assert_flos_extraction(trainer, wrapped_model_to_check): def assert_flos_extraction(trainer, wrapped_model_to_check):
self.assertEqual(trainer.model, trainer._actual_model(wrapped_model_to_check)) self.assertEqual(trainer.model, _model_unwrap(wrapped_model_to_check))
self.assertGreaterEqual(getattr(trainer._actual_model(wrapped_model_to_check).config, "total_flos", 0), 0) self.assertGreaterEqual(getattr(_model_unwrap(wrapped_model_to_check).config, "total_flos", 0), 0)
# with plain model # with plain model
assert_flos_extraction(trainer, trainer.model) assert_flos_extraction(trainer, trainer.model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment