"examples/vscode:/vscode.git/clone" did not exist on "93c26d6330c94270c0f0e8c561dd4e7dd3c800d1"
Unverified Commit 344b9fb0 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Limit the use of PreTrainedModel.device (#16935)

* Limit the use of PreTrainedModel.device

* Fix
parent 65687520
......@@ -502,7 +502,7 @@ class GenerationMixin:
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
return inputs.ne(pad_token_id).long()
else:
return torch.ones(inputs.shape[:2], dtype=torch.long, device=self.device)
return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
def _prepare_encoder_decoder_kwargs_for_generation(
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
......@@ -532,13 +532,16 @@ class GenerationMixin:
decoder_start_token_id: int = None,
bos_token_id: int = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
device: torch.device = None,
) -> torch.LongTensor:
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
return model_kwargs.pop("decoder_input_ids")
else:
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * decoder_start_token_id
if device is None:
device = self.device
return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
decoder_start_token_id = (
......@@ -1177,6 +1180,7 @@ class GenerationMixin:
decoder_start_token_id=decoder_start_token_id,
bos_token_id=bos_token_id,
model_kwargs=model_kwargs,
device=inputs_tensor.device,
)
else:
# if decoder-only then inputs_tensor has to be `input_ids`
......@@ -1327,7 +1331,7 @@ class GenerationMixin:
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=self.device,
device=inputs_tensor.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences,
......@@ -1367,7 +1371,7 @@ class GenerationMixin:
beam_scorer = BeamSearchScorer(
batch_size=batch_size * num_return_sequences,
num_beams=num_beams,
device=self.device,
device=inputs_tensor.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
)
......@@ -1410,7 +1414,7 @@ class GenerationMixin:
batch_size=batch_size,
num_beams=num_beams,
max_length=stopping_criteria.max_length,
device=self.device,
device=inputs_tensor.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences,
......@@ -1492,7 +1496,7 @@ class GenerationMixin:
constraints=final_constraints,
batch_size=batch_size,
num_beams=num_beams,
device=self.device,
device=inputs_tensor.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences,
......
......@@ -1157,7 +1157,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
new_embeddings.to(self.device, dtype=old_embeddings.weight.dtype)
new_embeddings.to(old_embeddings.weight.device, dtype=old_embeddings.weight.dtype)
# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
......@@ -1228,7 +1228,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
has_new_lm_head_bias = old_lm_head.bias is not None
new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias)
new_lm_head = new_lm_head.to(self.device, dtype=old_lm_head.weight.dtype)
new_lm_head = new_lm_head.to(old_lm_head.weight.device, dtype=old_lm_head.weight.dtype)
# initialize new lm head (in particular added tokens)
self._init_weights(new_lm_head)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment