Unverified Commit 9f675b05 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer] self.model_wrapped + _model_unwrap (#9390)



* model wrapped + model_unwrap

* cleanup

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* style

* deprecation warning

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 453a70d4
......@@ -162,6 +162,14 @@ if is_fairscale_available():
logger = logging.get_logger(__name__)
def _model_unwrap(model: nn.Module) -> nn.Module:
# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return _model_unwrap(model.module)
else:
return model
class Trainer:
"""
Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
......@@ -212,6 +220,16 @@ class Trainer:
containing the optimizer and the scheduler to use. Will default to an instance of
:class:`~transformers.AdamW` on your model and a scheduler given by
:func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
Important accessors:
``self.model`` - always points to the core model. If using a transformers model, it will be a
:class:`PreTrainedModel` subclass.
``self.model_wrapped`` - always points to the most external model in case one or more other modules wrap the
original model. This is the model that should be used for the forward pass. For example, under ``DeepSpeed``,
the inner model is wrapped in ``DeepSpeed`` and then again in ``DistributedDataParallel``. If the inner model
hasn't been wrapped, then ``self.model_wrapped`` is the same as ``self.model``.
"""
def __init__(
......@@ -234,30 +252,37 @@ class Trainer:
self.args = args
# Seed must be set before instantiating the model when using model
set_seed(self.args.seed)
assert (
model is not None or model_init is not None
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
self.model_init = model_init
self.hp_name = None
if model is None and model_init is not None:
model = self.call_model_init()
if self.args.model_parallel and not model.is_parallelizable:
raise ValueError(
f"{model.__class__.__name__} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used"
)
# Model parallel
if model is not None and not self.args.model_parallel:
model = model.to(args.device)
if model is None:
if model_init is not None:
self.model_init = model_init
model = self.call_model_init()
else:
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
else:
if model_init is not None:
warnings.warn(
"`Trainer` requires either a `model` or `model_init` argument, but not both. "
"`model_init` will overwrite your model when calling the `train` method. This will become a fatal error in the next release.",
FutureWarning,
)
self.model_init = model_init
self.model = model
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
self.data_collator = data_collator if data_collator is not None else default_collator
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.tokenizer = tokenizer
# Model parallel
if not self.args.model_parallel:
model = model.to(args.device)
# later use `self.model is self.model_wrapped` to check if it's wrapped or not
self.model_wrapped = model
self.model = model
self.compute_metrics = compute_metrics
self.optimizer, self.lr_scheduler = optimizers
if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
......@@ -640,9 +665,11 @@ class Trainer:
set_seed(self.args.seed)
model = self.call_model_init(trial)
if not self.args.model_parallel:
self.model = model.to(self.args.device)
model = model.to(self.args.device)
self.model = model
self.model_wrapped = model
# Reinitializes optimizer and scheduler
self.optimizer, self.lr_scheduler = None, None
......@@ -681,8 +708,9 @@ class Trainer:
# Check if saved optimizer or scheduler states exist
self._load_optimizer_and_scheduler(model_path)
model = self.model_wrapped
# Mixed precision training with apex (torch < 1.6)
model = self.model
if self.use_apex:
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
......@@ -707,6 +735,14 @@ class Trainer:
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
# important: at this point:
# self.model is the Transformers Model
# self.model_wrapped is DDP(Transformers Model), DDP(Deepspeed(Transformers Model)), etc.
# Train!
if is_torch_tpu_available():
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
......@@ -937,12 +973,10 @@ class Trainer:
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
def _save_checkpoint(self, model, trial, metrics=None):
# In all cases (even distributed/parallel), self.model is always a reference
# to the model we want to save.
if hasattr(model, "module"):
assert model.module is self.model, f"Module {model.module} should be a reference to self.model"
else:
assert model is self.model, f"Model {model} should be a reference to self.model"
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
# want to save.
assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model"
# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
......@@ -1630,30 +1664,7 @@ class Trainer:
Returns:
:obj:`int`: The number of floating-point operations.
"""
model = self._actual_model(self.model)
if hasattr(model, "floating_point_ops"):
return model.floating_point_ops(inputs)
if hasattr(self.model, "floating_point_ops"):
return self.model.floating_point_ops(inputs)
else:
return 0
@staticmethod
def _actual_model(
model: Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]
) -> torch.nn.modules.Module:
"""
Args:
model: (:obj:`Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]`):
Model object used during training
Returns:
:obj:`torch.nn.modules.Module`: unwrapped module
"""
if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
model = model.module
else:
model = model
return model
......@@ -53,6 +53,7 @@ if is_torch_available():
Trainer,
TrainerState,
)
from transformers.trainer import _model_unwrap
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
......@@ -850,8 +851,8 @@ class TrainerIntegrationTest(unittest.TestCase):
trainer = get_regression_trainer(learning_rate=0.1)
def assert_flos_extraction(trainer, wrapped_model_to_check):
self.assertEqual(trainer.model, trainer._actual_model(wrapped_model_to_check))
self.assertGreaterEqual(getattr(trainer._actual_model(wrapped_model_to_check).config, "total_flos", 0), 0)
self.assertEqual(trainer.model, _model_unwrap(wrapped_model_to_check))
self.assertGreaterEqual(getattr(_model_unwrap(wrapped_model_to_check).config, "total_flos", 0), 0)
# with plain model
assert_flos_extraction(trainer, trainer.model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment