Commit b53a9be2 authored by Casper Hansen's avatar Casper Hansen
Browse files

Implemented fp16 and fp32 sanitization.

parent 6708bbcb
......@@ -30,13 +30,21 @@ def prepare_logits_processor(
processor_list.append(TopKLogitsWarper(top_k))
return processor_list
def sanitize_tensor(tensor):
def sanitize_tensor(tensor: torch.Tensor):
if tensor.dtype == torch.float16:
replacement_value = 65504
elif tensor.dtype == torch.float32:
replacement_value = 1e20
else:
return tensor
# Replace positive infinity with a large finite number
tensor[tensor == float('inf')] = 1e20
tensor[tensor == float('inf')] = replacement_value
# Replace negative infinity with a small finite number
tensor[tensor == float('-inf')] = -1e20
tensor[tensor == float('-inf')] = -replacement_value
# Replace NaNs with zero
tensor[torch.isnan(tensor)] = 0.0
return tensor
@torch.inference_mode()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment