Commit 1d4e8760 authored by Jared Casper's avatar Jared Casper Committed by Deepak Narayanan
Browse files

Fix text generation without recompute

parent 2623551d
......@@ -138,23 +138,23 @@ def generate_samples_input_from_file(model):
group=mpu.get_model_parallel_group())
terminate_runs = input_info_tensor[0].item()
raw_text_len = input_info_tensor[1].item()
context_length = input_info_tensor[2].item()
if terminate_runs == 1:
return
# For pipeline parallel we send context tokens to last stage
# so it knows when to start overwriting
# For pipeline parallel we send context tokens to other stages
# so they get the lengths correct
if mpu.get_tensor_model_parallel_rank() == 0 \
and args.pipeline_model_parallel_size > 1:
if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_embedding_group()
group = mpu.get_pipeline_model_parallel_group()
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
torch.distributed.broadcast(context_tokens_tensor, src, group)
if mpu.is_pipeline_last_stage():
else:
src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_embedding_group()
context_length = input_info_tensor[2].item()
group = mpu.get_pipeline_model_parallel_group()
context_tokens_tensor = torch.empty(context_length,
dtype=torch.int64,
device=torch.device("cuda"))
......@@ -229,23 +229,23 @@ def generate_samples_interactive(model, print_frequency=24):
group=mpu.get_model_parallel_group())
terminate_runs = input_info_tensor[0].item()
raw_text_len = input_info_tensor[1].item()
context_length = input_info_tensor[2].item()
if terminate_runs == 1:
return
# For pipeline parallel we send context tokens to last stage
# so it knows when to start overwriting
# For pipeline parallel we send context tokens to other stages
# so they get the lengths correct
if mpu.get_tensor_model_parallel_rank() == 0 \
and args.pipeline_model_parallel_size > 1:
if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_embedding_group()
group = mpu.get_pipeline_model_parallel_group()
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
torch.distributed.broadcast(context_tokens_tensor, src, group)
if mpu.is_pipeline_last_stage():
else:
src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_embedding_group()
context_length = input_info_tensor[2].item()
group = mpu.get_pipeline_model_parallel_group()
context_tokens_tensor = torch.empty(context_length,
dtype=torch.int64,
device=torch.device("cuda"))
......@@ -253,6 +253,7 @@ def generate_samples_interactive(model, print_frequency=24):
context_tokens = context_tokens_tensor.cpu().numpy().tolist()
token_stream = get_token_stream(model, [context_tokens])
for counter, decode_tokens in enumerate(token_stream):
if counter % print_frequency != 0 \
or mpu.get_tensor_model_parallel_rank() != 0 \
......@@ -394,6 +395,12 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
layer_past=None, get_key_value=None,
forward_method_parallel_output=None):
# Hidden size changes when not using recompute, need to tell communicate()
# the correct size
args = get_args()
orig_seq_length = args.seq_length
args.seq_length = tokens.shape[1]
if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
......@@ -437,8 +444,8 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
return None
args.seq_length = orig_seq_length
if get_key_value:
return output_tensor, layer_past
return output_tensor
......@@ -495,7 +502,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if type_ids is not None:
types2use = type_ids[:, context_length - 1].view(
batch_size, -1)
logits, layer_past = forward_step(model, tokens2use,
output, layer_past = forward_step(model, tokens2use,
positions2use,
attention_mask,
layer_past=layer_past,
......@@ -504,7 +511,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
forward_method_parallel_output=False)
if mpu.is_pipeline_last_stage():
assert output is not None
logits = logits[:, -1].view(batch_size, -1).contiguous()
logits = output[:, -1].view(batch_size, -1).contiguous()
if mpu.is_pipeline_last_stage():
if args.greedy:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment