Commit a33e1b35 authored by rprenger's avatar rprenger
Browse files

Fixing bug where temperature was never actually broadcast

parent 5ab64637
......@@ -108,13 +108,13 @@ def tokenize_batch(sentences, max_len, add_BOS):
context_length_tensor = torch.cuda.LongTensor(context_lengths)
return context_tokens_tensor, context_length_tensor
def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs):
def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature):
"""
Needs to be synced up with receive_generate_info
"""
# Send the sizes of the tensors
input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, all_probs]
input_info_tensor = torch.cuda.LongTensor(input_info)
input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, all_probs, temperature]
input_info_tensor = torch.cuda.FloatTensor(input_info)
torch.distributed.broadcast(input_info_tensor, 0)
# Send variables to all ranks
......@@ -125,12 +125,13 @@ def receive_generate_info():
"""
Needs to be synced up with send_generate_info
"""
input_info_tensor = torch.empty(4, dtype=torch.int64, device=torch.cuda.current_device())
input_info_tensor = torch.empty(5, dtype=torch.float32, device=torch.cuda.current_device())
torch.distributed.broadcast(input_info_tensor, 0)
batch_size = input_info_tensor[0].item()
seq_len = input_info_tensor[1].item()
tokens_to_generate = input_info_tensor[2].item()
all_probs = input_info_tensor[3].item()
batch_size = int(input_info_tensor[0].item())
seq_len = int(input_info_tensor[1].item())
tokens_to_generate = int(input_info_tensor[2].item())
all_probs = int(input_info_tensor[3].item())
temperature = float(input_info_tensor[4].item())
context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device())
context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.cuda.current_device())
......@@ -139,7 +140,7 @@ def receive_generate_info():
torch.distributed.broadcast(context_length_tensor, 0)
torch.distributed.broadcast(context_tokens_tensor, 0)
return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs
return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs, temperature
def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature):
context_length = context_length_tensor.min().item()
......@@ -182,7 +183,7 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe
model.eval()
if torch.distributed.get_rank() == 0:
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate, add_BOS)
send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs)
send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature)
else:
context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment