Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5a34d8d9
Unverified
Commit
5a34d8d9
authored
Apr 19, 2021
by
e
Committed by
GitHub
Apr 19, 2021
Browse files
move device statements outside if statements (#11292)
parent
d9c62047
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
5 deletions
+8
-5
src/transformers/models/ctrl/modeling_ctrl.py
src/transformers/models/ctrl/modeling_ctrl.py
+4
-3
src/transformers/models/gpt2/modeling_gpt2.py
src/transformers/models/gpt2/modeling_gpt2.py
+2
-1
src/transformers/models/gpt_neo/modeling_gpt_neo.py
src/transformers/models/gpt_neo/modeling_gpt_neo.py
+2
-1
No files found.
src/transformers/models/ctrl/modeling_ctrl.py
View file @
5a34d8d9
...
@@ -394,13 +394,14 @@ class CTRLModel(CTRLPreTrainedModel):
...
@@ -394,13 +394,14 @@ class CTRLModel(CTRLPreTrainedModel):
else
:
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
if
past_key_values
is
None
:
if
past_key_values
is
None
:
past_length
=
0
past_length
=
0
past_key_values
=
tuple
([
None
]
*
len
(
self
.
h
))
past_key_values
=
tuple
([
None
]
*
len
(
self
.
h
))
else
:
else
:
past_length
=
past_key_values
[
0
][
0
].
size
(
-
2
)
past_length
=
past_key_values
[
0
][
0
].
size
(
-
2
)
if
position_ids
is
None
:
if
position_ids
is
None
:
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
position_ids
=
torch
.
arange
(
past_length
,
input_shape
[
-
1
]
+
past_length
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
torch
.
arange
(
past_length
,
input_shape
[
-
1
]
+
past_length
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
view
(
-
1
,
input_shape
[
-
1
])
position_ids
=
position_ids
.
unsqueeze
(
0
).
view
(
-
1
,
input_shape
[
-
1
])
...
@@ -438,11 +439,11 @@ class CTRLModel(CTRLPreTrainedModel):
...
@@ -438,11 +439,11 @@ class CTRLModel(CTRLPreTrainedModel):
inputs_embeds
=
self
.
w
(
input_ids
)
inputs_embeds
=
self
.
w
(
input_ids
)
# inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
# inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len
=
input_shape
[
-
1
]
seq_len
=
input_shape
[
-
1
]
mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
+
past_length
,
seq_len
+
past_length
),
1
).
to
(
inputs_embeds
.
device
)
mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
+
past_length
,
seq_len
+
past_length
),
1
).
to
(
device
)
inputs_embeds
*=
np
.
sqrt
(
self
.
d_model_size
)
inputs_embeds
*=
np
.
sqrt
(
self
.
d_model_size
)
pos_embeds
=
self
.
pos_encoding
[
position_ids
,
:].
to
(
inputs_embeds
.
device
)
pos_embeds
=
self
.
pos_encoding
[
position_ids
,
:].
to
(
device
)
hidden_states
=
inputs_embeds
+
pos_embeds
+
token_type_embeds
hidden_states
=
inputs_embeds
+
pos_embeds
+
token_type_embeds
...
...
src/transformers/models/gpt2/modeling_gpt2.py
View file @
5a34d8d9
...
@@ -675,6 +675,8 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -675,6 +675,8 @@ class GPT2Model(GPT2PreTrainedModel):
else
:
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
if
token_type_ids
is
not
None
:
if
token_type_ids
is
not
None
:
token_type_ids
=
token_type_ids
.
view
(
-
1
,
input_shape
[
-
1
])
token_type_ids
=
token_type_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
position_ids
is
not
None
:
if
position_ids
is
not
None
:
...
@@ -686,7 +688,6 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -686,7 +688,6 @@ class GPT2Model(GPT2PreTrainedModel):
else
:
else
:
past_length
=
past_key_values
[
0
][
0
].
size
(
-
2
)
past_length
=
past_key_values
[
0
][
0
].
size
(
-
2
)
if
position_ids
is
None
:
if
position_ids
is
None
:
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
position_ids
=
torch
.
arange
(
past_length
,
input_shape
[
-
1
]
+
past_length
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
torch
.
arange
(
past_length
,
input_shape
[
-
1
]
+
past_length
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
view
(
-
1
,
input_shape
[
-
1
])
position_ids
=
position_ids
.
unsqueeze
(
0
).
view
(
-
1
,
input_shape
[
-
1
])
...
...
src/transformers/models/gpt_neo/modeling_gpt_neo.py
View file @
5a34d8d9
...
@@ -755,6 +755,8 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
...
@@ -755,6 +755,8 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
else
:
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
if
token_type_ids
is
not
None
:
if
token_type_ids
is
not
None
:
token_type_ids
=
token_type_ids
.
view
(
-
1
,
input_shape
[
-
1
])
token_type_ids
=
token_type_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
position_ids
is
not
None
:
if
position_ids
is
not
None
:
...
@@ -766,7 +768,6 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
...
@@ -766,7 +768,6 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
else
:
else
:
past_length
=
past_key_values
[
0
][
0
].
size
(
-
2
)
past_length
=
past_key_values
[
0
][
0
].
size
(
-
2
)
if
position_ids
is
None
:
if
position_ids
is
None
:
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
position_ids
=
torch
.
arange
(
past_length
,
input_shape
[
-
1
]
+
past_length
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
torch
.
arange
(
past_length
,
input_shape
[
-
1
]
+
past_length
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
view
(
-
1
,
input_shape
[
-
1
])
position_ids
=
position_ids
.
unsqueeze
(
0
).
view
(
-
1
,
input_shape
[
-
1
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment