Unverified Commit 222dbdb2 authored by Jethro Kuan's avatar Jethro Kuan Committed by GitHub
Browse files

allow integer device for BatchEncoding (#9271)



Fixes #9244
Co-authored-by: default avatarJethro Kuan <jethro.kuan@bytedance.com>
parent 6c091abe
...@@ -781,7 +781,7 @@ class BatchEncoding(UserDict): ...@@ -781,7 +781,7 @@ class BatchEncoding(UserDict):
# This check catches things like APEX blindly calling "to" on all inputs to a module # This check catches things like APEX blindly calling "to" on all inputs to a module
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs # Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor # into a HalfTensor
if isinstance(device, str) or isinstance(device, torch.device): if isinstance(device, str) or isinstance(device, torch.device) or isinstance(device, int):
self.data = {k: v.to(device=device) for k, v in self.data.items()} self.data = {k: v.to(device=device) for k, v in self.data.items()}
else: else:
logger.warning( logger.warning(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment