"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "f85acb4d73a84fe9bee5279068b0430fc391fb36"
Unverified Commit bc72b4e2 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: fix candidate device placement (#28493)

* fix candidate device

* this line shouldn't have been in
parent e304f976
......@@ -169,6 +169,8 @@ class AssistedCandidateGenerator(CandidateGenerator):
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate.
"""
input_ids = input_ids.to(self.assistant_model.device)
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round)
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
......
......@@ -4591,11 +4591,10 @@ class GenerationMixin:
cur_len = input_ids.shape[-1]
# 1. Fetch candidate sequences from a `CandidateGenerator`
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(
input_ids.to(candidate_generator.assistant_model.device)
)
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
candidate_input_ids = candidate_input_ids.to(self.device)
candidate_logits = candidate_logits.to(self.device)
if candidate_logits is not None:
candidate_logits = candidate_logits.to(self.device)
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
last_assistant_token_is_eos = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment