Unverified Commit aa756cec authored by eqy's avatar eqy Committed by GitHub
Browse files

minimal bert pipeline parallel test (#1216)

* minimal bert pipeline parallel test

* fix global and cleanup

* use get_forward_backward_func

* cleanup and fix some tests
parent fcae8fa3
import torch
from apex.normalization import FusedLayerNorm as LayerNorm
from apex.transformer import tensor_parallel
from apex.transformer.enums import AttnMaskType
from apex.transformer.testing.global_vars import get_args
from .standalone_gpt import get_language_model, get_linear_layer, init_method_normal, parallel_lm_logits, scaled_init_method_normal
from .standalone_gpt import MegatronModule
def bert_extended_attention_mask(attention_mask):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s = attention_mask.unsqueeze(1)
# [b, s, 1]
attention_mask_bs1 = attention_mask.unsqueeze(2)
# [b, s, s]
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
# [b, 1, s, s]
extended_attention_mask = attention_mask_bss.unsqueeze(1)
# Convert attention mask to binary:
extended_attention_mask = (extended_attention_mask < 0.5)
return extended_attention_mask
def bert_position_ids(token_ids):
# Create position ids
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long,
device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids
class BertLMHead(MegatronModule):
"""Masked LM head for Bert
Arguments:
mpu_vocab_size: model parallel size of vocabulary.
hidden_size: hidden size
init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: whether output logits being distributed or not.
"""
def __init__(self, mpu_vocab_size, hidden_size, init_method,
layernorm_epsilon, parallel_output):
super(BertLMHead, self).__init__()
args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
# TODO: do we need this?
# mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
self.gelu = openai_gelu
elif args.onnx_safe:
self.gelu = erf_gelu
def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states)
hidden_states = self.gelu(hidden_states)
hidden_states = self.layernorm(hidden_states)
output = parallel_lm_logits(hidden_states,
word_embeddings_weight,
self.parallel_output,
bias=self.bias)
return output
def post_language_model_processing(lm_output, pooled_output,
lm_head, binary_head,
lm_labels,
logit_weights,
fp16_lm_cross_entropy):
# Output.
lm_logits = lm_head(
lm_output, logit_weights)
binary_logits = None
if binary_head is not None:
binary_logits = binary_head(pooled_output)
if lm_labels is None:
return lm_logits, binary_logits
else:
if fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
return lm_loss, binary_logits
class BertModel(MegatronModule):
"""Bert Language model."""
def __init__(self,
num_tokentypes=2,
add_binary_head=True,
parallel_output=True,
pre_process=True,
post_process=True):
super(BertModel, self).__init__()
args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.add_binary_head = add_binary_head
self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process)
self.initialize_word_embeddings(init_method_normal)
if self.post_process:
self.lm_head = BertLMHead(
self.word_embeddings_weight().size(0),
args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
self._lm_head_key = 'lm_head'
self.binary_head = None
if self.add_binary_head:
self.binary_head = get_linear_layer(args.hidden_size, 2,
init_method)
self._binary_head_key = 'binary_head'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, bert_model_input, attention_mask,
tokentype_ids=None, lm_labels=None):
extended_attention_mask = bert_extended_attention_mask(attention_mask)
input_ids = bert_model_input
position_ids = bert_position_ids(input_ids)
lm_output = self.language_model(
input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids
)
if self.post_process and self.add_binary_head:
lm_output, pooled_output = lm_output
else:
pooled_output = None
if self.post_process:
return post_language_model_processing(lm_output, pooled_output,
self.lm_head, self.binary_head,
lm_labels,
self.word_embeddings_weight(),
self.fp16_lm_cross_entropy)
else:
return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.post_process:
state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.post_process and self.add_binary_head:
state_dict_[self._binary_head_key] \
= self.binary_head.state_dict(destination, prefix, keep_vars)
# Save word_embeddings.
if self.post_process and not self.pre_process:
state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
if self.post_process:
self.lm_head.load_state_dict(
state_dict[self._lm_head_key], strict=strict)
if self.post_process and self.add_binary_head:
self.binary_head.load_state_dict(
state_dict[self._binary_head_key], strict=strict)
# Load word_embeddings.
if self.post_process and not self.pre_process:
self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict)
def bert_model_provider(pre_process=True, post_process=True):
model = BertModel(num_tokentypes=0, add_binary_head=False, pre_process=pre_process, post_process=post_process)
return model
...@@ -67,7 +67,6 @@ def bias_gelu_back(g, bias, y): ...@@ -67,7 +67,6 @@ def bias_gelu_back(g, bias, y):
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff * g return ff * g
class MegatronModule(torch.nn.Module): class MegatronModule(torch.nn.Module):
"""Megatron specific extensions of torch Module with support """Megatron specific extensions of torch Module with support
for pipelining.""" for pipelining."""
...@@ -76,26 +75,30 @@ class MegatronModule(torch.nn.Module): ...@@ -76,26 +75,30 @@ class MegatronModule(torch.nn.Module):
super(MegatronModule, self).__init__() super(MegatronModule, self).__init__()
self.share_word_embeddings = share_word_embeddings self.share_word_embeddings = share_word_embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False):
def state_dict_for_save_checkpoint(self, destination=None, prefix="",
keep_vars=False):
"""Use this function to override the state dict for """Use this function to override the state dict for
saving checkpoints.""" saving checkpoints."""
return self.state_dict(destination, prefix, keep_vars) return self.state_dict(destination, prefix, keep_vars)
def word_embeddings_weight(self): def word_embeddings_weight(self):
if ( if not parallel_state.is_pipeline_last_stage(ignore_virtual=True) or \
not parallel_state.is_pipeline_last_stage(ignore_virtual=True) parallel_state.get_pipeline_model_parallel_world_size() == 1:
or parallel_state.get_pipeline_model_parallel_world_size() == 1
):
return self.language_model.embedding.word_embeddings.weight return self.language_model.embedding.word_embeddings.weight
else: else:
if not self.share_word_embeddings: if not self.share_word_embeddings:
raise Exception("word_embeddings_weight() called for last " "stage, but share_word_embeddings is false") raise Exception("word_embeddings_weight() called for last "
"stage, but share_word_embeddings is false")
return self.word_embeddings.weight return self.word_embeddings.weight
def initialize_word_embeddings(self, init_method_normal): def initialize_word_embeddings(self, init_method_normal):
args = get_args() args = get_args()
if not self.share_word_embeddings: if not self.share_word_embeddings:
raise Exception("initialize_word_embeddings() was called but " "share_word_embeddings is false") raise Exception("initialize_word_embeddings() was called but "
"share_word_embeddings is false")
# This function just initializes the word embeddings in the final stage # This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. Nothing to do if we aren't # when we are using pipeline parallelism. Nothing to do if we aren't
...@@ -121,27 +124,25 @@ class MegatronModule(torch.nn.Module): ...@@ -121,27 +124,25 @@ class MegatronModule(torch.nn.Module):
# set word_embeddings weights to 0 here, then copy first # set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below. # stage's weights using all_reduce below.
self.word_embeddings = tensor_parallel.VocabParallelEmbedding( self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size, init_method=init_method_normal(args.init_method_std) args.padded_vocab_size, args.hidden_size,
) init_method=init_method_normal(args.init_method_std),
use_cpu_initialization=args.use_cpu_initialization)
self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True self.word_embeddings.weight.shared = True
# Zero out initial weights for decoder embedding. # Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule. # NOTE: We don't currently support T5 with the interleaved schedule.
if ( if not parallel_state.is_pipeline_first_stage(ignore_virtual=True) and \
not parallel_state.is_pipeline_first_stage(ignore_virtual=True) not parallel_state.is_pipeline_last_stage(ignore_virtual=True) and \
and not parallel_state.is_pipeline_last_stage(ignore_virtual=True) parallel_state.is_rank_in_embedding_group():
and parallel_state.is_rank_in_embedding_group()
):
self.language_model.embedding.zero_parameters() self.language_model.embedding.zero_parameters()
# Ensure that first and last stages have the same initial parameter # Ensure that first and last stages have the same initial parameter
# values. # values.
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
if parallel_state.is_rank_in_embedding_group(): if parallel_state.is_rank_in_embedding_group():
torch.distributed.all_reduce( torch.distributed.all_reduce(self.word_embeddings_weight().data,
self.word_embeddings_weight().data, group=parallel_state.get_embedding_group() group=parallel_state.get_embedding_group())
)
# All-reduce other embeddings as well as necessary. The last stage # All-reduce other embeddings as well as necessary. The last stage
# does not have these other embeddings, so just create placeholder # does not have these other embeddings, so just create placeholder
# tensors of the right shape with all zeros. # tensors of the right shape with all zeros.
...@@ -155,18 +156,14 @@ class MegatronModule(torch.nn.Module): ...@@ -155,18 +156,14 @@ class MegatronModule(torch.nn.Module):
else: else:
self.language_model.embedding.cuda() self.language_model.embedding.cuda()
position_embeddings = self.language_model.embedding.position_embeddings position_embeddings = self.language_model.embedding.position_embeddings
torch.distributed.all_reduce( torch.distributed.all_reduce(position_embeddings.weight.data,
position_embeddings.weight.data, group=parallel_state.get_embedding_group() group=parallel_state.get_embedding_group())
)
else: else:
print( print("WARNING! Distributed processes aren't initialized, so "
"WARNING! Distributed processes aren't initialized, so " "word embeddings in the last layer are not initialized. "
"word embeddings in the last layer are not initialized. " "If you are just manipulating a model this is fine, but "
"If you are just manipulating a model this is fine, but " "this needs to be handled manually. If you are training "
"this needs to be handled manually. If you are training " "something is definitely wrong.")
"something is definitely wrong."
)
class GeLUFunction(torch.autograd.Function): class GeLUFunction(torch.autograd.Function):
@staticmethod @staticmethod
...@@ -248,8 +245,8 @@ class ParallelMLP(MegatronModule): ...@@ -248,8 +245,8 @@ class ParallelMLP(MegatronModule):
# Project to 4h. # Project to 4h.
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear( self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.ffn_hidden_size, gather_output=False, init_method=init_method, skip_bias_add=True args.hidden_size, args.ffn_hidden_size, gather_output=False, init_method=init_method, skip_bias_add=True,
) use_cpu_initialization=args.use_cpu_initialization)
self.bias_gelu_fusion = args.bias_gelu_fusion self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu self.activation_func = F.gelu
...@@ -265,6 +262,7 @@ class ParallelMLP(MegatronModule): ...@@ -265,6 +262,7 @@ class ParallelMLP(MegatronModule):
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True, skip_bias_add=True,
use_cpu_initialization=args.use_cpu_initialization
) )
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -322,17 +320,14 @@ class ParallelAttention(MegatronModule): ...@@ -322,17 +320,14 @@ class ParallelAttention(MegatronModule):
# Strided linear layer. # Strided linear layer.
if attention_type == AttnType.self_attn: if attention_type == AttnType.self_attn:
self.query_key_value = tensor_parallel.ColumnParallelLinear( self.query_key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size, 3 * projection_size, gather_output=False, init_method=init_method args.hidden_size, 3 * projection_size, gather_output=False, init_method=init_method, use_cpu_initialization=args.use_cpu_initialization)
)
else: else:
assert attention_type == AttnType.cross_attn assert attention_type == AttnType.cross_attn
self.query = tensor_parallel.ColumnParallelLinear( self.query = tensor_parallel.ColumnParallelLinear(
args.hidden_size, projection_size, gather_output=False, init_method=init_method args.hidden_size, projection_size, gather_output=False, init_method=init_method, use_cpu_initialization=args.use_cpu_initialization)
)
self.key_value = tensor_parallel.ColumnParallelLinear( self.key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size, 2 * projection_size, gather_output=False, init_method=init_method args.hidden_size, 2 * projection_size, gather_output=False, init_method=init_method, use_cpu_initialization=args.use_cpu_initialization)
)
coeff = None coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
...@@ -362,6 +357,7 @@ class ParallelAttention(MegatronModule): ...@@ -362,6 +357,7 @@ class ParallelAttention(MegatronModule):
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True, skip_bias_add=True,
use_cpu_initialization=args.use_cpu_initialization
) )
# Inference key-value memory # Inference key-value memory
...@@ -1021,10 +1017,12 @@ class Embedding(MegatronModule): ...@@ -1021,10 +1017,12 @@ class Embedding(MegatronModule):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.init_method = init_method self.init_method = init_method
self.num_tokentypes = num_tokentypes self.num_tokentypes = num_tokentypes
args = get_args()
# Word embeddings (parallel). # Word embeddings (parallel).
self.word_embeddings = tensor_parallel.VocabParallelEmbedding( self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
vocab_size, self.hidden_size, init_method=self.init_method vocab_size, self.hidden_size, init_method=self.init_method,
use_cpu_initialization=args.use_cpu_initialization
) )
self._word_embeddings_key = "word_embeddings" self._word_embeddings_key = "word_embeddings"
...@@ -1048,6 +1046,7 @@ class Embedding(MegatronModule): ...@@ -1048,6 +1046,7 @@ class Embedding(MegatronModule):
# Embeddings dropout # Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
print("FINISH WORD EMBEDDING", self.word_embeddings)
def zero_parameters(self): def zero_parameters(self):
"""Zero out all parameters in embedding.""" """Zero out all parameters in embedding."""
...@@ -1500,7 +1499,6 @@ class GPTModel(MegatronModule): ...@@ -1500,7 +1499,6 @@ class GPTModel(MegatronModule):
state_dict = state_dict[self._language_model_key] state_dict = state_dict[self._language_model_key]
self.language_model.load_state_dict(state_dict, strict=strict) self.language_model.load_state_dict(state_dict, strict=strict)
def gpt_model_provider(pre_process=True, post_process=True): def gpt_model_provider(pre_process=True, post_process=True):
model = GPTModel(num_tokentypes=0, parallel_output=True, pre_process=pre_process, post_process=post_process) model = GPTModel(num_tokentypes=0, parallel_output=True, pre_process=pre_process, post_process=post_process)
return model return model
import random
import torch
from apex.transformer import tensor_parallel
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import vocab_parallel_cross_entropy
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import average_losses_across_data_parallel_group
from apex.transformer.pipeline_parallel.schedules import get_forward_backward_func
from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization
from apex.transformer.testing.standalone_bert import bert_model_provider
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import print_separator
mode = None
MANUAL_SEED = 42
inds = None
masks = None
data_idx = 0
MASK_PROB = 0.1
EASY_MODE = False
EASY_MODE_SIZ = 32
ONCE = False
# download a public domain book as corpus
def download_fancy_data():
#import requests
#response = requests.get('https://internet.com/book.txt')
#text = ' '.join(response.text.split())
text = """
An original sentence not subject to any license restrictions, copyright, or royalty payments. Nothing to see here. Commercial or non-commercial use. Research or non-research purposes. The quick brown fox jumps over the lazy dog. Lorem ipsum.
"""
text = text*1024
encoded = text.encode('ascii', 'replace')
ints = [int(encoded[i]) for i in range(len(encoded))]
return torch.tensor(ints)
# build a batch given sequence_len and batch size
def generate_fancy_data_labels(sequence_len, batch_size):
global data_idx
global inds
global masks
global MANUAL_SEED
temps = list()
for i in range(batch_size):
if inds is None or data_idx >= len(inds):
# hack as use of RNG will fall out of sync due to pipelines being different
torch.manual_seed(MANUAL_SEED)
inds = torch.randperm(effective_length, device='cuda')
masks = (torch.rand(len(inds)//batch_size + 1, batch_size, sequence_len, device='cuda') >= MASK_PROB).long()
MANUAL_SEED += 1
print("new epoch", len(inds))
data_idx = 0
print("my start", inds[0:5])
print("masks_checksum:", torch.sum(masks))
if EASY_MODE:
data_idx_ = data_idx % EASY_MODE_SIZ
else:
data_idx_ = data_idx
offset = inds[data_idx_] #* SEQUENCE_LEN
data_idx += 1
curr = fancy_data[offset:offset+sequence_len].clone().detach()
temps.append(curr)
temp = torch.stack(temps, dim=0).cuda()
mask = masks[data_idx//batch_size]
mask_not = torch.logical_not(mask)
data = mask * temp + mask_not*124
label = temp
return (data, label, mask_not)
easy_data = None
def fwd_step_func(batch, model):
data, label, loss_mask = batch
data = data.cuda()
label = label.cuda()
loss_mask = loss_mask.cuda()
y = model(data, torch.ones_like(data), lm_labels=label)
def loss_func(output_tensor):
global ONCE
output_tensor, _ = output_tensor
lm_loss_ = output_tensor.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
averaged_loss = average_losses_across_data_parallel_group([lm_loss])
if data_idx >= 1536:
assert lm_loss < 4.8
if not ONCE:
print("LOSS OK")
ONCE = True
return lm_loss, {'avg': averaged_loss}
return y, loss_func
def train(model, optim, virtual_pipeline_model_parallel_size, pipeline_model_parallel_size):
sequence_len = global_vars.get_args().seq_length
micro_batch_size = global_vars.get_args().micro_batch_size
hidden_size = global_vars.get_args().hidden_size
forward_backward_func = get_forward_backward_func(virtual_pipeline_model_parallel_size, pipeline_model_parallel_size)
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
for _ in range(8):
batch = generate_fancy_data_labels(sequence_len, batch_size)
optim.zero_grad()
forward_backward_func(fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape)
optim.step()
if __name__ == '__main__':
global fancy_data
global effective_length
global_vars.set_global_variables()
fancy_data = download_fancy_data()
effective_length = fancy_data.size(0) // global_vars.get_args().seq_length
effective_length = fancy_data.size(0) - global_vars.get_args().seq_length
initialize_distributed()
world_size = torch.distributed.get_world_size()
failure = None
try:
args = global_vars.get_args()
args.padded_vocab_size = 128 # needed in standalone gpt
batch_size = args.global_batch_size
micro_batch_size = args.micro_batch_size
setup_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
1, # args.data_parallel_size,
)
virtual_pipeline_model_parallel_size = 2
world_size = torch.distributed.get_world_size()
pipeline_model_parallel_size = world_size
parallel_state.initialize_model_parallel(
1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size)
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
tensor_parallel.random.model_parallel_cuda_manual_seed(0)
model = build_model(
bert_model_provider,
wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
)
assert isinstance(model, list)
assert len(model) == (1 if virtual_pipeline_model_parallel_size is None else virtual_pipeline_model_parallel_size)
_param_groups = _get_params_for_weight_decay_optimization(model)
optim = torch.optim.Adam(_param_groups)
print(effective_length)
print(fancy_data.size(0))
train(model, optim, virtual_pipeline_model_parallel_size, pipeline_model_parallel_size)
except Exception as e:
failure = str(e)
finally:
parallel_state.destroy_model_parallel()
if failure is not None:
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(f"Minimal BERT Pipeline Parallel Failed with: {failure}")
else:
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
...@@ -125,7 +125,7 @@ def forward_backward_func_template( ...@@ -125,7 +125,7 @@ def forward_backward_func_template(
assert isinstance(model, list) assert isinstance(model, list)
assert len(model) == (1 if virtual_pipeline_model_parallel_size is None else virtual_pipeline_model_parallel_size) assert len(model) == (1 if virtual_pipeline_model_parallel_size is None else virtual_pipeline_model_parallel_size)
_param_groups = _get_params_for_weight_decay_optimization(model) _param_groups = _get_params_for_weight_decay_optimization(model)
torch.optim.Adam(_param_groups) torch.optim.Adam(_param_groups, lr=1e-4)
tensor_shape = [batch_size // parallel_state.get_data_parallel_world_size(), hidden_size] tensor_shape = [batch_size // parallel_state.get_data_parallel_world_size(), hidden_size]
batch = (torch.randn(tensor_shape).cuda(),) batch = (torch.randn(tensor_shape).cuda(),)
......
...@@ -12,18 +12,27 @@ MULTIGPU_TEST = [ ...@@ -12,18 +12,27 @@ MULTIGPU_TEST = [
"pipeline_parallel_test", "pipeline_parallel_test",
"dynamic_batchsize_test", "dynamic_batchsize_test",
] ]
SEVERALGPU_TEST = [
"bert_minimal_test",
]
def get_multigpu_launch_option(min_gpu):
should_skip = False
import torch
num_devices = torch.cuda.device_count()
if num_devices < min_gpu:
should_skip = True
distributed_run_options = f"-m torch.distributed.run --nproc_per_node={num_devices}"
return should_skip, distributed_run_options
def get_launch_option(test_filename) -> Tuple[bool, str]: def get_launch_option(test_filename) -> Tuple[bool, str]:
should_skip = False should_skip = False
for multigpu_test in MULTIGPU_TEST: for multigpu_test in MULTIGPU_TEST:
if multigpu_test in test_filename: if multigpu_test in test_filename:
import torch return get_multigpu_launch_option(2)
num_devices = torch.cuda.device_count() for severalgpu_test in SEVERALGPU_TEST:
if num_devices < 2: if severalgpu_test in test_filename:
should_skip = True return get_multigpu_launch_option(3)
distributed_run_options = f"-m torch.distributed.run --nproc_per_node={num_devices}"
return should_skip, distributed_run_options
return should_skip, "" return should_skip, ""
...@@ -55,9 +64,15 @@ def run_transformer_tests(): ...@@ -55,9 +64,15 @@ def run_transformer_tests():
continue continue
test_run_cmd = ( test_run_cmd = (
f"{python_executable_path} {launch_option} {test_file} " f"{python_executable_path} {launch_option} {test_file} "
"--micro-batch-size 2 --num-layers 1 --hidden-size 256 --num-attention-heads 8 --max-position-embeddings " "--micro-batch-size 4 --num-layers 16 --hidden-size 768 --num-attention-heads 8 --max-position-embeddings "
"32 --encoder-seq-length 32 --use-cpu-initialization" "512 --seq-length 512 --global-batch-size 256"
) )
if 'bert' in test_file:
import torch
num_devices = torch.cuda.device_count()
test_run_cmd += f" --pipeline-model-parallel-size {num_devices}"
else:
test_run_cmd += f" --use-cpu-initialization"
print(f"### {i} / {len(files)}: cmd: {test_run_cmd}") print(f"### {i} / {len(files)}: cmd: {test_run_cmd}")
try: try:
output = subprocess.check_output( output = subprocess.check_output(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment