Commit 6708bbcb authored by Casper Hansen's avatar Casper Hansen
Browse files

Fix #69 - sanitize NaN and infs.

parent 321d74ff
...@@ -30,6 +30,14 @@ def prepare_logits_processor( ...@@ -30,6 +30,14 @@ def prepare_logits_processor(
processor_list.append(TopKLogitsWarper(top_k)) processor_list.append(TopKLogitsWarper(top_k))
return processor_list return processor_list
def sanitize_tensor(tensor):
# Replace positive infinity with a large finite number
tensor[tensor == float('inf')] = 1e20
# Replace negative infinity with a small finite number
tensor[tensor == float('-inf')] = -1e20
# Replace NaNs with zero
tensor[torch.isnan(tensor)] = 0.0
return tensor
@torch.inference_mode() @torch.inference_mode()
def StreamGenerator(model, def StreamGenerator(model,
...@@ -82,6 +90,7 @@ def StreamGenerator(model, ...@@ -82,6 +90,7 @@ def StreamGenerator(model,
else: else:
tmp_output_ids = None tmp_output_ids = None
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
last_token_logits = sanitize_tensor(last_token_logits)
else: else:
last_token_logits = logits[0, -1, :] last_token_logits = logits[0, -1, :]
if gen_params.temp < 1e-5 or gen_params.top_p < 1e-8: # greedy if gen_params.temp < 1e-5 or gen_params.top_p < 1e-8: # greedy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment