Unverified Commit 1cd9be2a authored by alexorona's avatar alexorona Committed by GitHub
Browse files

gpt2 and t5 parallel modeling (#8696)



* gpt2 and t5 parallel modeling

* model_parallel utils update

* adding missing model_parallel_utils

Adds missing model_parallel_utils and reverses the changes to code in modeling_gpt2 and modeling_t5

* training_args reformat

Reformatted training_args

* style formatting

Style formatting doc string length on training_args and model_parallel_utils

* style changes

make style && make quality for training_args and model_parallel_utils.

* adding tests

* minor change in trainer

reverts loss calculation

* Update training_args.py

* Update training_args.py

added back docstring language for adam_beta1 and adam_beta2

* Update trainer.py

* Update src/transformers/trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Fix style & rebase
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarLysandreJik <lysandre.debut@reseau.eseo.fr>
parent 1e45bef0
......@@ -44,6 +44,7 @@ from ...modeling_utils import (
prune_conv1d_layer,
)
from ...utils import logging
from ...utils.model_parallel_utils import assert_device_map, get_device_map
from .configuration_gpt2 import GPT2Config
......@@ -474,6 +475,46 @@ GPT2_INPUTS_DOCSTRING = r"""
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
PARALLELIZE_DOCSTRING = r"""
Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
it will evenly distribute blocks across all devices.
Args:
device_map (:obj:`Dict[int, list]`, optional, defaults to None):
A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
automatically mapped to the first device (for esoteric reasons). That means that the first device should
have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
following number of attention modules:
- gpt2: 12
- gpt2-medium: 24
- gpt2-large: 36
- gpt2-xl: 48
Example::
Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
model = GPT2LMHeadModel.from_pretrained('gpt2-xl')
device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]}
model.parallelize(device_map)
"""
DEPARALLELIZE_DOCSTRING = r"""
Moves the model to cpu from a model parallel state.
Example::
On a 4 GPU machine with gpt2-large:
model = GPT2LMHeadModel.from_pretrained('gpt2-large')
device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7],
1: [8, 9, 10, 11, 12, 13, 14, 15],
2: [16, 17, 18, 19, 20, 21, 22, 23],
3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]}
model.parallelize(device_map) # Splits the model across several devices
model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
"""
@add_start_docstrings(
......@@ -491,6 +532,42 @@ class GPT2Model(GPT2PreTrainedModel):
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.init_weights()
# Model parallel
self.model_parallel = False
self.device_map = None
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
# Check validity of device_map
self.device_map = (
get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
)
assert_device_map(self.device_map, len(self.h))
self.model_parallel = True
self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
self.last_device = "cuda:" + str(max(self.device_map.keys()))
self.wte = self.wte.to(self.first_device)
self.wpe = self.wpe.to(self.first_device)
# Load onto devices
for k, v in self.device_map.items():
for block in v:
cuda_device = "cuda:" + str(k)
self.h[block] = self.h[block].to(cuda_device)
# ln_f to last
self.ln_f = self.ln_f.to(self.last_device)
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self):
self.model_parallel = False
self.device_map = None
self.first_device = "cpu"
self.last_device = "cpu"
self.wte = self.wte.to("cpu")
self.wpe = self.wpe.to("cpu")
for index in range(len(self.h)):
self.h[index] = self.h[index].to("cpu")
self.ln_f = self.ln_f.to("cpu")
torch.cuda.empty_cache()
def get_input_embeddings(self):
return self.wte
......@@ -616,6 +693,18 @@ class GPT2Model(GPT2PreTrainedModel):
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = layer_past.to(hidden_states.device)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
......@@ -658,6 +747,12 @@ class GPT2Model(GPT2PreTrainedModel):
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3],)
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(*output_shape)
......@@ -694,6 +789,28 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self.init_weights()
self.model_parallel = False
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.transformer.h))
self.transformer.parallelize(self.device_map)
self.lm_head = self.lm_head.to(self.transformer.first_device)
self.model_parallel = True
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self):
self.transformer.deparallelize()
self.transformer = self.transformer.to("cpu")
self.lm_head = self.lm_head.to("cpu")
self.model_parallel = False
torch.cuda.empty_cache()
def get_output_embeddings(self):
return self.lm_head
......@@ -774,6 +891,11 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
)
hidden_states = transformer_outputs[0]
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.transformer.first_device)
hidden_states = hidden_states.to(self.lm_head.weight.device)
lm_logits = self.lm_head(hidden_states)
loss = None
......
......@@ -40,6 +40,7 @@ from ...modeling_outputs import (
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
from ...utils.model_parallel_utils import assert_device_map, get_device_map
from .configuration_t5 import T5Config
......@@ -177,6 +178,47 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
# - torch.nn.Module for the layers and
# - PreTrainedModel for the models (it-self a sub-class of torch.nn.Module)
####################################################
PARALLELIZE_DOCSTRING = r"""
Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
it will evenly distribute blocks across all devices.
Args:
device_map (:obj:`Dict[int, list]`, optional, defaults to None):
A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
automatically mapped to the first device (for esoteric reasons). That means that the first device should
have fewer attention modules mapped to it than other devices. For reference, the t5 models have the
following number of attention modules:
- t5-small: 6
- t5-base: 12
- t5-large: 24
- t5-3b: 24
- t5-11b: 24
Example::
Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules:
model = T5ForConditionalGeneration.from_pretrained('t5-3b')
device_map = {0: [0, 1, 2],
1: [3, 4, 5, 6, 7, 8, 9],
2: [10, 11, 12, 13, 14, 15, 16],
3: [17, 18, 19, 20, 21, 22, 23]}
model.parallelize(device_map)
"""
DEPARALLELIZE_DOCSTRING = r"""
Moves the model to cpu from a model parallel state.
Example::
On a 4 GPU machine with t5-3b:
model = T5ForConditionalGeneration.from_pretrained('t5-3b')
device_map = {0: [0, 1, 2],
1: [3, 4, 5, 6, 7, 8, 9],
2: [10, 11, 12, 13, 14, 15, 16],
3: [17, 18, 19, 20, 21, 22, 23]}
model.parallelize(device_map) # Splits the model across several devices
model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
"""
class T5LayerNorm(nn.Module):
......@@ -729,6 +771,42 @@ class T5Stack(T5PreTrainedModel):
self.dropout = nn.Dropout(config.dropout_rate)
self.init_weights()
# Model parallel
self.model_parallel = False
self.device_map = None
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
# Check validity of device_map
self.device_map = (
get_device_map(len(self.block), torch.cuda.device_count()) if device_map is None else device_map
)
assert_device_map(self.device_map, len(self.block))
self.model_parallel = True
self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
self.last_device = "cuda:" + str(max(self.device_map.keys()))
# Load onto devices
for k, v in self.device_map.items():
for layer in v:
cuda_device = "cuda:" + str(k)
self.block[layer] = self.block[layer].to(cuda_device)
# Set embed_tokens to first layer
self.embed_tokens = self.embed_tokens.to(self.first_device)
# Set final layer norm to last device
self.final_layer_norm = self.final_layer_norm.to(self.last_device)
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def deparallelize(self):
self.model_parallel = False
self.device_map = None
self.first_device = "cpu"
self.last_device = "cpu"
for i in range(len(self.block)):
self.block[i] = self.block[i].to("cpu")
self.embed_tokens = self.embed_tokens.to("cpu")
self.final_layer_norm = self.final_layer_norm.to("cpu")
torch.cuda.empty_cache()
def get_input_embeddings(self):
return self.embed_tokens
......@@ -753,7 +831,10 @@ class T5Stack(T5PreTrainedModel):
output_hidden_states=None,
return_dict=None,
):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(self.first_device)
self.embed_tokens = self.embed_tokens.to(self.first_device)
use_cache = use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
......@@ -821,6 +902,20 @@ class T5Stack(T5PreTrainedModel):
hidden_states = self.dropout(inputs_embeds)
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
if position_bias is not None:
position_bias = position_bias.to(hidden_states.device)
if encoder_hidden_states is not None:
encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
if encoder_extended_attention_mask is not None:
encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
if encoder_decoder_position_bias is not None:
encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
......@@ -855,6 +950,12 @@ class T5Stack(T5PreTrainedModel):
if self.is_decoder:
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
......@@ -1008,6 +1109,32 @@ class T5Model(T5PreTrainedModel):
self.init_weights()
# Model parallel
self.model_parallel = False
self.device_map = None
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.encoder.block))
self.encoder.parallelize(self.device_map)
self.decoder.parallelize(self.device_map)
self.model_parallel = True
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self):
self.encoder.deparallelize()
self.decoder.deparallelize()
self.encoder = self.encoder.to("cpu")
self.decoder = self.decoder.to("cpu")
self.model_parallel = False
self.device_map = None
torch.cuda.empty_cache()
def get_input_embeddings(self):
return self.shared
......@@ -1086,6 +1213,18 @@ class T5Model(T5PreTrainedModel):
)
hidden_states = encoder_outputs[0]
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
hidden_states = hidden_states.to(self.decoder.first_device)
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
if attention_mask is not None:
attention_mask = attention_mask.to(self.decoder.first_device)
if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
# Decode
decoder_outputs = self.decoder(
......@@ -1147,6 +1286,34 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
self.init_weights()
# Model parallel
self.model_parallel = False
self.device_map = None
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.encoder.block))
self.encoder.parallelize(self.device_map)
self.decoder.parallelize(self.device_map)
self.lm_head = self.lm_head.to(self.decoder.first_device)
self.model_parallel = True
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self):
self.encoder.deparallelize()
self.decoder.deparallelize()
self.encoder = self.encoder.to("cpu")
self.decoder = self.decoder.to("cpu")
self.lm_head = self.lm_head.to("cpu")
self.model_parallel = False
self.device_map = None
torch.cuda.empty_cache()
def get_input_embeddings(self):
return self.shared
......@@ -1231,6 +1398,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
hidden_states = encoder_outputs[0]
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
# get decoder inputs from shifting lm labels to the right
decoder_input_ids = self._shift_right(labels)
......@@ -1244,6 +1414,17 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
if decoder_inputs_embeds is not None:
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
hidden_states = hidden_states.to(self.decoder.first_device)
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
if attention_mask is not None:
attention_mask = attention_mask.to(self.decoder.first_device)
if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
......@@ -1261,6 +1442,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
sequence_output = decoder_outputs[0]
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.encoder.first_device)
self.lm_head = self.lm_head.to(self.encoder.first_device)
sequence_output = sequence_output.to(self.lm_head.weight.device)
if self.config.tie_word_embeddings:
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
......
......@@ -241,7 +241,11 @@ class Trainer:
self.hp_name = None
if model is None and model_init is not None:
model = self.call_model_init()
# Model parallel
if not self.args.model_parallel:
self.model = model.to(args.device) if model is not None else None
else:
self.model = model if model is not None else None
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
self.data_collator = data_collator if data_collator is not None else default_collator
self.train_dataset = train_dataset
......@@ -578,6 +582,7 @@ class Trainer:
model = self.call_model_init(trial)
if not self.args.model_parallel:
self.model = model.to(self.args.device)
# Reinitializes optimizer and scheduler
......@@ -625,7 +630,7 @@ class Trainer:
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
# Multi-gpu training (should be after apex fp16 initialization)
if self.args.n_gpu > 1:
if self.args.n_gpu > 1 and not self.args.model_parallel:
model = torch.nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization)
......@@ -805,6 +810,7 @@ class Trainer:
)
if isinstance(model, PreTrainedModel):
self.model = model.from_pretrained(self.state.best_model_checkpoint)
if not self.args.model_parallel:
self.model = self.model.to(self.args.device)
else:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
......@@ -1323,7 +1329,7 @@ class Trainer:
model = self.model
# multi-gpu eval
if self.args.n_gpu > 1:
if self.args.n_gpu > 1 and not self.args.model_parallel:
model = torch.nn.DataParallel(model)
# Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
......
......@@ -40,6 +40,9 @@ class TrainingArguments:
Using :class:`~transformers.HfArgumentParser` we can turn this class into argparse arguments to be able to specify
them on the command line.
Parameters:
output_dir (:obj:`str`):
The output directory where the model predictions and checkpoints will be written.
......@@ -201,6 +204,15 @@ class TrainingArguments:
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=None, metadata={"help": "Whether to run eval on the dev set."})
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
model_parallel: bool = field(
default=False,
metadata={
"help": (
"If there are more than one devices, whether to use model parallelism to distribute the "
"model's modules across devices."
)
},
)
evaluation_strategy: EvaluationStrategy = field(
default="no",
metadata={"help": "Run evaluation during training at each logging step."},
......@@ -366,7 +378,11 @@ class TrainingArguments:
"version. Using `--per_device_train_batch_size` is preferred."
)
per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size
return per_device_batch_size * max(1, self.n_gpu)
if not self.model_parallel:
train_batch_size = per_device_batch_size * max(1, self.n_gpu)
else:
train_batch_size = per_device_batch_size
return train_batch_size
@property
def eval_batch_size(self) -> int:
......@@ -379,7 +395,11 @@ class TrainingArguments:
"version. Using `--per_device_eval_batch_size` is preferred."
)
per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size
return per_device_batch_size * max(1, self.n_gpu)
if not self.model_parallel:
eval_batch_size = per_device_batch_size * max(1, self.n_gpu)
else:
eval_batch_size = per_device_batch_size
return eval_batch_size
@cached_property
@torch_required
......
# coding=utf-8
from math import ceil
def assert_device_map(device_map, num_blocks):
blocks = list(range(0, num_blocks))
device_map_blocks = [item for sublist in list(device_map.values()) for item in sublist]
# Duplicate check
duplicate_blocks = []
for i in device_map_blocks:
if device_map_blocks.count(i) > 1 and i not in duplicate_blocks:
duplicate_blocks.append(i)
# Missing blocks
missing_blocks = [i for i in blocks if i not in device_map_blocks]
extra_blocks = [i for i in device_map_blocks if i not in blocks]
assert len(duplicate_blocks) == 0, (
"Duplicate attention blocks specified in device_map. Attention blocks must be specified to one device. These "
"attention blocks were specified more than once: " + str(duplicate_blocks)
)
assert len(missing_blocks) == 0, (
"There are attention blocks for this model that are not specified in the device_map. Add these attention "
"blocks to a device on the device_map: " + str(missing_blocks)
)
assert (
len(extra_blocks) == 0
), "The device_map contains more attention blocks than this model has. Remove these from the device_map:" + str(
extra_blocks
)
def get_device_map(n_layers, devices):
"""Returns a dictionary of layers distributed evenly across all devices."""
layers = list(range(n_layers))
n_blocks = int(ceil(n_layers / len(devices)))
layers_list = list(layers[i : i + n_blocks] for i in range(0, n_layers, n_blocks))
return dict(zip(devices, layers_list))
......@@ -68,6 +68,7 @@ class ModelTesterMixin:
test_resize_embeddings = True
test_head_masking = True
test_missing_keys = True
test_model_parallel = False
is_encoder_decoder = False
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......@@ -953,6 +954,97 @@ class ModelTesterMixin:
with torch.no_grad():
_ = model(**self._prepare_for_class(inputs_dict, model_class))
@require_torch_multi_gpu
def test_model_parallelization(self):
if not self.test_model_parallel:
pass
import subprocess
def get_current_gpu_memory_use():
run_process = subprocess.Popen(
"nvidia-smi --query-gpu=memory.used --format=csv,nounits,noheader", shell=True, stdout=subprocess.PIPE
)
memory_usage = run_process.stdout.read().decode("utf-8").strip()
per_device_memory = [int(memory) for memory in memory_usage.split("\n")]
return per_device_memory
# Needs a large model to see the difference.
config = self.model_tester.get_large_model_config()
for model_class in self.all_parallelizable_model_classes:
torch.cuda.empty_cache()
# Retrieve initial memory usage (should be close to 0)
initial_memory = get_current_gpu_memory_use()
# Put model on device
model = model_class(config.from_pretrained("gpt2"))
model.to("cuda:0")
# Retrieve the memory after the model is put on the device
memory_after_model_load = get_current_gpu_memory_use()
del model
torch.cuda.empty_cache()
# The memory use on that device should be higher than it was initially.
self.assertGreater(memory_after_model_load[0], initial_memory[0])
# Spread model layers over multiple devices
model = model_class(config.from_pretrained("gpt2"))
model.parallelize()
memory_after_parallelization = get_current_gpu_memory_use()
# Assert that the memory use on all devices is higher than it was when loaded only on CPU
for n in range(torch.cuda.device_count()):
self.assertGreater(memory_after_parallelization[n], initial_memory[n])
# Assert that the memory use of the first device is lower than it was when the entire model was loaded on it
self.assertLess(memory_after_parallelization[0], memory_after_model_load[0])
# Assert that the memory use of the second device is higher than it was when the entire model was loaded
# on the other device.
self.assertGreater(memory_after_parallelization[1], memory_after_model_load[1])
del model
torch.cuda.empty_cache()
@require_torch_multi_gpu
def test_model_parallel_equal_results(self):
if not self.test_model_parallel:
pass
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_parallelizable_model_classes:
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
output = model(**inputs_dict)
model.parallelize()
def cast_to_gpu(dictionary):
output = {}
for k, v in dictionary.items():
if isinstance(v, torch.Tensor):
output[k] = v.to("cuda:0")
else:
output[k] = v
return output
parallel_output = model(**cast_to_gpu(inputs_dict))
for value, parallel_value in zip(output, parallel_output):
if isinstance(value, torch.Tensor):
self.assertTrue(torch.allclose(value, parallel_value.to("cpu"), atol=1e-7))
elif isinstance(value, (Tuple, List)):
for value_, parallel_value_ in zip(value, parallel_value):
self.assertTrue(torch.allclose(value_, parallel_value_.to("cpu"), atol=1e-7))
global_rng = random.Random()
......
......@@ -92,6 +92,9 @@ class GPT2ModelTester:
self.eos_token_id = vocab_size - 1
self.pad_token_id = vocab_size - 1
def get_large_model_config(self):
return GPT2Config.from_pretrained("gpt2")
def prepare_config_and_inputs(self, gradient_checkpointing=False):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
......@@ -389,7 +392,9 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
else ()
)
all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
all_parallelizable_model_classes = (GPT2LMHeadModel,) if is_torch_available() else ()
test_missing_keys = False
test_model_parallel = True
# special case for DoubleHeads model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
......@@ -85,6 +85,9 @@ class T5ModelTester:
self.scope = None
self.decoder_layers = decoder_layers
def get_large_model_config(self):
return T5Config.from_pretrained("t5-base")
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
......@@ -470,9 +473,18 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
all_parallelizable_model_classes = (
(
T5Model,
T5ForConditionalGeneration,
)
if is_torch_available()
else ()
)
test_pruning = False
test_torchscript = True
test_resize_embeddings = False
test_model_parallel = True
is_encoder_decoder = True
def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment