Unverified Commit 5b493762 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Deprecate parallelize API (#21448)

* Deprecate parallelize API

* Add documentation

* Fix copies
parent cc840752
...@@ -89,15 +89,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h ...@@ -89,15 +89,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] GPT2Model [[autodoc]] GPT2Model
- forward - forward
- parallelize
- deparallelize
## GPT2LMHeadModel ## GPT2LMHeadModel
[[autodoc]] GPT2LMHeadModel [[autodoc]] GPT2LMHeadModel
- forward - forward
- parallelize
- deparallelize
## GPT2DoubleHeadsModel ## GPT2DoubleHeadsModel
......
...@@ -360,22 +360,16 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h ...@@ -360,22 +360,16 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] T5Model [[autodoc]] T5Model
- forward - forward
- parallelize
- deparallelize
## T5ForConditionalGeneration ## T5ForConditionalGeneration
[[autodoc]] T5ForConditionalGeneration [[autodoc]] T5ForConditionalGeneration
- forward - forward
- parallelize
- deparallelize
## T5EncoderModel ## T5EncoderModel
[[autodoc]] T5EncoderModel [[autodoc]] T5EncoderModel
- forward - forward
- parallelize
- deparallelize
## TFT5Model ## TFT5Model
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import math import math
import os import os
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
...@@ -689,6 +690,13 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -689,6 +690,13 @@ class GPT2Model(GPT2PreTrainedModel):
@add_start_docstrings(PARALLELIZE_DOCSTRING) @add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None): def parallelize(self, device_map=None):
# Check validity of device_map # Check validity of device_map
warnings.warn(
"`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
" model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
" ...}",
FutureWarning,
)
self.device_map = ( self.device_map = (
get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
) )
...@@ -708,6 +716,10 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -708,6 +716,10 @@ class GPT2Model(GPT2PreTrainedModel):
@add_start_docstrings(DEPARALLELIZE_DOCSTRING) @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self): def deparallelize(self):
warnings.warn(
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
FutureWarning,
)
self.model_parallel = False self.model_parallel = False
self.device_map = None self.device_map = None
self.first_device = "cpu" self.first_device = "cpu"
...@@ -955,6 +967,13 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -955,6 +967,13 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
@add_start_docstrings(PARALLELIZE_DOCSTRING) @add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None): def parallelize(self, device_map=None):
warnings.warn(
"`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
" your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
" 0, 'transformer.h.1': 1, ...}",
FutureWarning,
)
self.device_map = ( self.device_map = (
get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
if device_map is None if device_map is None
...@@ -967,6 +986,10 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -967,6 +986,10 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
@add_start_docstrings(DEPARALLELIZE_DOCSTRING) @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self): def deparallelize(self):
warnings.warn(
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
FutureWarning,
)
self.transformer.deparallelize() self.transformer.deparallelize()
self.transformer = self.transformer.to("cpu") self.transformer = self.transformer.to("cpu")
self.lm_head = self.lm_head.to("cpu") self.lm_head = self.lm_head.to("cpu")
...@@ -1134,6 +1157,13 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -1134,6 +1157,13 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
@add_start_docstrings(PARALLELIZE_DOCSTRING) @add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None): def parallelize(self, device_map=None):
warnings.warn(
"`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should"
" load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your"
" own `device_map` but it needs to be a dictionary module_name to device, so for instance"
" {'transformer.h.0': 0, 'transformer.h.1': 1, ...}",
FutureWarning,
)
self.device_map = ( self.device_map = (
get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
if device_map is None if device_map is None
...@@ -1147,6 +1177,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -1147,6 +1177,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
@add_start_docstrings(DEPARALLELIZE_DOCSTRING) @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self): def deparallelize(self):
warnings.warn(
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
FutureWarning,
)
self.transformer.deparallelize() self.transformer.deparallelize()
self.transformer = self.transformer.to("cpu") self.transformer = self.transformer.to("cpu")
self.lm_head = self.lm_head.to("cpu") self.lm_head = self.lm_head.to("cpu")
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" PyTorch GPT-J model.""" """ PyTorch GPT-J model."""
import warnings
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
...@@ -489,6 +490,13 @@ class GPTJModel(GPTJPreTrainedModel): ...@@ -489,6 +490,13 @@ class GPTJModel(GPTJPreTrainedModel):
@add_start_docstrings(PARALLELIZE_DOCSTRING) @add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None): def parallelize(self, device_map=None):
warnings.warn(
"`GPTJModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
" model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
" ...}",
FutureWarning,
)
# Check validity of device_map # Check validity of device_map
self.device_map = ( self.device_map = (
get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
...@@ -508,6 +516,10 @@ class GPTJModel(GPTJPreTrainedModel): ...@@ -508,6 +516,10 @@ class GPTJModel(GPTJPreTrainedModel):
@add_start_docstrings(DEPARALLELIZE_DOCSTRING) @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self): def deparallelize(self):
warnings.warn(
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
FutureWarning,
)
self.model_parallel = False self.model_parallel = False
self.device_map = None self.device_map = None
self.first_device = "cpu" self.first_device = "cpu"
...@@ -724,6 +736,13 @@ class GPTJForCausalLM(GPTJPreTrainedModel): ...@@ -724,6 +736,13 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
@add_start_docstrings(PARALLELIZE_DOCSTRING) @add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None): def parallelize(self, device_map=None):
warnings.warn(
"`GPTJForCausalLM.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
" your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
" 0, 'transformer.h.1': 1, ...}",
FutureWarning,
)
self.device_map = ( self.device_map = (
get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
if device_map is None if device_map is None
...@@ -736,6 +755,10 @@ class GPTJForCausalLM(GPTJPreTrainedModel): ...@@ -736,6 +755,10 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
@add_start_docstrings(DEPARALLELIZE_DOCSTRING) @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self): def deparallelize(self):
warnings.warn(
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
FutureWarning,
)
self.transformer.deparallelize() self.transformer.deparallelize()
self.transformer = self.transformer.to("cpu") self.transformer = self.transformer.to("cpu")
self.lm_head = self.lm_head.to("cpu") self.lm_head = self.lm_head.to("cpu")
......
...@@ -843,6 +843,13 @@ class MT5Stack(MT5PreTrainedModel): ...@@ -843,6 +843,13 @@ class MT5Stack(MT5PreTrainedModel):
@add_start_docstrings(PARALLELIZE_DOCSTRING) @add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None): def parallelize(self, device_map=None):
warnings.warn(
"`MT5Stack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
" with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
" 'block.1': 1, ...}",
FutureWarning,
)
# Check validity of device_map # Check validity of device_map
self.device_map = ( self.device_map = (
get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
...@@ -864,6 +871,10 @@ class MT5Stack(MT5PreTrainedModel): ...@@ -864,6 +871,10 @@ class MT5Stack(MT5PreTrainedModel):
@add_start_docstrings(DEPARALLELIZE_DOCSTRING) @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self): def deparallelize(self):
warnings.warn(
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
FutureWarning,
)
self.model_parallel = False self.model_parallel = False
self.device_map = None self.device_map = None
self.first_device = "cpu" self.first_device = "cpu"
...@@ -1314,6 +1325,13 @@ class MT5Model(MT5PreTrainedModel): ...@@ -1314,6 +1325,13 @@ class MT5Model(MT5PreTrainedModel):
@add_start_docstrings(PARALLELIZE_DOCSTRING) @add_start_docstrings(PARALLELIZE_DOCSTRING)
# Copied from transformers.models.t5.modeling_t5.T5Model.parallelize # Copied from transformers.models.t5.modeling_t5.T5Model.parallelize
def parallelize(self, device_map=None): def parallelize(self, device_map=None):
warnings.warn(
"`T5Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
" with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'encoder.block.0':"
" 0, 'encoder.block.1': 1, ...}",
FutureWarning,
)
self.device_map = ( self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None if device_map is None
...@@ -1327,6 +1345,10 @@ class MT5Model(MT5PreTrainedModel): ...@@ -1327,6 +1345,10 @@ class MT5Model(MT5PreTrainedModel):
@add_start_docstrings(DEPARALLELIZE_DOCSTRING) @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
# Copied from transformers.models.t5.modeling_t5.T5Model.deparallelize # Copied from transformers.models.t5.modeling_t5.T5Model.deparallelize
def deparallelize(self): def deparallelize(self):
warnings.warn(
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
FutureWarning,
)
self.encoder.deparallelize() self.encoder.deparallelize()
self.decoder.deparallelize() self.decoder.deparallelize()
self.encoder = self.encoder.to("cpu") self.encoder = self.encoder.to("cpu")
...@@ -1539,6 +1561,13 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel): ...@@ -1539,6 +1561,13 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
@add_start_docstrings(PARALLELIZE_DOCSTRING) @add_start_docstrings(PARALLELIZE_DOCSTRING)
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.parallelize # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.parallelize
def parallelize(self, device_map=None): def parallelize(self, device_map=None):
warnings.warn(
"`T5ForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you"
" should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also"
" provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance"
" {'encoder.block.0': 0, 'encoder.block.1': 1, ...}",
FutureWarning,
)
self.device_map = ( self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None if device_map is None
...@@ -1553,6 +1582,10 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel): ...@@ -1553,6 +1582,10 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
@add_start_docstrings(DEPARALLELIZE_DOCSTRING) @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.deparallelize # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.deparallelize
def deparallelize(self): def deparallelize(self):
warnings.warn(
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
FutureWarning,
)
self.encoder.deparallelize() self.encoder.deparallelize()
self.decoder.deparallelize() self.decoder.deparallelize()
self.encoder = self.encoder.to("cpu") self.encoder = self.encoder.to("cpu")
...@@ -1849,6 +1882,13 @@ class MT5EncoderModel(MT5PreTrainedModel): ...@@ -1849,6 +1882,13 @@ class MT5EncoderModel(MT5PreTrainedModel):
@add_start_docstrings(PARALLELIZE_DOCSTRING) @add_start_docstrings(PARALLELIZE_DOCSTRING)
# Copied from transformers.models.t5.modeling_t5.T5EncoderModel.parallelize # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.parallelize
def parallelize(self, device_map=None): def parallelize(self, device_map=None):
warnings.warn(
"`T5EncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
" your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
" 'block.1': 1, ...}",
FutureWarning,
)
self.device_map = ( self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None if device_map is None
...@@ -1861,6 +1901,10 @@ class MT5EncoderModel(MT5PreTrainedModel): ...@@ -1861,6 +1901,10 @@ class MT5EncoderModel(MT5PreTrainedModel):
@add_start_docstrings(DEPARALLELIZE_DOCSTRING) @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
# Copied from transformers.models.t5.modeling_t5.T5EncoderModel.deparallelize # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.deparallelize
def deparallelize(self): def deparallelize(self):
warnings.warn(
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
FutureWarning,
)
self.encoder.deparallelize() self.encoder.deparallelize()
self.encoder = self.encoder.to("cpu") self.encoder = self.encoder.to("cpu")
self.model_parallel = False self.model_parallel = False
......
...@@ -872,6 +872,13 @@ class T5Stack(T5PreTrainedModel): ...@@ -872,6 +872,13 @@ class T5Stack(T5PreTrainedModel):
@add_start_docstrings(PARALLELIZE_DOCSTRING) @add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None): def parallelize(self, device_map=None):
warnings.warn(
"`T5Stack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
" with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
" 'block.1': 1, ...}",
FutureWarning,
)
# Check validity of device_map # Check validity of device_map
self.device_map = ( self.device_map = (
get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
...@@ -893,6 +900,10 @@ class T5Stack(T5PreTrainedModel): ...@@ -893,6 +900,10 @@ class T5Stack(T5PreTrainedModel):
@add_start_docstrings(DEPARALLELIZE_DOCSTRING) @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self): def deparallelize(self):
warnings.warn(
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
FutureWarning,
)
self.model_parallel = False self.model_parallel = False
self.device_map = None self.device_map = None
self.first_device = "cpu" self.first_device = "cpu"
...@@ -1318,6 +1329,13 @@ class T5Model(T5PreTrainedModel): ...@@ -1318,6 +1329,13 @@ class T5Model(T5PreTrainedModel):
@add_start_docstrings(PARALLELIZE_DOCSTRING) @add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None): def parallelize(self, device_map=None):
warnings.warn(
"`T5Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
" with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'encoder.block.0':"
" 0, 'encoder.block.1': 1, ...}",
FutureWarning,
)
self.device_map = ( self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None if device_map is None
...@@ -1330,6 +1348,10 @@ class T5Model(T5PreTrainedModel): ...@@ -1330,6 +1348,10 @@ class T5Model(T5PreTrainedModel):
@add_start_docstrings(DEPARALLELIZE_DOCSTRING) @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self): def deparallelize(self):
warnings.warn(
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
FutureWarning,
)
self.encoder.deparallelize() self.encoder.deparallelize()
self.decoder.deparallelize() self.decoder.deparallelize()
self.encoder = self.encoder.to("cpu") self.encoder = self.encoder.to("cpu")
...@@ -1515,6 +1537,13 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1515,6 +1537,13 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
@add_start_docstrings(PARALLELIZE_DOCSTRING) @add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None): def parallelize(self, device_map=None):
warnings.warn(
"`T5ForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you"
" should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also"
" provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance"
" {'encoder.block.0': 0, 'encoder.block.1': 1, ...}",
FutureWarning,
)
self.device_map = ( self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None if device_map is None
...@@ -1528,6 +1557,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1528,6 +1557,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
@add_start_docstrings(DEPARALLELIZE_DOCSTRING) @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self): def deparallelize(self):
warnings.warn(
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
FutureWarning,
)
self.encoder.deparallelize() self.encoder.deparallelize()
self.decoder.deparallelize() self.decoder.deparallelize()
self.encoder = self.encoder.to("cpu") self.encoder = self.encoder.to("cpu")
...@@ -1790,6 +1823,13 @@ class T5EncoderModel(T5PreTrainedModel): ...@@ -1790,6 +1823,13 @@ class T5EncoderModel(T5PreTrainedModel):
@add_start_docstrings(PARALLELIZE_DOCSTRING) @add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None): def parallelize(self, device_map=None):
warnings.warn(
"`T5EncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
" your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
" 'block.1': 1, ...}",
FutureWarning,
)
self.device_map = ( self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None if device_map is None
...@@ -1801,6 +1841,10 @@ class T5EncoderModel(T5PreTrainedModel): ...@@ -1801,6 +1841,10 @@ class T5EncoderModel(T5PreTrainedModel):
@add_start_docstrings(DEPARALLELIZE_DOCSTRING) @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self): def deparallelize(self):
warnings.warn(
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
FutureWarning,
)
self.encoder.deparallelize() self.encoder.deparallelize()
self.encoder = self.encoder.to("cpu") self.encoder = self.encoder.to("cpu")
self.model_parallel = False self.model_parallel = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment