"tests/led/test_modeling_led.py" did not exist on "189387e9b2e1d6d1a0fb8355fd01e9a89fdb3e4a"
Unverified Commit a9c8bff7 authored by Ahmed Elnaggar's avatar Ahmed Elnaggar Committed by GitHub
Browse files

Add parallelization support for T5EncoderModel (#9082)



* add model parallelism to T5EncoderModel

add model parallelism to T5EncoderModel

* remove decoder from T5EncoderModel parallelize

* uodate T5EncoderModel docs

* Extend T5ModelTest for T5EncoderModel

* fix T5Stask using range for get_device_map

* fix style
Co-authored-by: default avatarAhmed Elnaggar <elnaggar@rostlab.informatik.tu-muenchen.de>
parent b00eb4fb
......@@ -131,7 +131,7 @@ T5EncoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.T5EncoderModel
:members: forward
:members: forward, parallelize, deparallelize
TFT5Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -781,7 +781,7 @@ class T5Stack(T5PreTrainedModel):
def parallelize(self, device_map=None):
# Check validity of device_map
self.device_map = (
get_device_map(len(self.block), torch.cuda.device_count()) if device_map is None else device_map
get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
)
assert_device_map(self.device_map, len(self.block))
self.model_parallel = True
......@@ -1579,6 +1579,25 @@ class T5EncoderModel(T5PreTrainedModel):
self.init_weights()
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.encoder.block))
self.encoder.parallelize(self.device_map)
self.model_parallel = True
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self):
self.encoder.deparallelize()
self.encoder = self.encoder.to("cpu")
self.model_parallel = False
self.device_map = None
torch.cuda.empty_cache()
def get_input_embeddings(self):
return self.shared
......
......@@ -485,12 +485,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
all_parallelizable_model_classes = (
(
T5Model,
T5ForConditionalGeneration,
)
if is_torch_available()
else ()
(T5Model, T5ForConditionalGeneration, T5EncoderModel) if is_torch_available() else ()
)
test_pruning = False
test_torchscript = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment