Commit 5c45db4a authored by Jared Casper's avatar Jared Casper Committed by Deepak Narayanan
Browse files

Initial implementation of pipelined text generation

parent caa9dca5
......@@ -34,7 +34,8 @@ from .initialize import get_tensor_model_parallel_rank, set_tensor_model_paralle
from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank
from .initialize import is_pipeline_first_stage, is_pipeline_last_stage
from .initialize import get_tensor_model_parallel_src_rank
from .initialize import get_pipeline_model_parallel_src_rank
from .initialize import get_pipeline_model_parallel_first_rank
from .initialize import get_pipeline_model_parallel_last_rank
from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size
from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size
from .initialize import initialize_model_parallel
......
......@@ -38,6 +38,7 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
_PIPELINE_GLOBAL_RANKS = None
def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization"""
......@@ -131,6 +132,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
global _PIPELINE_MODEL_PARALLEL_GROUP
global _PIPELINE_GLOBAL_RANKS
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
'pipeline model parallel group is already initialized'
global _EMBEDDING_GROUP
......@@ -142,6 +144,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
group = torch.distributed.new_group(ranks)
if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group
_PIPELINE_GLOBAL_RANKS = ranks
# Setup embedding group (to exchange gradients between
# first and last stages).
if len(ranks) > 1:
......@@ -265,21 +268,22 @@ def is_pipeline_last_stage():
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to a local rank
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = torch.distributed.get_rank()
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size
def get_pipeline_model_parallel_last_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_src_rank():
"""Calculate the global rank corresponding to a local rank
in the pipeline model parallel group."""
global_rank = torch.distributed.get_rank()
global_world_size = torch.distributed.get_world_size()
local_world_size = get_pipeline_model_parallel_world_size()
return global_rank % (global_world_size // local_world_size)
def get_pipeline_model_parallel_first_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
return _PIPELINE_GLOBAL_RANKS[0]
def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
......
......@@ -26,6 +26,7 @@ import torch.nn.functional as F
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
from megatron.training import communicate
from megatron.utils import get_ltor_masks_and_position_ids
......@@ -88,14 +89,14 @@ def generate_samples_input_from_file(model):
# Read the sample file and open the output file.
assert args.sample_input_file is not None, \
'sample input file is not provided.'
if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines()
input_count = len(all_raw_text)
input_pos = 0
if args.sample_output_file is None:
sample_output_file = args.sample_input_file + ".out"
print('could not find `sample-output-file`, setting '
print('`sample-output-file` not specified, setting '
'it to {}'.format(sample_output_file))
else:
sample_output_file = args.sample_output_file
......@@ -105,14 +106,16 @@ def generate_samples_input_from_file(model):
model.eval()
with torch.no_grad():
while True:
torch.distributed.barrier(group=mpu.get_tensor_model_parallel_group())
terminate_runs = 0
raw_text_len = 0
if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0:
raw_text = all_raw_text[input_pos]
input_pos += 1
if input_pos == input_count:
raw_text = "stop"
raw_text_len = len(raw_text)
if "stop" in raw_text:
terminate_runs = 1
......@@ -127,38 +130,60 @@ def generate_samples_input_from_file(model):
continue
else:
context_tokens = tokenizer.tokenize("EMPTY TEXT")
context_length = len(context_tokens)
context_length = 0
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item()
input_info = [terminate_runs, raw_text_len, context_length]
input_info_tensor = torch.cuda.LongTensor(input_info)
torch.distributed.all_reduce(input_info_tensor,
group=mpu.get_model_parallel_group())
terminate_runs = input_info_tensor[0].item()
raw_text_len = input_info_tensor[1].item()
if terminate_runs == 1:
return
# For pipeline parallel we send context tokens to last stage
# so it knows when to start overwriting
if mpu.get_tensor_model_parallel_rank() == 0 \
and args.pipeline_model_parallel_size > 1:
if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_embedding_group()
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
torch.distributed.broadcast(context_tokens_tensor, src, group)
if mpu.is_pipeline_last_stage():
src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_embedding_group()
context_length = input_info_tensor[2].item()
context_tokens_tensor = torch.empty(context_length,
dtype=torch.int64,
device=torch.device("cuda"))
torch.distributed.broadcast(context_tokens_tensor, src, group)
context_tokens = context_tokens_tensor.cpu().numpy().tolist()
token_stream = get_token_stream(model, [context_tokens])
for _, decode_tokens in enumerate(token_stream):
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
pass
if mpu.get_tensor_model_parallel_rank() == 0:
os.system('clear')
print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[len(raw_text):]
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
if mpu.is_pipeline_first_stage():
os.system('clear')
print("\nContext:", raw_text, flush=True)
fname_out.write("\nContext:")
fname_out.write(raw_text)
fname_out.write("\n\nMegatron-LM:")
fname_out.write(trim_decode_tokens)
fname_out.write("\n")
fname_out.write("\nContext:")
fname_out.write(raw_text)
raw_text = None
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[raw_text_len:]
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
fname_out.write("\n\nMegatron-LM:")
fname_out.write(trim_decode_tokens)
fname_out.write("\n")
torch.distributed.barrier(group=mpu.get_tensor_model_parallel_group())
raw_text = None
context_count += 1
......@@ -171,15 +196,17 @@ def generate_samples_interactive(model, print_frequency=24):
model.eval()
with torch.no_grad():
while True:
torch.distributed.barrier(group=mpu.get_tensor_model_parallel_group())
terminate_runs = 0
raw_text_len = 0
if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0:
os.system('clear')
raw_text = input("\nContext prompt (stop to exit) >>> ")
while not raw_text:
print('Prompt should not be empty!')
raw_text = input("\nContext prompt (stop to exit) >>> ")
raw_text_len = len(raw_text)
if "stop" in raw_text:
terminate_runs = 1
......@@ -194,43 +221,70 @@ def generate_samples_interactive(model, print_frequency=24):
continue
else:
context_tokens = tokenizer.tokenize("EMPTY TEXT")
context_length = len(context_tokens)
context_length = 0
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item()
input_info = [terminate_runs, raw_text_len, context_length]
input_info_tensor = torch.cuda.LongTensor(input_info)
torch.distributed.all_reduce(input_info_tensor,
group=mpu.get_model_parallel_group())
terminate_runs = input_info_tensor[0].item()
raw_text_len = input_info_tensor[1].item()
if terminate_runs == 1:
return
# For pipeline parallel we send context tokens to last stage
# so it knows when to start overwriting
if mpu.get_tensor_model_parallel_rank() == 0 \
and args.pipeline_model_parallel_size > 1:
if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_embedding_group()
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
torch.distributed.broadcast(context_tokens_tensor, src, group)
if mpu.is_pipeline_last_stage():
src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_embedding_group()
context_length = input_info_tensor[2].item()
context_tokens_tensor = torch.empty(context_length,
dtype=torch.int64,
device=torch.device("cuda"))
torch.distributed.broadcast(context_tokens_tensor, src, group)
context_tokens = context_tokens_tensor.cpu().numpy().tolist()
token_stream = get_token_stream(model, [context_tokens])
for counter, decode_tokens in enumerate(token_stream):
if counter % print_frequency != 0 \
or mpu.get_tensor_model_parallel_rank() != 0 \
or not mpu.is_pipeline_first_stage():
continue
os.system('clear')
print("\nContext:", raw_text, flush=True)
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[raw_text_len:]
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
if mpu.get_tensor_model_parallel_rank() == 0 and \
counter % print_frequency == 0:
os.system('clear')
print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[len(raw_text):]
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0:
os.system('clear')
print("\nContext:", raw_text, flush=True)
if not isinstance(decode_tokens, list):
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[len(raw_text):]
decode_tokens)[raw_text_len:]
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
input("\nPress Enter to continue >>>")
raw_text = None
torch.distributed.barrier(group=mpu.get_tensor_model_parallel_group())
context_count += 1
if mpu.get_tensor_model_parallel_rank() == 0:
input("\nPress any key to continue >>>")
def generate_samples_unconditional(model):
......@@ -247,22 +301,31 @@ def generate_samples_unconditional(model):
for token_stream in get_token_stream(model,
copy.deepcopy(context_tokens)):
pass
if ctr % args.log_interval == 0:
print('Avg s/batch:',
(time.time() - start_time) / min(args.log_interval, ctr + 1))
start_time = time.time()
length = len(token_stream)
token_batch = token_stream[0].cpu().numpy().tolist()
length_batch = token_stream[1].cpu().numpy().tolist()
for tokens, length in zip(token_batch, length_batch):
tokens = tokens[1:length - 1]
text = tokenizer.detokenize(tokens)
is_finished = length < args.seq_length - 1
datum = {'text': text, 'length': length - 1, 'finished': is_finished}
yield datum
ctr += 1
if ctr >= num_samples:
break
if mpu.is_pipeline_last_stage() and \
mpu.get_tensor_model_parallel_rank() == 0:
if ctr % args.log_interval == 0:
print('Avg s/batch:',
(time.time() - start_time) / min(args.log_interval, ctr + 1))
start_time = time.time()
length = len(token_stream)
token_batch = token_stream[0].cpu().numpy().tolist()
length_batch = token_stream[1].cpu().numpy().tolist()
assert len(length_batch) == args.batch_size
for tokens, length in zip(token_batch, length_batch):
tokens = tokens[1:length - 1]
text = tokenizer.detokenize(tokens)
is_finished = length < args.seq_length - 1
datum = {'text': text, 'length': length - 1, 'finished': is_finished}
yield datum
ctr += 1
if ctr >= num_samples:
break
else:
for _ in range(args.batch_size):
yield None
ctr += 1
if ctr >= num_samples:
break
if ctr >= num_samples:
break
......@@ -273,7 +336,9 @@ def generate_and_write_samples_unconditional(model):
assert args.genfile is not None
with open(args.genfile, 'w') as f:
for datum in generate_samples_unconditional(model):
f.write(json.dumps(datum) + '\n')
if mpu.is_pipeline_last_stage() and \
mpu.get_tensor_model_parallel_rank() == 0:
f.write(json.dumps(datum) + '\n')
def pad_batch(batch, pad_id, args):
......@@ -313,7 +378,10 @@ def get_token_stream(model, context_tokens):
attention_mask, position_ids)
for tokens, lengths in batch_token_iterator:
context_length += 1
yield tokens[:, :context_length], lengths
if tokens is not None:
yield tokens[:, :context_length], lengths
else:
yield None, None
def switch(val1, val2, boolean):
......@@ -322,6 +390,60 @@ def switch(val1, val2, boolean):
return (1 - boolean) * val1 + boolean * val2
def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
layer_past=None, get_key_value=None,
forward_method_parallel_output=None):
if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else:
input_tensor = None
# Forward pass through the model.
if mpu.is_pipeline_first_stage():
assert input_tensor is None
if mpu.is_pipeline_last_stage():
output_tensor = model(tokens, position_ids, attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
else:
output_tensor = model(tokens, position_ids, attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value)
elif mpu.is_pipeline_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
if get_key_value:
output_tensor, layer_past = output_tensor
if not mpu.is_pipeline_last_stage():
communicate(tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
return None
if get_key_value:
return output_tensor, layer_past
return output_tensor
def sample_sequence_batch(model, context_tokens, context_lengths,
attention_mask, position_ids,
maxlen=None, type_ids=None):
......@@ -349,14 +471,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
lengths = torch.ones([batch_size]).long().cuda() * maxlen
while context_length <= (maxlen):
if args.recompute:
logits = model(tokens,
position_ids,
attention_mask,
tokentype_ids=type_ids,
forward_method_parallel_output=False)
logits = logits[:, context_length - 1, :]
output = forward_step(model, tokens,
position_ids,
attention_mask,
tokentype_ids=type_ids,
forward_method_parallel_output=False)
if mpu.is_pipeline_last_stage():
assert output is not None
logits = output[:, context_length - 1, :]
else:
types2use = None
if counter == 0:
......@@ -372,41 +495,65 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if type_ids is not None:
types2use = type_ids[:, context_length - 1].view(
batch_size, -1)
logits, layer_past = model(tokens2use,
positions2use,
attention_mask,
layer_past=layer_past,
get_key_value=True,
tokentype_ids=types2use,
forward_method_parallel_output=False)
logits = logits[:, -1].view(batch_size, -1).contiguous()
if args.greedy:
prev = torch.argmax(logits, dim=-1).view(-1)
logits, layer_past = forward_step(model, tokens2use,
positions2use,
attention_mask,
layer_past=layer_past,
get_key_value=True,
tokentype_ids=types2use,
forward_method_parallel_output=False)
if mpu.is_pipeline_last_stage():
assert output is not None
logits = logits[:, -1].view(batch_size, -1).contiguous()
if mpu.is_pipeline_last_stage():
if args.greedy:
prev = torch.argmax(logits, dim=-1).view(-1)
else:
logits = logits.float()
logits /= args.temperature
logits = top_k_logits(logits, top_k=args.top_k,
top_p=args.top_p)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
started = context_lengths <= context_length
new_tokens = switch(
tokens[:, context_length].view(-1), prev, started)
tokens[:, context_length] = new_tokens
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
torch.distributed.broadcast(new_tokens, src, group)
done_token = (prev == eos_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length
is_done = is_done | done_token
done = torch.all(is_done)
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
yield tokens, lengths
else:
logits = logits.float()
logits /= args.temperature
logits = top_k_logits(logits, top_k=args.top_k,
top_p=args.top_p)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
print_logits = []
for p in prev:
print_logits.append([logits[i, p].item()
for i in range(batch_size)])
started = context_lengths <= context_length
tokens[:, context_length] = switch(
tokens[:, context_length].view(-1), prev, started)
context_length += 1
counter += 1
if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
new_tokens = torch.empty_like(tokens[:, context_length])
torch.distributed.broadcast(new_tokens, src, group)
tokens[:, context_length] = new_tokens
yield tokens, None
else:
yield None, None
done_token = (prev == eos_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length
is_done = is_done | done_token
done = torch.all(is_done)
done = torch.cuda.ByteTensor([0])
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
yield tokens, lengths
context_length += 1
counter += 1
if done:
break
......@@ -23,9 +23,10 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPT2Model
from megatron.model import GPT2Model, GPT2ModelFirstStage, GPT2ModelLastStage, GPT2ModelIntermediateStage
from megatron.training import get_model
from megatron.text_generation_utils import generate_and_write_samples_unconditional
from megatron.text_generation_utils import generate_samples_input_from_file
......@@ -36,7 +37,19 @@ def model_provider():
"""Build the model."""
print_rank_0('building GPT2 model ...')
model = GPT2Model(num_tokentypes=0, parallel_output=False)
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = GPT2ModelFirstStage(num_tokentypes=0)
elif mpu.is_pipeline_last_stage():
model = GPT2ModelLastStage(
num_tokentypes=0, parallel_output=False)
else:
model = GPT2ModelIntermediateStage(
num_tokentypes=0)
else:
model = GPT2Model(num_tokentypes=0, parallel_output=False)
return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment