Commit 87d08f4b authored by rprenger's avatar rprenger
Browse files

Fixing a bug in merge

parent 803ae5ee
...@@ -93,7 +93,7 @@ def generate(model, ...@@ -93,7 +93,7 @@ def generate(model,
# Make sure input params are avaialble to all ranks. # Make sure input params are avaialble to all ranks.
values = [tokens_to_generate, values = [tokens_to_generate,
return_output_log_probs, return_output_log_probs,
greedy_sampling, top_k_sampling, top_p_sampling, top_k_sampling, top_p_sampling,
temperature, add_BOS, use_eod_token_for_early_termination, just_score] temperature, add_BOS, use_eod_token_for_early_termination, just_score]
values_float_tensor = broadcast_float_list(8, float_list=values) values_float_tensor = broadcast_float_list(8, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item()) tokens_to_generate = int(values_float_tensor[0].item())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment