Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
344b9fb0
Unverified
Commit
344b9fb0
authored
Apr 25, 2022
by
Sylvain Gugger
Committed by
GitHub
Apr 25, 2022
Browse files
Limit the use of PreTrainedModel.device (#16935)
* Limit the use of PreTrainedModel.device * Fix
parent
65687520
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
8 deletions
+12
-8
src/transformers/generation_utils.py
src/transformers/generation_utils.py
+10
-6
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+2
-2
No files found.
src/transformers/generation_utils.py
View file @
344b9fb0
...
@@ -502,7 +502,7 @@ class GenerationMixin:
...
@@ -502,7 +502,7 @@ class GenerationMixin:
if
is_input_ids
and
is_pad_token_in_inputs
and
is_pad_token_not_equal_to_eos_token_id
:
if
is_input_ids
and
is_pad_token_in_inputs
and
is_pad_token_not_equal_to_eos_token_id
:
return
inputs
.
ne
(
pad_token_id
).
long
()
return
inputs
.
ne
(
pad_token_id
).
long
()
else
:
else
:
return
torch
.
ones
(
inputs
.
shape
[:
2
],
dtype
=
torch
.
long
,
device
=
self
.
device
)
return
torch
.
ones
(
inputs
.
shape
[:
2
],
dtype
=
torch
.
long
,
device
=
inputs
.
device
)
def
_prepare_encoder_decoder_kwargs_for_generation
(
def
_prepare_encoder_decoder_kwargs_for_generation
(
self
,
inputs_tensor
:
torch
.
Tensor
,
model_kwargs
,
model_input_name
:
Optional
[
str
]
=
None
self
,
inputs_tensor
:
torch
.
Tensor
,
model_kwargs
,
model_input_name
:
Optional
[
str
]
=
None
...
@@ -532,13 +532,16 @@ class GenerationMixin:
...
@@ -532,13 +532,16 @@ class GenerationMixin:
decoder_start_token_id
:
int
=
None
,
decoder_start_token_id
:
int
=
None
,
bos_token_id
:
int
=
None
,
bos_token_id
:
int
=
None
,
model_kwargs
:
Optional
[
Dict
[
str
,
torch
.
Tensor
]]
=
None
,
model_kwargs
:
Optional
[
Dict
[
str
,
torch
.
Tensor
]]
=
None
,
device
:
torch
.
device
=
None
,
)
->
torch
.
LongTensor
:
)
->
torch
.
LongTensor
:
if
model_kwargs
is
not
None
and
"decoder_input_ids"
in
model_kwargs
:
if
model_kwargs
is
not
None
and
"decoder_input_ids"
in
model_kwargs
:
return
model_kwargs
.
pop
(
"decoder_input_ids"
)
return
model_kwargs
.
pop
(
"decoder_input_ids"
)
else
:
else
:
decoder_start_token_id
=
self
.
_get_decoder_start_token_id
(
decoder_start_token_id
,
bos_token_id
)
decoder_start_token_id
=
self
.
_get_decoder_start_token_id
(
decoder_start_token_id
,
bos_token_id
)
return
torch
.
ones
((
batch_size
,
1
),
dtype
=
torch
.
long
,
device
=
self
.
device
)
*
decoder_start_token_id
if
device
is
None
:
device
=
self
.
device
return
torch
.
ones
((
batch_size
,
1
),
dtype
=
torch
.
long
,
device
=
device
)
*
decoder_start_token_id
def
_get_decoder_start_token_id
(
self
,
decoder_start_token_id
:
int
=
None
,
bos_token_id
:
int
=
None
)
->
int
:
def
_get_decoder_start_token_id
(
self
,
decoder_start_token_id
:
int
=
None
,
bos_token_id
:
int
=
None
)
->
int
:
decoder_start_token_id
=
(
decoder_start_token_id
=
(
...
@@ -1177,6 +1180,7 @@ class GenerationMixin:
...
@@ -1177,6 +1180,7 @@ class GenerationMixin:
decoder_start_token_id
=
decoder_start_token_id
,
decoder_start_token_id
=
decoder_start_token_id
,
bos_token_id
=
bos_token_id
,
bos_token_id
=
bos_token_id
,
model_kwargs
=
model_kwargs
,
model_kwargs
=
model_kwargs
,
device
=
inputs_tensor
.
device
,
)
)
else
:
else
:
# if decoder-only then inputs_tensor has to be `input_ids`
# if decoder-only then inputs_tensor has to be `input_ids`
...
@@ -1327,7 +1331,7 @@ class GenerationMixin:
...
@@ -1327,7 +1331,7 @@ class GenerationMixin:
beam_scorer
=
BeamSearchScorer
(
beam_scorer
=
BeamSearchScorer
(
batch_size
=
batch_size
,
batch_size
=
batch_size
,
num_beams
=
num_beams
,
num_beams
=
num_beams
,
device
=
self
.
device
,
device
=
inputs_tensor
.
device
,
length_penalty
=
length_penalty
,
length_penalty
=
length_penalty
,
do_early_stopping
=
early_stopping
,
do_early_stopping
=
early_stopping
,
num_beam_hyps_to_keep
=
num_return_sequences
,
num_beam_hyps_to_keep
=
num_return_sequences
,
...
@@ -1367,7 +1371,7 @@ class GenerationMixin:
...
@@ -1367,7 +1371,7 @@ class GenerationMixin:
beam_scorer
=
BeamSearchScorer
(
beam_scorer
=
BeamSearchScorer
(
batch_size
=
batch_size
*
num_return_sequences
,
batch_size
=
batch_size
*
num_return_sequences
,
num_beams
=
num_beams
,
num_beams
=
num_beams
,
device
=
self
.
device
,
device
=
inputs_tensor
.
device
,
length_penalty
=
length_penalty
,
length_penalty
=
length_penalty
,
do_early_stopping
=
early_stopping
,
do_early_stopping
=
early_stopping
,
)
)
...
@@ -1410,7 +1414,7 @@ class GenerationMixin:
...
@@ -1410,7 +1414,7 @@ class GenerationMixin:
batch_size
=
batch_size
,
batch_size
=
batch_size
,
num_beams
=
num_beams
,
num_beams
=
num_beams
,
max_length
=
stopping_criteria
.
max_length
,
max_length
=
stopping_criteria
.
max_length
,
device
=
self
.
device
,
device
=
inputs_tensor
.
device
,
length_penalty
=
length_penalty
,
length_penalty
=
length_penalty
,
do_early_stopping
=
early_stopping
,
do_early_stopping
=
early_stopping
,
num_beam_hyps_to_keep
=
num_return_sequences
,
num_beam_hyps_to_keep
=
num_return_sequences
,
...
@@ -1492,7 +1496,7 @@ class GenerationMixin:
...
@@ -1492,7 +1496,7 @@ class GenerationMixin:
constraints
=
final_constraints
,
constraints
=
final_constraints
,
batch_size
=
batch_size
,
batch_size
=
batch_size
,
num_beams
=
num_beams
,
num_beams
=
num_beams
,
device
=
self
.
device
,
device
=
inputs_tensor
.
device
,
length_penalty
=
length_penalty
,
length_penalty
=
length_penalty
,
do_early_stopping
=
early_stopping
,
do_early_stopping
=
early_stopping
,
num_beam_hyps_to_keep
=
num_return_sequences
,
num_beam_hyps_to_keep
=
num_return_sequences
,
...
...
src/transformers/modeling_utils.py
View file @
344b9fb0
...
@@ -1157,7 +1157,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -1157,7 +1157,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Build new embeddings
# Build new embeddings
new_embeddings
=
nn
.
Embedding
(
new_num_tokens
,
old_embedding_dim
)
new_embeddings
=
nn
.
Embedding
(
new_num_tokens
,
old_embedding_dim
)
new_embeddings
.
to
(
self
.
device
,
dtype
=
old_embeddings
.
weight
.
dtype
)
new_embeddings
.
to
(
old_embeddings
.
weight
.
device
,
dtype
=
old_embeddings
.
weight
.
dtype
)
# initialize all new embeddings (in particular added tokens)
# initialize all new embeddings (in particular added tokens)
self
.
_init_weights
(
new_embeddings
)
self
.
_init_weights
(
new_embeddings
)
...
@@ -1228,7 +1228,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -1228,7 +1228,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
new_lm_head_shape
=
(
old_lm_head_dim
,
new_num_tokens
)
if
not
transposed
else
(
new_num_tokens
,
old_lm_head_dim
)
new_lm_head_shape
=
(
old_lm_head_dim
,
new_num_tokens
)
if
not
transposed
else
(
new_num_tokens
,
old_lm_head_dim
)
has_new_lm_head_bias
=
old_lm_head
.
bias
is
not
None
has_new_lm_head_bias
=
old_lm_head
.
bias
is
not
None
new_lm_head
=
nn
.
Linear
(
*
new_lm_head_shape
,
bias
=
has_new_lm_head_bias
)
new_lm_head
=
nn
.
Linear
(
*
new_lm_head_shape
,
bias
=
has_new_lm_head_bias
)
new_lm_head
=
new_lm_head
.
to
(
self
.
device
,
dtype
=
old_lm_head
.
weight
.
dtype
)
new_lm_head
=
new_lm_head
.
to
(
old_lm_head
.
weight
.
device
,
dtype
=
old_lm_head
.
weight
.
dtype
)
# initialize new lm head (in particular added tokens)
# initialize new lm head (in particular added tokens)
self
.
_init_weights
(
new_lm_head
)
self
.
_init_weights
(
new_lm_head
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment