Unverified Commit bc72b4e2 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: fix candidate device placement (#28493)

* fix candidate device

* this line shouldn't have been in
parent e304f976
...@@ -169,6 +169,8 @@ class AssistedCandidateGenerator(CandidateGenerator): ...@@ -169,6 +169,8 @@ class AssistedCandidateGenerator(CandidateGenerator):
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length, assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate. vocabulary_size)` containing the logits associated to each candidate.
""" """
input_ids = input_ids.to(self.assistant_model.device)
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length # 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round) # (which implicitly contains the number of accepted candidates from the previous round)
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
......
...@@ -4591,11 +4591,10 @@ class GenerationMixin: ...@@ -4591,11 +4591,10 @@ class GenerationMixin:
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
# 1. Fetch candidate sequences from a `CandidateGenerator` # 1. Fetch candidate sequences from a `CandidateGenerator`
candidate_input_ids, candidate_logits = candidate_generator.get_candidates( candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
input_ids.to(candidate_generator.assistant_model.device)
)
candidate_input_ids = candidate_input_ids.to(self.device) candidate_input_ids = candidate_input_ids.to(self.device)
candidate_logits = candidate_logits.to(self.device) if candidate_logits is not None:
candidate_logits = candidate_logits.to(self.device)
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
last_assistant_token_is_eos = ( last_assistant_token_is_eos = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment