"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "13aa174112f0c2ee794c44188ecf13b241694db0"
Unverified Commit d99f11e8 authored by Xinyu Yang's avatar Xinyu Yang Committed by GitHub
Browse files

ensure banned_mask and indices in same device (#23901)

* ensure banned_mask and indices in same device

* ensure banned_mask and indices in same device

switch the order in which indices and banned_mask are created and create banned_mask on the proper device
parent d68d6665
......@@ -649,8 +649,8 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
else:
if banned_mask_list:
banned_mask = torch.LongTensor(banned_mask_list)
indices = torch.ones(len(banned_mask))
indices = torch.ones(len(banned_mask_list))
banned_mask = torch.LongTensor(banned_mask_list, device=indices.device)
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
# [ 0 1 1 ]
# [ 0 0 0 ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment