Unverified Commit ab7af058 authored by Rishi Puri's avatar Rishi Puri Committed by GitHub
Browse files

Minimal gpt pipeline parallel (builds off of minimal_bert_pipeline_parallel)...


Minimal gpt pipeline parallel (builds off of minimal_bert_pipeline_parallel) including cpu-offloading (#1222)

* minimal bert pipeline parallel test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* first draft of gpt minimal test

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* framework to scale up the gpt2 test for variety of distributed setups

* adding gpt_minimal_test to list of multigpu tests
Co-authored-by: default avatarEddie Yan <eddiey@nvidia.com>
Co-authored-by: default avatarriship <riship@nvidia.com>
parent 1203099a
......@@ -640,6 +640,8 @@ def _add_distributed_args(parser):
group.add_argument('--use-cpu-initialization', action='store_true',
default=None, help='If set, affine parallel weights '
'initialization uses CPU' )
group.add_argument('--cpu-offload', action='store_true',
default=False, help='Turns on CPU offloading')
group.add_argument('--empty-unused-memory-level', default=0, type=int,
choices=[0, 1, 2],
help='Call torch.cuda.empty_cache() each iteration '
......
import torch
import contextlib
from apex.normalization import FusedLayerNorm as LayerNorm
from apex.transformer import tensor_parallel
from apex.transformer.enums import AttnMaskType
......@@ -106,7 +106,8 @@ class BertModel(MegatronModule):
add_binary_head=True,
parallel_output=True,
pre_process=True,
post_process=True):
post_process=True,
cpu_offload=False):
super(BertModel, self).__init__()
args = get_args()
......@@ -115,7 +116,7 @@ class BertModel(MegatronModule):
self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
self.cpu_offload = cpu_offload
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
......@@ -147,31 +148,31 @@ class BertModel(MegatronModule):
def forward(self, bert_model_input, attention_mask,
tokentype_ids=None, lm_labels=None):
extended_attention_mask = bert_extended_attention_mask(attention_mask)
input_ids = bert_model_input
position_ids = bert_position_ids(input_ids)
lm_output = self.language_model(
input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids
)
if self.post_process and self.add_binary_head:
lm_output, pooled_output = lm_output
else:
pooled_output = None
if self.post_process:
return post_language_model_processing(lm_output, pooled_output,
self.lm_head, self.binary_head,
lm_labels,
self.word_embeddings_weight(),
self.fp16_lm_cross_entropy)
else:
return lm_output
with torch.autograd.graph.save_on_cpu() if self.cpu_offload else contextlib.nullcontext():
extended_attention_mask = bert_extended_attention_mask(attention_mask)
input_ids = bert_model_input
position_ids = bert_position_ids(input_ids)
lm_output = self.language_model(
input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids
)
if self.post_process and self.add_binary_head:
lm_output, pooled_output = lm_output
else:
pooled_output = None
if self.post_process:
return post_language_model_processing(lm_output, pooled_output,
self.lm_head, self.binary_head,
lm_labels,
self.word_embeddings_weight(),
self.fp16_lm_cross_entropy)
else:
return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
......@@ -212,6 +213,6 @@ class BertModel(MegatronModule):
self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict)
def bert_model_provider(pre_process=True, post_process=True):
model = BertModel(num_tokentypes=0, add_binary_head=False, pre_process=pre_process, post_process=post_process)
def bert_model_provider(pre_process=True, post_process=True, cpu_offload=False):
model = BertModel(num_tokentypes=0, add_binary_head=False, pre_process=pre_process, post_process=post_process, cpu_offload=cpu_offload)
return model
......@@ -16,6 +16,8 @@
"""GPT-2 model."""
import enum
import math
import contextlib
import json
import torch
import torch.nn.functional as F
......@@ -1046,7 +1048,8 @@ class Embedding(MegatronModule):
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
print("FINISH WORD EMBEDDING", self.word_embeddings)
if torch.distributed.get_rank() == 0:
print("FINISH WORD EMBEDDING", self.word_embeddings)
def zero_parameters(self):
"""Zero out all parameters in embedding."""
......@@ -1422,18 +1425,29 @@ def post_language_model_processing(lm_output, labels, logit_weights, parallel_ou
loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)
return loss
def module_size(m: torch.nn.Module, only_trainable: bool = False):
"""
returns the total number of parameters used by `m` (only counting
shared parameters once); if `only_trainable` is True, then only
includes parameters with `requires_grad = True`
"""
parameters = list(m.parameters())
if only_trainable:
parameters = [p for p in parameters if p.requires_grad]
unique = {p.data_ptr(): p for p in parameters}.values()
return sum(p.numel() for p in unique)
class GPTModel(MegatronModule):
"""GPT-2 Language model."""
def __init__(self, num_tokentypes=0, parallel_output=True, pre_process=True, post_process=True):
def __init__(self, num_tokentypes=0, parallel_output=True, pre_process=True, post_process=True, cpu_offload=False):
super(GPTModel, self).__init__()
args = get_args()
self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.cpu_offload = cpu_offload
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
add_pooler=False,
......@@ -1461,20 +1475,21 @@ class GPTModel(MegatronModule):
inference_max_sequence_len=None,
):
lm_output = self.language_model(
input_ids,
position_ids,
attention_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len,
)
if self.post_process:
return post_language_model_processing(
lm_output, labels, self.word_embeddings_weight(), self.parallel_output, self.fp16_lm_cross_entropy
with torch.autograd.graph.save_on_cpu() if self.cpu_offload else contextlib.nullcontext():
lm_output = self.language_model(
input_ids,
position_ids,
attention_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len,
)
else:
return lm_output
if self.post_process:
return post_language_model_processing(
lm_output, labels, self.word_embeddings_weight(), self.parallel_output, self.fp16_lm_cross_entropy
)
else:
return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False):
......@@ -1499,6 +1514,11 @@ class GPTModel(MegatronModule):
state_dict = state_dict[self._language_model_key]
self.language_model.load_state_dict(state_dict, strict=strict)
def gpt_model_provider(pre_process=True, post_process=True):
model = GPTModel(num_tokentypes=0, parallel_output=True, pre_process=pre_process, post_process=post_process)
def gpt_model_provider(pre_process=True, post_process=False, cpu_offload=False):
model = GPTModel(num_tokentypes=0, parallel_output=True, pre_process=pre_process, post_process=post_process, cpu_offload=cpu_offload)
if torch.distributed.get_rank() == 0:
init_dict = {"pre_process":pre_process, "post_process":post_process, "cpu_offload":cpu_offload}
print("Initialized GPT-2 w/:", json.dumps(init_dict))
n_params = module_size(model) * parallel_state.get_tensor_model_parallel_world_size() * parallel_state.get_pipeline_model_parallel_world_size()
print("Number of Parameters:", n_params)
return model
import subprocess
import os
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
def run_gpt(cmd):
args = list(cmd.split(' '))
p = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
outs, errs = p.communicate()
outs = list(str((outs).decode('utf-8')).splitlines())
success = False
runtime = 0
num_params = 0
for out in outs:
out=str(out)
if "Average Iteration Time:" in str(out):
slicey = out[out.find(':')+2:]
try:
runtime = float(slicey)
except:
print(slicey)
quit()
if "Number of Parameters:" in str(out):
slicey = out[out.find(':')+2:]
try:
num_params = int(slicey)
except:
print(slicey)
quit()
if str(out) == str(TEST_SUCCESS_MESSAGE):
success=True
return runtime, round(float(int(num_params))/10.0**9,3), success, errs
def plot(runtimes):
import matplotlib.pyplot as plt
for distributed_setting in runtimes.keys():
plt.scatter(runtimes[distributed_setting].keys(), runtimes[distributed_setting].values(), label=distributed_setting)
plt.legend()
plt.xlabel('Parameters (Billions)')
plt.ylabel('Training Iteration time (s)')
plt.title(str("GPT Scaling w/ Offloading"))
plt.savefig('offload_gpt_scaling.png')
plt.close()
if not os.path.exists('/my_workspace/'):
os.system('mkdir /my_workspace/')
os.system('cp *.png /my_workspace/')
def main():
runtimes = {}
nlist = list(range(2000,10000,2000)) + list(range(10000,50000,5000)) + list(range(50000,100000,10000))
print("N-List:", nlist)
for data_parr, tens_parr, pipe_parr in [(8,1,1), (4,2,1), (2,1,4), (1,2,4)]:
for offload in [True, False]:
dist_setting = 'ddp=' + str(data_parr) + ', tensor_parr=' + str(tens_parr) + ', pipe_parr=' + str(pipe_parr) + ', offload=' + str(offload)
runtimes[dist_setting] = {}
print("Beginning Testing for", dist_setting)
for n in nlist:
cmd = "python3 -m torch.distributed.launch --nproc_per_node=8 run_gpt_minimal_test.py"
cmd += " --micro-batch-size 1 --num-layers " + str(n) + " --hidden-size 128 --num-attention-heads 16"
cmd += ' --max-position-embeddings 128 --seq-length 128 --tensor-model-parallel-size ' + str(tens_parr)
cmd += " --pipeline-model-parallel-size " + str(pipe_parr) + (' --cpu-offload' if offload else '')
print(cmd)
runtime, bill_params, success, errs = run_gpt(cmd)
if success:
runtimes[dist_setting][bill_params] = runtime
print(str(runtime) + 's per training iter for', str(bill_params) + 'B parameter GPT-2')
if n >= 10000:
plot(runtimes)
else:
print("GPT-2 w/", n, "layers failed using", dist_setting)
print("Moving on to the next distributed setting...")
print("#"*(25))
print()
plot(runtimes)
break
print(runtimes)
plot(runtimes)
if __name__ == "__main__":
main()
\ No newline at end of file
......@@ -148,7 +148,7 @@ if __name__ == '__main__':
bert_model_provider,
wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
)
cpu_offload=args.cpu_offload)
assert isinstance(model, list)
assert len(model) == (1 if virtual_pipeline_model_parallel_size is None else virtual_pipeline_model_parallel_size)
_param_groups = _get_params_for_weight_decay_optimization(model)
......
import torch
import os
from typing import List
import time
from functools import partial
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import model_parallel_cuda_manual_seed
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import average_losses_across_data_parallel_group
from apex.transformer.pipeline_parallel.utils import get_ltor_masks_and_position_ids
from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import forward_backward_pipelining_without_interleaving
from apex.transformer.testing.standalone_gpt import gpt_model_provider
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed
MANUAL_SEED = 42
inds = None
data_idx = 0
N_VOCAB = 128
# download a public domain book as corpus
def download_fancy_data():
if not os.path.exists('data.txt'):
import requests
response = requests.get('https://www.gutenberg.org/files/1342/1342-0.txt')
text = ' '.join(response.text.split())
with open('data.txt','w+') as f:
print(text, file=f)
else:
text = open('data.txt','r').read()
encoded = text.encode('ascii', 'replace')
ints = [int(encoded[i]) for i in range(len(encoded))]
return torch.tensor(ints)
# build a batch given sequence_len and batch size
def generate_fancy_data_labels(sequence_len, batch_size):
global data_idx
global inds
global MANUAL_SEED
temps = list()
for i in range(batch_size):
if inds is None or data_idx >= len(inds):
# hack as use of RNG will fall out of sync due to pipelines being different
model_parallel_cuda_manual_seed(MANUAL_SEED)
inds = torch.randperm(effective_length, device='cuda')
MANUAL_SEED += 1
data_idx = 0
data_idx_ = data_idx
offset = inds[data_idx_]
data_idx += 1
curr = fancy_data[offset:offset+sequence_len+1].clone().detach()
temps.append(curr)
temp = torch.stack(temps, dim=0).cuda()
return temp
easy_data = None
def get_batch(int_tensors: List[torch.Tensor]):
data = int_tensors[0]
# Unpack.
tokens_ = data.long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and position ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
N_VOCAB, # tokenizer.eod,
False, # args.reset_position_ids,
False, # args.reset_attention_mask,
False, # args.eod_mask_loss,
)
return tokens, labels, loss_mask, attention_mask, position_ids
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L75
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L86
def fwd_step_func(batch, model):
"""Forward step."""
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(batch)
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def train(model, optim, pipeline_model_parallel_size):
sequence_len = global_vars.get_args().seq_length
micro_batch_size = global_vars.get_args().micro_batch_size
hidden_size = global_vars.get_args().hidden_size
fwd_bwd_func = forward_backward_pipelining_without_interleaving
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
runtime = 0
#training loop
for i in range(3):
since = time.time()
if torch.distributed.get_rank() == 0:
print('begin iter', i)
batch = [generate_fancy_data_labels(args.seq_length, args.global_batch_size) for _ in range(pipeline_model_parallel_size)]
if torch.distributed.get_rank() == 0:
print("finished making batch...")
optim.zero_grad()
fwd_bwd_func(fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape)
if torch.distributed.get_rank() == 0:
print('finished forward step')
optim.step()
if torch.distributed.get_rank() == 0:
print('finished iter', i)
runtime += time.time() - since
return runtime/3.0
if __name__ == '__main__':
global fancy_data
global effective_length
global_vars.set_global_variables()
fancy_data = download_fancy_data()
args = global_vars.get_args()
effective_length = fancy_data.size(0) // args.seq_length
effective_length = fancy_data.size(0) - args.seq_length
initialize_distributed()
world_size = torch.distributed.get_world_size()
failure = None
args.padded_vocab_size = 128
batch_size = args.global_batch_size
micro_batch_size = args.micro_batch_size
setup_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
args.data_parallel_size, # args.data_parallel_size,
)
world_size = torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=args.tensor_model_parallel_size,\
pipeline_model_parallel_size_=args.pipeline_model_parallel_size)
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
model_parallel_cuda_manual_seed(0)
model = build_model(
gpt_model_provider,
wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=None,
cpu_offload=args.cpu_offload
)
assert isinstance(model, list), model
_param_groups = _get_params_for_weight_decay_optimization(model)
optim = torch.optim.Adam(_param_groups)
runtime = train(model, optim, args.pipeline_model_parallel_size)
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
print("Average Iteration Time:", runtime)
\ No newline at end of file
......@@ -14,6 +14,7 @@ MULTIGPU_TEST = [
]
SEVERALGPU_TEST = [
"bert_minimal_test",
"gpt_minimal_test",
]
def get_multigpu_launch_option(min_gpu):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment