Commit 1d4e8760 authored by Jared Casper's avatar Jared Casper Committed by Deepak Narayanan
Browse files

Fix text generation without recompute

parent 2623551d
...@@ -138,23 +138,23 @@ def generate_samples_input_from_file(model): ...@@ -138,23 +138,23 @@ def generate_samples_input_from_file(model):
group=mpu.get_model_parallel_group()) group=mpu.get_model_parallel_group())
terminate_runs = input_info_tensor[0].item() terminate_runs = input_info_tensor[0].item()
raw_text_len = input_info_tensor[1].item() raw_text_len = input_info_tensor[1].item()
context_length = input_info_tensor[2].item()
if terminate_runs == 1: if terminate_runs == 1:
return return
# For pipeline parallel we send context tokens to last stage # For pipeline parallel we send context tokens to other stages
# so it knows when to start overwriting # so they get the lengths correct
if mpu.get_tensor_model_parallel_rank() == 0 \ if mpu.get_tensor_model_parallel_rank() == 0 \
and args.pipeline_model_parallel_size > 1: and args.pipeline_model_parallel_size > 1:
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_first_rank() src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_embedding_group() group = mpu.get_pipeline_model_parallel_group()
context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
torch.distributed.broadcast(context_tokens_tensor, src, group) torch.distributed.broadcast(context_tokens_tensor, src, group)
if mpu.is_pipeline_last_stage(): else:
src = mpu.get_pipeline_model_parallel_first_rank() src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_embedding_group() group = mpu.get_pipeline_model_parallel_group()
context_length = input_info_tensor[2].item()
context_tokens_tensor = torch.empty(context_length, context_tokens_tensor = torch.empty(context_length,
dtype=torch.int64, dtype=torch.int64,
device=torch.device("cuda")) device=torch.device("cuda"))
...@@ -229,23 +229,23 @@ def generate_samples_interactive(model, print_frequency=24): ...@@ -229,23 +229,23 @@ def generate_samples_interactive(model, print_frequency=24):
group=mpu.get_model_parallel_group()) group=mpu.get_model_parallel_group())
terminate_runs = input_info_tensor[0].item() terminate_runs = input_info_tensor[0].item()
raw_text_len = input_info_tensor[1].item() raw_text_len = input_info_tensor[1].item()
context_length = input_info_tensor[2].item()
if terminate_runs == 1: if terminate_runs == 1:
return return
# For pipeline parallel we send context tokens to last stage # For pipeline parallel we send context tokens to other stages
# so it knows when to start overwriting # so they get the lengths correct
if mpu.get_tensor_model_parallel_rank() == 0 \ if mpu.get_tensor_model_parallel_rank() == 0 \
and args.pipeline_model_parallel_size > 1: and args.pipeline_model_parallel_size > 1:
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_first_rank() src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_embedding_group() group = mpu.get_pipeline_model_parallel_group()
context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
torch.distributed.broadcast(context_tokens_tensor, src, group) torch.distributed.broadcast(context_tokens_tensor, src, group)
if mpu.is_pipeline_last_stage(): else:
src = mpu.get_pipeline_model_parallel_first_rank() src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_embedding_group() group = mpu.get_pipeline_model_parallel_group()
context_length = input_info_tensor[2].item()
context_tokens_tensor = torch.empty(context_length, context_tokens_tensor = torch.empty(context_length,
dtype=torch.int64, dtype=torch.int64,
device=torch.device("cuda")) device=torch.device("cuda"))
...@@ -253,6 +253,7 @@ def generate_samples_interactive(model, print_frequency=24): ...@@ -253,6 +253,7 @@ def generate_samples_interactive(model, print_frequency=24):
context_tokens = context_tokens_tensor.cpu().numpy().tolist() context_tokens = context_tokens_tensor.cpu().numpy().tolist()
token_stream = get_token_stream(model, [context_tokens]) token_stream = get_token_stream(model, [context_tokens])
for counter, decode_tokens in enumerate(token_stream): for counter, decode_tokens in enumerate(token_stream):
if counter % print_frequency != 0 \ if counter % print_frequency != 0 \
or mpu.get_tensor_model_parallel_rank() != 0 \ or mpu.get_tensor_model_parallel_rank() != 0 \
...@@ -394,6 +395,12 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, ...@@ -394,6 +395,12 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
layer_past=None, get_key_value=None, layer_past=None, get_key_value=None,
forward_method_parallel_output=None): forward_method_parallel_output=None):
# Hidden size changes when not using recompute, need to tell communicate()
# the correct size
args = get_args()
orig_seq_length = args.seq_length
args.seq_length = tokens.shape[1]
if not mpu.is_pipeline_first_stage(): if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate( input_tensor, _ = communicate(
tensor_send_next=None, tensor_send_next=None,
...@@ -437,8 +444,8 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, ...@@ -437,8 +444,8 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
tensor_send_prev=None, tensor_send_prev=None,
recv_forward=False, recv_forward=False,
recv_backward=False) recv_backward=False)
return None
args.seq_length = orig_seq_length
if get_key_value: if get_key_value:
return output_tensor, layer_past return output_tensor, layer_past
return output_tensor return output_tensor
...@@ -495,7 +502,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -495,7 +502,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if type_ids is not None: if type_ids is not None:
types2use = type_ids[:, context_length - 1].view( types2use = type_ids[:, context_length - 1].view(
batch_size, -1) batch_size, -1)
logits, layer_past = forward_step(model, tokens2use, output, layer_past = forward_step(model, tokens2use,
positions2use, positions2use,
attention_mask, attention_mask,
layer_past=layer_past, layer_past=layer_past,
...@@ -504,7 +511,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -504,7 +511,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
forward_method_parallel_output=False) forward_method_parallel_output=False)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
assert output is not None assert output is not None
logits = logits[:, -1].view(batch_size, -1).contiguous() logits = output[:, -1].view(batch_size, -1).contiguous()
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
if args.greedy: if args.greedy:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment