Unverified Commit 9c18f156 authored by Adam Pocock's avatar Adam Pocock Committed by GitHub
Browse files

Prevent BatchEncoding from blindly passing casts down to the tensors it...


Prevent BatchEncoding from blindly passing casts down to the tensors it contains. Fixes #6582. (#8860)

Update src/transformers/tokenization_utils_base.py with review fix
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent c0df963e
...@@ -776,7 +776,16 @@ class BatchEncoding(UserDict): ...@@ -776,7 +776,16 @@ class BatchEncoding(UserDict):
:class:`~transformers.BatchEncoding`: The same instance of :class:`~transformers.BatchEncoding` after :class:`~transformers.BatchEncoding`: The same instance of :class:`~transformers.BatchEncoding` after
modification. modification.
""" """
self.data = {k: v.to(device) for k, v in self.data.items()}
# This check catches things like APEX blindly calling "to" on all inputs to a module
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor
if isinstance(device, str) or isinstance(device, torch.device):
self.data = {k: v.to(device=device) for k, v in self.data.items()}
else:
logger.warning(
f"Attempting to cast a BatchEncoding to another type, {str(device)}. This is not supported."
)
return self return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment