Commit b53a9be2 authored by Casper Hansen's avatar Casper Hansen
Browse files

Implemented fp16 and fp32 sanitization.

parent 6708bbcb
...@@ -30,13 +30,21 @@ def prepare_logits_processor( ...@@ -30,13 +30,21 @@ def prepare_logits_processor(
processor_list.append(TopKLogitsWarper(top_k)) processor_list.append(TopKLogitsWarper(top_k))
return processor_list return processor_list
def sanitize_tensor(tensor): def sanitize_tensor(tensor: torch.Tensor):
if tensor.dtype == torch.float16:
replacement_value = 65504
elif tensor.dtype == torch.float32:
replacement_value = 1e20
else:
return tensor
# Replace positive infinity with a large finite number # Replace positive infinity with a large finite number
tensor[tensor == float('inf')] = 1e20 tensor[tensor == float('inf')] = replacement_value
# Replace negative infinity with a small finite number # Replace negative infinity with a small finite number
tensor[tensor == float('-inf')] = -1e20 tensor[tensor == float('-inf')] = -replacement_value
# Replace NaNs with zero # Replace NaNs with zero
tensor[torch.isnan(tensor)] = 0.0 tensor[torch.isnan(tensor)] = 0.0
return tensor return tensor
@torch.inference_mode() @torch.inference_mode()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment