"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "f284089ec4cf6da2d659512b09f2cf526dc6b308"
Unverified Commit d99f11e8 authored by Xinyu Yang's avatar Xinyu Yang Committed by GitHub
Browse files

ensure banned_mask and indices in same device (#23901)

* ensure banned_mask and indices in same device

* ensure banned_mask and indices in same device

switch the order in which indices and banned_mask are created and create banned_mask on the proper device
parent d68d6665
...@@ -649,8 +649,8 @@ class NoBadWordsLogitsProcessor(LogitsProcessor): ...@@ -649,8 +649,8 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
else: else:
if banned_mask_list: if banned_mask_list:
banned_mask = torch.LongTensor(banned_mask_list) indices = torch.ones(len(banned_mask_list))
indices = torch.ones(len(banned_mask)) banned_mask = torch.LongTensor(banned_mask_list, device=indices.device)
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates: # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
# [ 0 1 1 ] # [ 0 1 1 ]
# [ 0 0 0 ] # [ 0 0 0 ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment