"src/vscode:/vscode.git/clone" did not exist on "e3bc4aab2ef7b319d2b49e99a25bc2b1b1363bfa"
Commit 5c45db4a authored by Jared Casper's avatar Jared Casper Committed by Deepak Narayanan
Browse files

Initial implementation of pipelined text generation

parent caa9dca5
...@@ -34,7 +34,8 @@ from .initialize import get_tensor_model_parallel_rank, set_tensor_model_paralle ...@@ -34,7 +34,8 @@ from .initialize import get_tensor_model_parallel_rank, set_tensor_model_paralle
from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank
from .initialize import is_pipeline_first_stage, is_pipeline_last_stage from .initialize import is_pipeline_first_stage, is_pipeline_last_stage
from .initialize import get_tensor_model_parallel_src_rank from .initialize import get_tensor_model_parallel_src_rank
from .initialize import get_pipeline_model_parallel_src_rank from .initialize import get_pipeline_model_parallel_first_rank
from .initialize import get_pipeline_model_parallel_last_rank
from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size
from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size
from .initialize import initialize_model_parallel from .initialize import initialize_model_parallel
......
...@@ -38,6 +38,7 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None ...@@ -38,6 +38,7 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None _MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
_PIPELINE_GLOBAL_RANKS = None
def is_unitialized(): def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization""" """Useful for code segments that may be accessed with or without mpu initialization"""
...@@ -131,6 +132,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -131,6 +132,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# Build the pipeline model-parallel groups and embedding groups # Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group). # (first and last rank in each pipeline model-parallel group).
global _PIPELINE_MODEL_PARALLEL_GROUP global _PIPELINE_MODEL_PARALLEL_GROUP
global _PIPELINE_GLOBAL_RANKS
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \ assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
'pipeline model parallel group is already initialized' 'pipeline model parallel group is already initialized'
global _EMBEDDING_GROUP global _EMBEDDING_GROUP
...@@ -142,6 +144,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -142,6 +144,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
group = torch.distributed.new_group(ranks) group = torch.distributed.new_group(ranks)
if rank in ranks: if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group _PIPELINE_MODEL_PARALLEL_GROUP = group
_PIPELINE_GLOBAL_RANKS = ranks
# Setup embedding group (to exchange gradients between # Setup embedding group (to exchange gradients between
# first and last stages). # first and last stages).
if len(ranks) > 1: if len(ranks) > 1:
...@@ -265,21 +268,22 @@ def is_pipeline_last_stage(): ...@@ -265,21 +268,22 @@ def is_pipeline_last_stage():
def get_tensor_model_parallel_src_rank(): def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to a local rank """Calculate the global rank corresponding to the first local rank
in the tensor model parallel group.""" in the tensor model parallel group."""
global_rank = torch.distributed.get_rank() global_rank = torch.distributed.get_rank()
local_world_size = get_tensor_model_parallel_world_size() local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size return (global_rank // local_world_size) * local_world_size
def get_pipeline_model_parallel_last_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_src_rank(): def get_pipeline_model_parallel_first_rank():
"""Calculate the global rank corresponding to a local rank assert _PIPELINE_GLOBAL_RANKS is not None, \
in the pipeline model parallel group.""" "Pipeline parallel group is not initialized"
global_rank = torch.distributed.get_rank() return _PIPELINE_GLOBAL_RANKS[0]
global_world_size = torch.distributed.get_world_size()
local_world_size = get_pipeline_model_parallel_world_size()
return global_rank % (global_world_size // local_world_size)
def get_data_parallel_world_size(): def get_data_parallel_world_size():
"""Return world size for the data parallel group.""" """Return world size for the data parallel group."""
......
...@@ -26,6 +26,7 @@ import torch.nn.functional as F ...@@ -26,6 +26,7 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.training import communicate
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
...@@ -88,14 +89,14 @@ def generate_samples_input_from_file(model): ...@@ -88,14 +89,14 @@ def generate_samples_input_from_file(model):
# Read the sample file and open the output file. # Read the sample file and open the output file.
assert args.sample_input_file is not None, \ assert args.sample_input_file is not None, \
'sample input file is not provided.' 'sample input file is not provided.'
if mpu.get_tensor_model_parallel_rank() == 0: if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r") fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines() all_raw_text = fname.readlines()
input_count = len(all_raw_text) input_count = len(all_raw_text)
input_pos = 0 input_pos = 0
if args.sample_output_file is None: if args.sample_output_file is None:
sample_output_file = args.sample_input_file + ".out" sample_output_file = args.sample_input_file + ".out"
print('could not find `sample-output-file`, setting ' print('`sample-output-file` not specified, setting '
'it to {}'.format(sample_output_file)) 'it to {}'.format(sample_output_file))
else: else:
sample_output_file = args.sample_output_file sample_output_file = args.sample_output_file
...@@ -105,14 +106,16 @@ def generate_samples_input_from_file(model): ...@@ -105,14 +106,16 @@ def generate_samples_input_from_file(model):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
while True: while True:
torch.distributed.barrier(group=mpu.get_tensor_model_parallel_group())
terminate_runs = 0 terminate_runs = 0
raw_text_len = 0
if mpu.get_tensor_model_parallel_rank() == 0: if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0:
raw_text = all_raw_text[input_pos] raw_text = all_raw_text[input_pos]
input_pos += 1 input_pos += 1
if input_pos == input_count: if input_pos == input_count:
raw_text = "stop" raw_text = "stop"
raw_text_len = len(raw_text)
if "stop" in raw_text: if "stop" in raw_text:
terminate_runs = 1 terminate_runs = 1
...@@ -127,38 +130,60 @@ def generate_samples_input_from_file(model): ...@@ -127,38 +130,60 @@ def generate_samples_input_from_file(model):
continue continue
else: else:
context_tokens = tokenizer.tokenize("EMPTY TEXT") context_tokens = tokenizer.tokenize("EMPTY TEXT")
context_length = len(context_tokens) context_length = 0
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) input_info = [terminate_runs, raw_text_len, context_length]
torch.distributed.broadcast(terminate_runs_tensor, input_info_tensor = torch.cuda.LongTensor(input_info)
mpu.get_tensor_model_parallel_src_rank(), torch.distributed.all_reduce(input_info_tensor,
group=mpu.get_tensor_model_parallel_group()) group=mpu.get_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item() terminate_runs = input_info_tensor[0].item()
raw_text_len = input_info_tensor[1].item()
if terminate_runs == 1: if terminate_runs == 1:
return return
# For pipeline parallel we send context tokens to last stage
# so it knows when to start overwriting
if mpu.get_tensor_model_parallel_rank() == 0 \
and args.pipeline_model_parallel_size > 1:
if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_embedding_group()
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
torch.distributed.broadcast(context_tokens_tensor, src, group)
if mpu.is_pipeline_last_stage():
src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_embedding_group()
context_length = input_info_tensor[2].item()
context_tokens_tensor = torch.empty(context_length,
dtype=torch.int64,
device=torch.device("cuda"))
torch.distributed.broadcast(context_tokens_tensor, src, group)
context_tokens = context_tokens_tensor.cpu().numpy().tolist()
token_stream = get_token_stream(model, [context_tokens]) token_stream = get_token_stream(model, [context_tokens])
for _, decode_tokens in enumerate(token_stream): for _, decode_tokens in enumerate(token_stream):
decode_tokens, _ = decode_tokens pass
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_tensor_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage():
os.system('clear') os.system('clear')
print("\nContext:", raw_text, flush=True) print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[len(raw_text):]
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
fname_out.write("\nContext:") fname_out.write("\nContext:")
fname_out.write(raw_text) fname_out.write(raw_text)
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[raw_text_len:]
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
fname_out.write("\n\nMegatron-LM:") fname_out.write("\n\nMegatron-LM:")
fname_out.write(trim_decode_tokens) fname_out.write(trim_decode_tokens)
fname_out.write("\n") fname_out.write("\n")
raw_text = None raw_text = None
torch.distributed.barrier(group=mpu.get_tensor_model_parallel_group())
context_count += 1 context_count += 1
...@@ -171,15 +196,17 @@ def generate_samples_interactive(model, print_frequency=24): ...@@ -171,15 +196,17 @@ def generate_samples_interactive(model, print_frequency=24):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
while True: while True:
torch.distributed.barrier(group=mpu.get_tensor_model_parallel_group())
terminate_runs = 0 terminate_runs = 0
raw_text_len = 0
if mpu.get_tensor_model_parallel_rank() == 0: if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0:
os.system('clear') os.system('clear')
raw_text = input("\nContext prompt (stop to exit) >>> ") raw_text = input("\nContext prompt (stop to exit) >>> ")
while not raw_text: while not raw_text:
print('Prompt should not be empty!') print('Prompt should not be empty!')
raw_text = input("\nContext prompt (stop to exit) >>> ") raw_text = input("\nContext prompt (stop to exit) >>> ")
raw_text_len = len(raw_text)
if "stop" in raw_text: if "stop" in raw_text:
terminate_runs = 1 terminate_runs = 1
...@@ -194,43 +221,70 @@ def generate_samples_interactive(model, print_frequency=24): ...@@ -194,43 +221,70 @@ def generate_samples_interactive(model, print_frequency=24):
continue continue
else: else:
context_tokens = tokenizer.tokenize("EMPTY TEXT") context_tokens = tokenizer.tokenize("EMPTY TEXT")
context_length = len(context_tokens) context_length = 0
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) input_info = [terminate_runs, raw_text_len, context_length]
torch.distributed.broadcast(terminate_runs_tensor, input_info_tensor = torch.cuda.LongTensor(input_info)
mpu.get_tensor_model_parallel_src_rank(), torch.distributed.all_reduce(input_info_tensor,
group=mpu.get_tensor_model_parallel_group()) group=mpu.get_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item() terminate_runs = input_info_tensor[0].item()
raw_text_len = input_info_tensor[1].item()
if terminate_runs == 1: if terminate_runs == 1:
return return
# For pipeline parallel we send context tokens to last stage
# so it knows when to start overwriting
if mpu.get_tensor_model_parallel_rank() == 0 \
and args.pipeline_model_parallel_size > 1:
if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_embedding_group()
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
torch.distributed.broadcast(context_tokens_tensor, src, group)
if mpu.is_pipeline_last_stage():
src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_embedding_group()
context_length = input_info_tensor[2].item()
context_tokens_tensor = torch.empty(context_length,
dtype=torch.int64,
device=torch.device("cuda"))
torch.distributed.broadcast(context_tokens_tensor, src, group)
context_tokens = context_tokens_tensor.cpu().numpy().tolist()
token_stream = get_token_stream(model, [context_tokens]) token_stream = get_token_stream(model, [context_tokens])
for counter, decode_tokens in enumerate(token_stream): for counter, decode_tokens in enumerate(token_stream):
decode_tokens, _ = decode_tokens if counter % print_frequency != 0 \
decode_tokens = decode_tokens[0].cpu().numpy().tolist() or mpu.get_tensor_model_parallel_rank() != 0 \
or not mpu.is_pipeline_first_stage():
continue
if mpu.get_tensor_model_parallel_rank() == 0 and \
counter % print_frequency == 0:
os.system('clear') os.system('clear')
print("\nContext:", raw_text, flush=True) print("\nContext:", raw_text, flush=True)
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize( trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[len(raw_text):] decode_tokens)[raw_text_len:]
print("\nMegatron-LM:", trim_decode_tokens, flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True)
if mpu.get_tensor_model_parallel_rank() == 0: if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0:
os.system('clear') os.system('clear')
print("\nContext:", raw_text, flush=True) print("\nContext:", raw_text, flush=True)
if not isinstance(decode_tokens, list):
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize( trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[len(raw_text):] decode_tokens)[raw_text_len:]
print("\nMegatron-LM:", trim_decode_tokens, flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True)
input("\nPress Enter to continue >>>")
raw_text = None raw_text = None
torch.distributed.barrier(group=mpu.get_tensor_model_parallel_group())
context_count += 1 context_count += 1
if mpu.get_tensor_model_parallel_rank() == 0:
input("\nPress any key to continue >>>")
def generate_samples_unconditional(model): def generate_samples_unconditional(model):
...@@ -247,6 +301,8 @@ def generate_samples_unconditional(model): ...@@ -247,6 +301,8 @@ def generate_samples_unconditional(model):
for token_stream in get_token_stream(model, for token_stream in get_token_stream(model,
copy.deepcopy(context_tokens)): copy.deepcopy(context_tokens)):
pass pass
if mpu.is_pipeline_last_stage() and \
mpu.get_tensor_model_parallel_rank() == 0:
if ctr % args.log_interval == 0: if ctr % args.log_interval == 0:
print('Avg s/batch:', print('Avg s/batch:',
(time.time() - start_time) / min(args.log_interval, ctr + 1)) (time.time() - start_time) / min(args.log_interval, ctr + 1))
...@@ -254,6 +310,7 @@ def generate_samples_unconditional(model): ...@@ -254,6 +310,7 @@ def generate_samples_unconditional(model):
length = len(token_stream) length = len(token_stream)
token_batch = token_stream[0].cpu().numpy().tolist() token_batch = token_stream[0].cpu().numpy().tolist()
length_batch = token_stream[1].cpu().numpy().tolist() length_batch = token_stream[1].cpu().numpy().tolist()
assert len(length_batch) == args.batch_size
for tokens, length in zip(token_batch, length_batch): for tokens, length in zip(token_batch, length_batch):
tokens = tokens[1:length - 1] tokens = tokens[1:length - 1]
text = tokenizer.detokenize(tokens) text = tokenizer.detokenize(tokens)
...@@ -263,6 +320,12 @@ def generate_samples_unconditional(model): ...@@ -263,6 +320,12 @@ def generate_samples_unconditional(model):
ctr += 1 ctr += 1
if ctr >= num_samples: if ctr >= num_samples:
break break
else:
for _ in range(args.batch_size):
yield None
ctr += 1
if ctr >= num_samples:
break
if ctr >= num_samples: if ctr >= num_samples:
break break
...@@ -273,6 +336,8 @@ def generate_and_write_samples_unconditional(model): ...@@ -273,6 +336,8 @@ def generate_and_write_samples_unconditional(model):
assert args.genfile is not None assert args.genfile is not None
with open(args.genfile, 'w') as f: with open(args.genfile, 'w') as f:
for datum in generate_samples_unconditional(model): for datum in generate_samples_unconditional(model):
if mpu.is_pipeline_last_stage() and \
mpu.get_tensor_model_parallel_rank() == 0:
f.write(json.dumps(datum) + '\n') f.write(json.dumps(datum) + '\n')
...@@ -313,7 +378,10 @@ def get_token_stream(model, context_tokens): ...@@ -313,7 +378,10 @@ def get_token_stream(model, context_tokens):
attention_mask, position_ids) attention_mask, position_ids)
for tokens, lengths in batch_token_iterator: for tokens, lengths in batch_token_iterator:
context_length += 1 context_length += 1
if tokens is not None:
yield tokens[:, :context_length], lengths yield tokens[:, :context_length], lengths
else:
yield None, None
def switch(val1, val2, boolean): def switch(val1, val2, boolean):
...@@ -322,6 +390,60 @@ def switch(val1, val2, boolean): ...@@ -322,6 +390,60 @@ def switch(val1, val2, boolean):
return (1 - boolean) * val1 + boolean * val2 return (1 - boolean) * val1 + boolean * val2
def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
layer_past=None, get_key_value=None,
forward_method_parallel_output=None):
if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else:
input_tensor = None
# Forward pass through the model.
if mpu.is_pipeline_first_stage():
assert input_tensor is None
if mpu.is_pipeline_last_stage():
output_tensor = model(tokens, position_ids, attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
else:
output_tensor = model(tokens, position_ids, attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value)
elif mpu.is_pipeline_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
if get_key_value:
output_tensor, layer_past = output_tensor
if not mpu.is_pipeline_last_stage():
communicate(tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
return None
if get_key_value:
return output_tensor, layer_past
return output_tensor
def sample_sequence_batch(model, context_tokens, context_lengths, def sample_sequence_batch(model, context_tokens, context_lengths,
attention_mask, position_ids, attention_mask, position_ids,
maxlen=None, type_ids=None): maxlen=None, type_ids=None):
...@@ -349,14 +471,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -349,14 +471,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
lengths = torch.ones([batch_size]).long().cuda() * maxlen lengths = torch.ones([batch_size]).long().cuda() * maxlen
while context_length <= (maxlen): while context_length <= (maxlen):
if args.recompute: if args.recompute:
logits = model(tokens, output = forward_step(model, tokens,
position_ids, position_ids,
attention_mask, attention_mask,
tokentype_ids=type_ids, tokentype_ids=type_ids,
forward_method_parallel_output=False) forward_method_parallel_output=False)
logits = logits[:, context_length - 1, :] if mpu.is_pipeline_last_stage():
assert output is not None
logits = output[:, context_length - 1, :]
else: else:
types2use = None types2use = None
if counter == 0: if counter == 0:
...@@ -372,15 +495,18 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -372,15 +495,18 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if type_ids is not None: if type_ids is not None:
types2use = type_ids[:, context_length - 1].view( types2use = type_ids[:, context_length - 1].view(
batch_size, -1) batch_size, -1)
logits, layer_past = model(tokens2use, logits, layer_past = forward_step(model, tokens2use,
positions2use, positions2use,
attention_mask, attention_mask,
layer_past=layer_past, layer_past=layer_past,
get_key_value=True, get_key_value=True,
tokentype_ids=types2use, tokentype_ids=types2use,
forward_method_parallel_output=False) forward_method_parallel_output=False)
if mpu.is_pipeline_last_stage():
assert output is not None
logits = logits[:, -1].view(batch_size, -1).contiguous() logits = logits[:, -1].view(batch_size, -1).contiguous()
if mpu.is_pipeline_last_stage():
if args.greedy: if args.greedy:
prev = torch.argmax(logits, dim=-1).view(-1) prev = torch.argmax(logits, dim=-1).view(-1)
else: else:
...@@ -391,22 +517,43 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -391,22 +517,43 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
log_probs = F.softmax(logits, dim=-1) log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1) prev = torch.multinomial(log_probs, num_samples=1).view(-1)
print_logits = []
for p in prev:
print_logits.append([logits[i, p].item()
for i in range(batch_size)])
started = context_lengths <= context_length started = context_lengths <= context_length
tokens[:, context_length] = switch(
new_tokens = switch(
tokens[:, context_length].view(-1), prev, started) tokens[:, context_length].view(-1), prev, started)
context_length += 1 tokens[:, context_length] = new_tokens
counter += 1 src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
torch.distributed.broadcast(new_tokens, src, group)
done_token = (prev == eos_id).byte() & started.byte() done_token = (prev == eos_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool() just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length lengths[just_finished.view(-1)] = context_length
is_done = is_done | done_token is_done = is_done | done_token
done = torch.all(is_done)
done = torch.all(is_done)
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
yield tokens, lengths yield tokens, lengths
else:
if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
new_tokens = torch.empty_like(tokens[:, context_length])
torch.distributed.broadcast(new_tokens, src, group)
tokens[:, context_length] = new_tokens
yield tokens, None
else:
yield None, None
done = torch.cuda.ByteTensor([0])
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
context_length += 1
counter += 1
if done: if done:
break break
...@@ -23,9 +23,10 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ...@@ -23,9 +23,10 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.model import GPT2Model from megatron.model import GPT2Model, GPT2ModelFirstStage, GPT2ModelLastStage, GPT2ModelIntermediateStage
from megatron.training import get_model from megatron.training import get_model
from megatron.text_generation_utils import generate_and_write_samples_unconditional from megatron.text_generation_utils import generate_and_write_samples_unconditional
from megatron.text_generation_utils import generate_samples_input_from_file from megatron.text_generation_utils import generate_samples_input_from_file
...@@ -36,6 +37,18 @@ def model_provider(): ...@@ -36,6 +37,18 @@ def model_provider():
"""Build the model.""" """Build the model."""
print_rank_0('building GPT2 model ...') print_rank_0('building GPT2 model ...')
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = GPT2ModelFirstStage(num_tokentypes=0)
elif mpu.is_pipeline_last_stage():
model = GPT2ModelLastStage(
num_tokentypes=0, parallel_output=False)
else:
model = GPT2ModelIntermediateStage(
num_tokentypes=0)
else:
model = GPT2Model(num_tokentypes=0, parallel_output=False) model = GPT2Model(num_tokentypes=0, parallel_output=False)
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment