Unverified Commit dee876ce authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer] refactor place_model_on_device logic, add deepspeed (#10243)

* refactor place_model_on_device logic, add deepspeed

* doc

* style
parent d1eb88f4
...@@ -214,6 +214,10 @@ class Trainer: ...@@ -214,6 +214,10 @@ class Trainer:
inner model hasn't been wrapped, then ``self.model_wrapped`` is the same as ``self.model``. inner model hasn't been wrapped, then ``self.model_wrapped`` is the same as ``self.model``.
- **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
data parallelism, this means some of the model layers are split on different GPUs). data parallelism, this means some of the model layers are split on different GPUs).
- **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
to :obj:`False` if model parallel or deepspeed is used, or if the default
``TrainingArguments.place_model_on_device`` is overridden to return :obj:`False` .
""" """
def __init__( def __init__(
...@@ -262,6 +266,11 @@ class Trainer: ...@@ -262,6 +266,11 @@ class Trainer:
else: else:
self.is_model_parallel = False self.is_model_parallel = False
# one place to sort out whether to place the model on device or not
self.place_model_on_device = args.place_model_on_device
if self.is_model_parallel or (args.deepspeed and args.do_train):
self.place_model_on_device = False
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
self.data_collator = data_collator if data_collator is not None else default_collator self.data_collator = data_collator if data_collator is not None else default_collator
self.train_dataset = train_dataset self.train_dataset = train_dataset
...@@ -272,7 +281,7 @@ class Trainer: ...@@ -272,7 +281,7 @@ class Trainer:
# 1. MP - since we are trying to fit a much bigger than 1 gpu model # 1. MP - since we are trying to fit a much bigger than 1 gpu model
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway, # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
# and we only use deepspeed for training at the moment # and we only use deepspeed for training at the moment
if not (self.is_model_parallel or (args.deepspeed and args.do_train)) and self.args.place_model_on_device: if self.place_model_on_device:
model = model.to(args.device) model = model.to(args.device)
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
...@@ -780,7 +789,7 @@ class Trainer: ...@@ -780,7 +789,7 @@ class Trainer:
# If model was re-initialized, put it on the right device and update self.model_wrapped # If model was re-initialized, put it on the right device and update self.model_wrapped
if model_reloaded: if model_reloaded:
if not self.is_model_parallel and self.args.place_model_on_device: if self.place_model_on_device:
self.model = self.model.to(self.args.device) self.model = self.model.to(self.args.device)
self.model_wrapped = self.model self.model_wrapped = self.model
...@@ -1033,7 +1042,7 @@ class Trainer: ...@@ -1033,7 +1042,7 @@ class Trainer:
) )
if isinstance(self.model, PreTrainedModel): if isinstance(self.model, PreTrainedModel):
self.model = self.model.from_pretrained(self.state.best_model_checkpoint) self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
if not self.is_model_parallel and self.args.place_model_on_device: if self.place_model_on_device:
self.model = self.model.to(self.args.device) self.model = self.model.to(self.args.device)
else: else:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)) state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment