Unverified Commit 0f443436 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Actual fix (#9787)

parent fac7cfb1
......@@ -541,6 +541,7 @@ class GPT2Model(GPT2PreTrainedModel):
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.init_weights()
# Model parallel
self.model_parallel = False
self.device_map = None
......@@ -805,7 +806,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self.init_weights()
# Model parallel
self.model_parallel = False
self.device_map = None
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
......@@ -971,6 +974,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self.init_weights()
# Model parallel
self.model_parallel = False
self.device_map = None
def get_output_embeddings(self):
return self.lm_head
......@@ -1153,6 +1160,10 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
self.init_weights()
# Model parallel
self.model_parallel = False
self.device_map = None
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
......
......@@ -1651,6 +1651,10 @@ class T5EncoderModel(T5PreTrainedModel):
self.init_weights()
# Model parallel
self.model_parallel = False
self.device_map = None
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
self.device_map = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment