Unverified Commit 344b9fb0 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Limit the use of PreTrainedModel.device (#16935)

* Limit the use of PreTrainedModel.device

* Fix
parent 65687520
...@@ -502,7 +502,7 @@ class GenerationMixin: ...@@ -502,7 +502,7 @@ class GenerationMixin:
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
return inputs.ne(pad_token_id).long() return inputs.ne(pad_token_id).long()
else: else:
return torch.ones(inputs.shape[:2], dtype=torch.long, device=self.device) return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
def _prepare_encoder_decoder_kwargs_for_generation( def _prepare_encoder_decoder_kwargs_for_generation(
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
...@@ -532,13 +532,16 @@ class GenerationMixin: ...@@ -532,13 +532,16 @@ class GenerationMixin:
decoder_start_token_id: int = None, decoder_start_token_id: int = None,
bos_token_id: int = None, bos_token_id: int = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None, model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
device: torch.device = None,
) -> torch.LongTensor: ) -> torch.LongTensor:
if model_kwargs is not None and "decoder_input_ids" in model_kwargs: if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
return model_kwargs.pop("decoder_input_ids") return model_kwargs.pop("decoder_input_ids")
else: else:
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * decoder_start_token_id if device is None:
device = self.device
return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
decoder_start_token_id = ( decoder_start_token_id = (
...@@ -1177,6 +1180,7 @@ class GenerationMixin: ...@@ -1177,6 +1180,7 @@ class GenerationMixin:
decoder_start_token_id=decoder_start_token_id, decoder_start_token_id=decoder_start_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
device=inputs_tensor.device,
) )
else: else:
# if decoder-only then inputs_tensor has to be `input_ids` # if decoder-only then inputs_tensor has to be `input_ids`
...@@ -1327,7 +1331,7 @@ class GenerationMixin: ...@@ -1327,7 +1331,7 @@ class GenerationMixin:
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
num_beams=num_beams, num_beams=num_beams,
device=self.device, device=inputs_tensor.device,
length_penalty=length_penalty, length_penalty=length_penalty,
do_early_stopping=early_stopping, do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences, num_beam_hyps_to_keep=num_return_sequences,
...@@ -1367,7 +1371,7 @@ class GenerationMixin: ...@@ -1367,7 +1371,7 @@ class GenerationMixin:
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size * num_return_sequences, batch_size=batch_size * num_return_sequences,
num_beams=num_beams, num_beams=num_beams,
device=self.device, device=inputs_tensor.device,
length_penalty=length_penalty, length_penalty=length_penalty,
do_early_stopping=early_stopping, do_early_stopping=early_stopping,
) )
...@@ -1410,7 +1414,7 @@ class GenerationMixin: ...@@ -1410,7 +1414,7 @@ class GenerationMixin:
batch_size=batch_size, batch_size=batch_size,
num_beams=num_beams, num_beams=num_beams,
max_length=stopping_criteria.max_length, max_length=stopping_criteria.max_length,
device=self.device, device=inputs_tensor.device,
length_penalty=length_penalty, length_penalty=length_penalty,
do_early_stopping=early_stopping, do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences, num_beam_hyps_to_keep=num_return_sequences,
...@@ -1492,7 +1496,7 @@ class GenerationMixin: ...@@ -1492,7 +1496,7 @@ class GenerationMixin:
constraints=final_constraints, constraints=final_constraints,
batch_size=batch_size, batch_size=batch_size,
num_beams=num_beams, num_beams=num_beams,
device=self.device, device=inputs_tensor.device,
length_penalty=length_penalty, length_penalty=length_penalty,
do_early_stopping=early_stopping, do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences, num_beam_hyps_to_keep=num_return_sequences,
......
...@@ -1157,7 +1157,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1157,7 +1157,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Build new embeddings # Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim) new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
new_embeddings.to(self.device, dtype=old_embeddings.weight.dtype) new_embeddings.to(old_embeddings.weight.device, dtype=old_embeddings.weight.dtype)
# initialize all new embeddings (in particular added tokens) # initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings) self._init_weights(new_embeddings)
...@@ -1228,7 +1228,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1228,7 +1228,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim) new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
has_new_lm_head_bias = old_lm_head.bias is not None has_new_lm_head_bias = old_lm_head.bias is not None
new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias) new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias)
new_lm_head = new_lm_head.to(self.device, dtype=old_lm_head.weight.dtype) new_lm_head = new_lm_head.to(old_lm_head.weight.device, dtype=old_lm_head.weight.dtype)
# initialize new lm head (in particular added tokens) # initialize new lm head (in particular added tokens)
self._init_weights(new_lm_head) self._init_weights(new_lm_head)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment