"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "f2c1df93f5bb13a57de21e355836b7aa7c820d63"
Unverified Commit 0f443436 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Actual fix (#9787)

parent fac7cfb1
...@@ -541,6 +541,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -541,6 +541,7 @@ class GPT2Model(GPT2PreTrainedModel):
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.init_weights() self.init_weights()
# Model parallel # Model parallel
self.model_parallel = False self.model_parallel = False
self.device_map = None self.device_map = None
...@@ -805,7 +806,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -805,7 +806,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self.init_weights() self.init_weights()
# Model parallel
self.model_parallel = False self.model_parallel = False
self.device_map = None
@add_start_docstrings(PARALLELIZE_DOCSTRING) @add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None): def parallelize(self, device_map=None):
...@@ -971,6 +974,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -971,6 +974,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self.init_weights() self.init_weights()
# Model parallel
self.model_parallel = False
self.device_map = None
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
...@@ -1153,6 +1160,10 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel): ...@@ -1153,6 +1160,10 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
self.init_weights() self.init_weights()
# Model parallel
self.model_parallel = False
self.device_map = None
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -1651,6 +1651,10 @@ class T5EncoderModel(T5PreTrainedModel): ...@@ -1651,6 +1651,10 @@ class T5EncoderModel(T5PreTrainedModel):
self.init_weights() self.init_weights()
# Model parallel
self.model_parallel = False
self.device_map = None
@add_start_docstrings(PARALLELIZE_DOCSTRING) @add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None): def parallelize(self, device_map=None):
self.device_map = ( self.device_map = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment