Unverified Commit 8205b264 authored by jiqing-feng's avatar jiqing-feng Committed by GitHub
Browse files

Assitant model may on a different device (#27995)

* Assitant model may on a different device

* fix tensor device
parent cbbe3074
...@@ -96,6 +96,11 @@ class AssistedCandidateGenerator(CandidateGenerator): ...@@ -96,6 +96,11 @@ class AssistedCandidateGenerator(CandidateGenerator):
model_kwargs: Dict, model_kwargs: Dict,
inputs_tensor: Optional[torch.Tensor] = None, inputs_tensor: Optional[torch.Tensor] = None,
): ):
# Make sure all data at the same device as assistant model
device = assistant_model.device
input_ids = input_ids.to(device)
inputs_tensor = inputs_tensor.to(device)
# Prepare the assistant and the starting number of candidate tokens # Prepare the assistant and the starting number of candidate tokens
self.assistant_model = assistant_model self.assistant_model = assistant_model
self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
...@@ -104,7 +109,9 @@ class AssistedCandidateGenerator(CandidateGenerator): ...@@ -104,7 +109,9 @@ class AssistedCandidateGenerator(CandidateGenerator):
assistant_kwargs = {} assistant_kwargs = {}
for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads
if key not in ("encoder_outputs", "assistant_encoder_outputs"): if key not in ("encoder_outputs", "assistant_encoder_outputs"):
assistant_kwargs[key] = value.detach() if isinstance(value, torch.Tensor) else copy.deepcopy(value) assistant_kwargs[key] = (
value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value)
)
if "assistant_encoder_outputs" in model_kwargs: if "assistant_encoder_outputs" in model_kwargs:
assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"] assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
......
...@@ -4585,7 +4585,12 @@ class GenerationMixin: ...@@ -4585,7 +4585,12 @@ class GenerationMixin:
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
# 1. Fetch candidate sequences from a `CandidateGenerator` # 1. Fetch candidate sequences from a `CandidateGenerator`
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) candidate_input_ids, candidate_logits = candidate_generator.get_candidates(
input_ids.to(candidate_generator.assistant_model.device)
)
candidate_input_ids = candidate_input_ids.to(self.device)
candidate_logits = candidate_logits.to(self.device)
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
last_assistant_token_is_eos = ( last_assistant_token_is_eos = (
~candidate_input_ids[:, -1] ~candidate_input_ids[:, -1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment