Unverified Commit 25411085 authored by Sebastian Husch Lee's avatar Sebastian Husch Lee Committed by GitHub
Browse files

[`T5`] Adding model_parallel = False to `T5ForQuestionAnswering` and...

[`T5`] Adding model_parallel = False to `T5ForQuestionAnswering` and `MT5ForQuestionAnswering` (#24684)

Adding model_parallel = False
parent 30ed3adf
...@@ -2032,6 +2032,8 @@ class MT5ForQuestionAnswering(MT5PreTrainedModel): ...@@ -2032,6 +2032,8 @@ class MT5ForQuestionAnswering(MT5PreTrainedModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
self.model_parallel = False
# Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_input_embeddings # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_input_embeddings
def get_input_embeddings(self): def get_input_embeddings(self):
return self.shared return self.shared
......
...@@ -1980,6 +1980,8 @@ class T5ForQuestionAnswering(T5PreTrainedModel): ...@@ -1980,6 +1980,8 @@ class T5ForQuestionAnswering(T5PreTrainedModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
self.model_parallel = False
def get_input_embeddings(self): def get_input_embeddings(self):
return self.shared return self.shared
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment