"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "5603f78fc46bd117e3f25cc8842eb08046bbff4e"
Unverified Commit 748006c0 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer] --model_parallel hasn't been implemented for most models (#9347)

* --model_parallel hasn't been implemented for most models

* make the help clear as well

* implement is_parallelizable; use it

* oops

* remove property
parent 4225740a
...@@ -404,6 +404,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -404,6 +404,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model. derived classes of the same architecture adding modules on top of the base model.
- **is_parallelizable** (:obj:`bool`) -- A flag indicating whether this model supports model parallelization.
""" """
config_class = None config_class = None
base_model_prefix = "" base_model_prefix = ""
...@@ -417,6 +418,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -417,6 +418,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# trained, but which are deterministic) # trained, but which are deterministic)
_keys_to_ignore_on_save = None _keys_to_ignore_on_save = None
is_parallelizable = False
@property @property
def dummy_inputs(self) -> Dict[str, torch.Tensor]: def dummy_inputs(self) -> Dict[str, torch.Tensor]:
""" """
......
...@@ -337,6 +337,7 @@ class GPT2PreTrainedModel(PreTrainedModel): ...@@ -337,6 +337,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
config_class = GPT2Config config_class = GPT2Config
load_tf_weights = load_tf_weights_in_gpt2 load_tf_weights = load_tf_weights_in_gpt2
base_model_prefix = "transformer" base_model_prefix = "transformer"
is_parallelizable = True
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
......
...@@ -683,6 +683,7 @@ class T5PreTrainedModel(PreTrainedModel): ...@@ -683,6 +683,7 @@ class T5PreTrainedModel(PreTrainedModel):
config_class = T5Config config_class = T5Config
load_tf_weights = load_tf_weights_in_t5 load_tf_weights = load_tf_weights_in_t5
base_model_prefix = "transformer" base_model_prefix = "transformer"
is_parallelizable = True
@property @property
def dummy_inputs(self): def dummy_inputs(self):
......
...@@ -242,6 +242,11 @@ class Trainer: ...@@ -242,6 +242,11 @@ class Trainer:
if model is None and model_init is not None: if model is None and model_init is not None:
model = self.call_model_init() model = self.call_model_init()
if self.args.model_parallel and not model.is_parallelizable:
raise ValueError(
f"{model.__class__.__name__} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used"
)
# Model parallel # Model parallel
if model is not None and not self.args.model_parallel: if model is not None and not self.args.model_parallel:
model = model.to(args.device) model = model.to(args.device)
......
...@@ -207,8 +207,8 @@ class TrainingArguments: ...@@ -207,8 +207,8 @@ class TrainingArguments:
:obj:`"eval_loss"`. :obj:`"eval_loss"`.
- :obj:`False` if :obj:`metric_for_best_model` is not set, or set to :obj:`"loss"` or :obj:`"eval_loss"`. - :obj:`False` if :obj:`metric_for_best_model` is not set, or set to :obj:`"loss"` or :obj:`"eval_loss"`.
model_parallel (:obj:`bool`, `optional`, defaults to :obj:`False`): model_parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
If there is more than one device, whether to use model parallelism to distribute the model's modules across If the model supports model parallelism and there is more than one device, whether to use model parallelism
devices or not. to distribute the model's modules across devices or not.
ignore_skip_data (:obj:`bool`, `optional`, defaults to :obj:`False`): ignore_skip_data (:obj:`bool`, `optional`, defaults to :obj:`False`):
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment