Unverified Commit 7b0e2cfd authored by 浮躁的小螃蟹's avatar 浮躁的小螃蟹 Committed by GitHub
Browse files

Fix: unfinished_sequences with correct device (#22184)

Fix: unfinished_sequences with correct device 

The original code was causing errors when running torch.jit.trace due to the tensor options being incorrect. I fixed this by using torch.ones to create a tensor with the correct device and dtype. This should resolve the issue with running torch.jit.trace.
parent f7329751
...@@ -1805,7 +1805,7 @@ class GenerationMixin: ...@@ -1805,7 +1805,7 @@ class GenerationMixin:
) )
# keep track of which sequences are already finished # keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
this_peer_finished = False # used by synced_gpus only this_peer_finished = False # used by synced_gpus only
batch_size = input_ids.shape[0] batch_size = input_ids.shape[0]
...@@ -2180,7 +2180,7 @@ class GenerationMixin: ...@@ -2180,7 +2180,7 @@ class GenerationMixin:
) )
# keep track of which sequences are already finished # keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
this_peer_finished = False # used by synced_gpus only this_peer_finished = False # used by synced_gpus only
while True: while True:
...@@ -2446,7 +2446,7 @@ class GenerationMixin: ...@@ -2446,7 +2446,7 @@ class GenerationMixin:
) )
# keep track of which sequences are already finished # keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
this_peer_finished = False # used by synced_gpus only this_peer_finished = False # used by synced_gpus only
# auto-regressive generation # auto-regressive generation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment