Unverified Commit c1cda0ee authored by Dingli Yang's avatar Dingli Yang Committed by GitHub
Browse files

Fix Seq2SeqTrainer crash when BatchEncoding data is None (#31418)

avoiding crash when BatchEncoding data is None
parent 06fd7972
...@@ -800,7 +800,7 @@ class BatchEncoding(UserDict): ...@@ -800,7 +800,7 @@ class BatchEncoding(UserDict):
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs # Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor # into a HalfTensor
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int): if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
self.data = {k: v.to(device=device) for k, v in self.data.items()} self.data = {k: v.to(device=device) for k, v in self.data.items() if v is not None}
else: else:
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.") logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
return self return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment