Unverified Commit a9c8bff7 authored by Ahmed Elnaggar's avatar Ahmed Elnaggar Committed by GitHub
Browse files

Add parallelization support for T5EncoderModel (#9082)



* add model parallelism to T5EncoderModel

add model parallelism to T5EncoderModel

* remove decoder from T5EncoderModel parallelize

* uodate T5EncoderModel docs

* Extend T5ModelTest for T5EncoderModel

* fix T5Stask using range for get_device_map

* fix style
Co-authored-by: default avatarAhmed Elnaggar <elnaggar@rostlab.informatik.tu-muenchen.de>
parent b00eb4fb
...@@ -131,7 +131,7 @@ T5EncoderModel ...@@ -131,7 +131,7 @@ T5EncoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.T5EncoderModel .. autoclass:: transformers.T5EncoderModel
:members: forward :members: forward, parallelize, deparallelize
TFT5Model TFT5Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -781,7 +781,7 @@ class T5Stack(T5PreTrainedModel): ...@@ -781,7 +781,7 @@ class T5Stack(T5PreTrainedModel):
def parallelize(self, device_map=None): def parallelize(self, device_map=None):
# Check validity of device_map # Check validity of device_map
self.device_map = ( self.device_map = (
get_device_map(len(self.block), torch.cuda.device_count()) if device_map is None else device_map get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
) )
assert_device_map(self.device_map, len(self.block)) assert_device_map(self.device_map, len(self.block))
self.model_parallel = True self.model_parallel = True
...@@ -1579,6 +1579,25 @@ class T5EncoderModel(T5PreTrainedModel): ...@@ -1579,6 +1579,25 @@ class T5EncoderModel(T5PreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.encoder.block))
self.encoder.parallelize(self.device_map)
self.model_parallel = True
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self):
self.encoder.deparallelize()
self.encoder = self.encoder.to("cpu")
self.model_parallel = False
self.device_map = None
torch.cuda.empty_cache()
def get_input_embeddings(self): def get_input_embeddings(self):
return self.shared return self.shared
......
...@@ -485,12 +485,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -485,12 +485,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
all_parallelizable_model_classes = ( all_parallelizable_model_classes = (
( (T5Model, T5ForConditionalGeneration, T5EncoderModel) if is_torch_available() else ()
T5Model,
T5ForConditionalGeneration,
)
if is_torch_available()
else ()
) )
test_pruning = False test_pruning = False
test_torchscript = True test_torchscript = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment