Commit a97d676b authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'fix_initial_broadcasting' into 'main'

Fixes a bug in broadcasting that was causing hanging

See merge request ADLR/megatron-lm!327
parents 230633f8 b46482e8
...@@ -30,9 +30,7 @@ class MegatronGenerate(Resource): ...@@ -30,9 +30,7 @@ class MegatronGenerate(Resource):
@staticmethod @staticmethod
def send_do_generate(): def send_do_generate():
choice = torch.cuda.LongTensor([GENERATE_NUM]) choice = torch.cuda.LongTensor([GENERATE_NUM])
torch.distributed.broadcast(choice, torch.distributed.broadcast(choice, 0)
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
def put(self): def put(self):
args = get_args() args = get_args()
......
...@@ -141,7 +141,6 @@ def receive_generate_info(): ...@@ -141,7 +141,6 @@ def receive_generate_info():
def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs): def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs):
context_length = context_length_tensor.min().item() context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
context_length_tensor, context_length_tensor,
attention_mask, position_ids, attention_mask, position_ids,
...@@ -172,7 +171,6 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_ ...@@ -172,7 +171,6 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
group = mpu.get_embedding_group() group = mpu.get_embedding_group()
full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda")) full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda"))
torch.distributed.broadcast(full_logits, src, group) torch.distributed.broadcast(full_logits, src, group)
if tokens is not None: if tokens is not None:
return tokens[:, :context_length], output_logits, full_logits return tokens[:, :context_length], output_logits, full_logits
...@@ -310,7 +308,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -310,7 +308,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if type_ids is not None: if type_ids is not None:
types2use = type_ids[:, context_length - 1].view( types2use = type_ids[:, context_length - 1].view(
batch_size, -1) batch_size, -1)
output, layer_past = forward_step(model, tokens2use, output, layer_past = forward_step(model, tokens2use,
positions2use, positions2use,
attention_mask, attention_mask,
...@@ -332,7 +329,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -332,7 +329,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
top_p=args.top_p) top_p=args.top_p)
log_probs = F.softmax(logits, dim=-1) log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1) prev = torch.multinomial(log_probs, num_samples=1).view(-1)
started = context_lengths <= context_length started = context_lengths <= context_length
new_tokens = switch( new_tokens = switch(
......
...@@ -78,8 +78,6 @@ if __name__ == "__main__": ...@@ -78,8 +78,6 @@ if __name__ == "__main__":
while True: while True:
choice = torch.cuda.LongTensor(1) choice = torch.cuda.LongTensor(1)
torch.distributed.broadcast(choice, torch.distributed.broadcast(choice, 0)
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
if choice[0].item() == 0: if choice[0].item() == 0:
generate(model) generate(model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment