Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8e11de0e
Commit
8e11de0e
authored
Nov 01, 2019
by
Julien Chaumond
Browse files
model forwards can take an inputs_embeds param
parent
68f7064a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
10 deletions
+20
-10
transformers/modeling_gpt2.py
transformers/modeling_gpt2.py
+20
-10
No files found.
transformers/modeling_gpt2.py
View file @
8e11de0e
...
@@ -370,9 +370,15 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -370,9 +370,15 @@ class GPT2Model(GPT2PreTrainedModel):
for
layer
,
heads
in
heads_to_prune
.
items
():
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
h
[
layer
].
attn
.
prune_heads
(
heads
)
self
.
h
[
layer
].
attn
.
prune_heads
(
heads
)
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
=
None
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
):
input_shape
=
input_ids
.
size
()
if
input_ids
is
not
None
:
input_ids
=
input_ids
.
view
(
-
1
,
input_shape
[
-
1
])
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_shape
[
-
1
])
elif
inputs_embeds
is
not
None
:
input_shape
=
inputs_embeds
.
size
()[:
-
1
]
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
token_type_ids
is
not
None
:
if
token_type_ids
is
not
None
:
token_type_ids
=
token_type_ids
.
view
(
-
1
,
input_shape
[
-
1
])
token_type_ids
=
token_type_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
position_ids
is
not
None
:
if
position_ids
is
not
None
:
...
@@ -384,8 +390,9 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -384,8 +390,9 @@ class GPT2Model(GPT2PreTrainedModel):
else
:
else
:
past_length
=
past
[
0
][
0
].
size
(
-
2
)
past_length
=
past
[
0
][
0
].
size
(
-
2
)
if
position_ids
is
None
:
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
past_length
,
input_ids
.
size
(
-
1
)
+
past_length
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
input_ids
)
position_ids
=
torch
.
arange
(
past_length
,
input_shape
[
-
1
]
+
past_length
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
view
(
-
1
,
input_shape
[
-
1
])
# Attention mask.
# Attention mask.
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
...
@@ -419,7 +426,8 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -419,7 +426,8 @@ class GPT2Model(GPT2PreTrainedModel):
else
:
else
:
head_mask
=
[
None
]
*
self
.
config
.
n_layer
head_mask
=
[
None
]
*
self
.
config
.
n_layer
inputs_embeds
=
self
.
wte
(
input_ids
)
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
wte
(
input_ids
)
position_embeds
=
self
.
wpe
(
position_ids
)
position_embeds
=
self
.
wpe
(
position_ids
)
if
token_type_ids
is
not
None
:
if
token_type_ids
is
not
None
:
token_type_embeds
=
self
.
wte
(
token_type_ids
)
token_type_embeds
=
self
.
wte
(
token_type_ids
)
...
@@ -520,14 +528,15 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
...
@@ -520,14 +528,15 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
def
get_output_embeddings
(
self
):
def
get_output_embeddings
(
self
):
return
self
.
lm_head
return
self
.
lm_head
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
def
forward
(
self
,
input_ids
=
None
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
labels
=
None
):
labels
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
transformer_outputs
=
self
.
transformer
(
input_ids
,
past
=
past
,
past
=
past
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
head_mask
=
head_mask
)
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
)
hidden_states
=
transformer_outputs
[
0
]
hidden_states
=
transformer_outputs
[
0
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
...
@@ -623,14 +632,15 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
...
@@ -623,14 +632,15 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
def
get_output_embeddings
(
self
):
def
get_output_embeddings
(
self
):
return
self
.
lm_head
return
self
.
lm_head
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
def
forward
(
self
,
input_ids
=
None
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
mc_token_ids
=
None
,
lm_labels
=
None
,
mc_labels
=
None
):
mc_token_ids
=
None
,
lm_labels
=
None
,
mc_labels
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
transformer_outputs
=
self
.
transformer
(
input_ids
,
past
=
past
,
past
=
past
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
head_mask
=
head_mask
)
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
)
hidden_states
=
transformer_outputs
[
0
]
hidden_states
=
transformer_outputs
[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment