Commit 3d718bfc authored by rprenger's avatar rprenger
Browse files

Fixing merge conflict

parent 9939fb58
...@@ -175,21 +175,11 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len ...@@ -175,21 +175,11 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len
if tokens is not None: if tokens is not None:
return tokens[:, :context_length], output_logits, full_logits return tokens[:, :context_length], output_logits, full_logits
<<<<<<< HEAD
def generate(model, sentences=None, max_len=0, all_probs=False): def generate(model, sentences=None, max_len=0, all_probs=False):
if torch.distributed.get_rank() == 0:
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences)
c = context_length_tensor[0]
b = context_tokens_tensor.size(0)
start = time.time()
send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs)
=======
def generate(model, sentences=None, max_len=0):
model.eval() model.eval()
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences) context_tokens_tensor, context_length_tensor = tokenize_batch(sentences)
send_generate_info(context_tokens_tensor, context_length_tensor, max_len) send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs)
>>>>>>> server
else: else:
context_length_tensor, context_tokens_tensor, max_len, all_probs = receive_generate_info() context_length_tensor, context_tokens_tensor, max_len, all_probs = receive_generate_info()
...@@ -206,7 +196,6 @@ def generate(model, sentences=None, max_len=0): ...@@ -206,7 +196,6 @@ def generate(model, sentences=None, max_len=0):
decode_tokens = decode_tokens.cpu().numpy().tolist() decode_tokens = decode_tokens.cpu().numpy().tolist()
for decode_token in decode_tokens: for decode_token in decode_tokens:
resp_sentences.append(tokenizer.detokenize(decode_token)) resp_sentences.append(tokenizer.detokenize(decode_token))
<<<<<<< HEAD
words = [] words = []
for token in decode_token: for token in decode_token:
word = tokenizer.tokenizer.decoder[token] word = tokenizer.tokenizer.decoder[token]
...@@ -218,12 +207,7 @@ def generate(model, sentences=None, max_len=0): ...@@ -218,12 +207,7 @@ def generate(model, sentences=None, max_len=0):
if all_probs: if all_probs:
full_logits = full_logits.cpu().numpy().tolist() full_logits = full_logits.cpu().numpy().tolist()
end = time.time()
print(str(b)+","+str(c)+","+str(len(decode_tokens[0]))+","+str(end-start), flush=True)
return resp_sentences, resp_sentences_seg, output_logits, full_logits, decode_tokens return resp_sentences, resp_sentences_seg, output_logits, full_logits, decode_tokens
=======
return resp_sentences
>>>>>>> server
def generate_samples_eval(model, context, max_gen_length, eos_token_id): def generate_samples_eval(model, context, max_gen_length, eos_token_id):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment