Commit b46482e8 authored by rprenger's avatar rprenger
Browse files

Fixes a bug in broadcasting that was causing hanging

parent 593b47b4
......@@ -30,9 +30,7 @@ class MegatronGenerate(Resource):
@staticmethod
def send_do_generate():
choice = torch.cuda.LongTensor([GENERATE_NUM])
torch.distributed.broadcast(choice,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
torch.distributed.broadcast(choice, 0)
def put(self):
args = get_args()
......
......@@ -141,7 +141,6 @@ def receive_generate_info():
def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs):
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
context_length_tensor,
attention_mask, position_ids,
......@@ -172,7 +171,6 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
group = mpu.get_embedding_group()
full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda"))
torch.distributed.broadcast(full_logits, src, group)
if tokens is not None:
return tokens[:, :context_length], output_logits, full_logits
......@@ -310,7 +308,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if type_ids is not None:
types2use = type_ids[:, context_length - 1].view(
batch_size, -1)
output, layer_past = forward_step(model, tokens2use,
positions2use,
attention_mask,
......@@ -332,7 +329,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
top_p=args.top_p)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
started = context_lengths <= context_length
new_tokens = switch(
......
......@@ -78,8 +78,6 @@ if __name__ == "__main__":
while True:
choice = torch.cuda.LongTensor(1)
torch.distributed.broadcast(choice,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
torch.distributed.broadcast(choice, 0)
if choice[0].item() == 0:
generate(model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment