Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
965f172d
Commit
965f172d
authored
Jun 17, 2019
by
thomwolf
Browse files
output all hidden layers states in GPT/GPT-2
parent
f12007e4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
43 additions
and
12 deletions
+43
-12
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+12
-3
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+11
-3
tests/modeling_gpt2_test.py
tests/modeling_gpt2_test.py
+10
-3
tests/modeling_openai_test.py
tests/modeling_openai_test.py
+10
-3
No files found.
pytorch_pretrained_bert/modeling_gpt2.py
View file @
965f172d
...
@@ -720,9 +720,13 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -720,9 +720,13 @@ class GPT2Model(GPT2PreTrainedModel):
hidden_states
=
inputs_embeds
+
position_embeds
+
token_type_embeds
hidden_states
=
inputs_embeds
+
position_embeds
+
token_type_embeds
hidden_states
=
self
.
drop
(
hidden_states
)
hidden_states
=
self
.
drop
(
hidden_states
)
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
presents
=
[]
presents
=
[]
all_attentions
=
[]
all_attentions
=
[]
all_hidden_states
=
[]
for
block
,
layer_past
in
zip
(
self
.
h
,
past
):
for
block
,
layer_past
in
zip
(
self
.
h
,
past
):
all_hidden_states
.
append
(
hidden_states
.
view
(
*
output_shape
))
outputs
=
block
(
hidden_states
,
layer_past
,
head_mask
)
outputs
=
block
(
hidden_states
,
layer_past
,
head_mask
)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
attentions
,
hidden_states
,
present
=
outputs
attentions
,
hidden_states
,
present
=
outputs
...
@@ -731,10 +735,11 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -731,10 +735,11 @@ class GPT2Model(GPT2PreTrainedModel):
hidden_states
,
present
=
outputs
hidden_states
,
present
=
outputs
presents
.
append
(
present
)
presents
.
append
(
present
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
all_hidden_states
.
append
(
hidden_states
.
view
(
*
output_shape
))
if
self
.
output_attentions
:
if
self
.
output_attentions
:
return
all_attentions
,
hidden_states
.
view
(
*
output_shape
)
,
presents
return
all_attentions
,
all_
hidden_states
,
presents
return
hidden_states
.
view
(
*
output_shape
)
,
presents
return
all_
hidden_states
,
presents
class
GPT2LMHeadModel
(
GPT2PreTrainedModel
):
class
GPT2LMHeadModel
(
GPT2PreTrainedModel
):
...
@@ -802,6 +807,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
...
@@ -802,6 +807,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
all_attentions
,
hidden_states
,
presents
=
transformer_output
all_attentions
,
hidden_states
,
presents
=
transformer_output
else
:
else
:
hidden_states
,
presents
=
transformer_output
hidden_states
,
presents
=
transformer_output
hidden_states
=
hidden_states
[
-
1
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
if
lm_labels
is
not
None
:
if
lm_labels
is
not
None
:
# Shift so that tokens < n predict n
# Shift so that tokens < n predict n
...
@@ -889,6 +896,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
...
@@ -889,6 +896,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
all_attentions
,
hidden_states
,
presents
=
transformer_output
all_attentions
,
hidden_states
,
presents
=
transformer_output
else
:
else
:
hidden_states
,
presents
=
transformer_output
hidden_states
,
presents
=
transformer_output
hidden_states
=
hidden_states
[
-
1
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
losses
=
[]
losses
=
[]
...
...
pytorch_pretrained_bert/modeling_openai.py
View file @
965f172d
...
@@ -716,7 +716,10 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
...
@@ -716,7 +716,10 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
hidden_states
=
inputs_embeds
+
position_embeds
+
token_type_embeds
hidden_states
=
inputs_embeds
+
position_embeds
+
token_type_embeds
hidden_states
=
self
.
drop
(
hidden_states
)
hidden_states
=
self
.
drop
(
hidden_states
)
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
all_attentions
=
[]
all_attentions
=
[]
all_hidden_states
=
[
hidden_states
.
view
(
*
output_shape
)]
for
block
in
self
.
h
:
for
block
in
self
.
h
:
outputs
=
block
(
hidden_states
,
head_mask
)
outputs
=
block
(
hidden_states
,
head_mask
)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
...
@@ -724,10 +727,11 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
...
@@ -724,10 +727,11 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
all_attentions
.
append
(
attentions
)
all_attentions
.
append
(
attentions
)
else
:
else
:
hidden_states
=
outputs
hidden_states
=
outputs
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
all_hidden_states
.
append
(
hidden_states
.
view
(
*
output_shape
))
if
self
.
output_attentions
:
if
self
.
output_attentions
:
return
all_attentions
,
hidden_states
.
view
(
*
output_shape
)
return
all_attentions
,
all_
hidden_states
return
hidden_states
.
view
(
*
output_shape
)
return
all_
hidden_states
class
OpenAIGPTLMHeadModel
(
OpenAIGPTPreTrainedModel
):
class
OpenAIGPTLMHeadModel
(
OpenAIGPTPreTrainedModel
):
...
@@ -805,6 +809,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
...
@@ -805,6 +809,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
head_mask
)
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
head_mask
)
if
self
.
transformer
.
output_attentions
:
if
self
.
transformer
.
output_attentions
:
all_attentions
,
hidden_states
=
hidden_states
all_attentions
,
hidden_states
=
hidden_states
hidden_states
=
hidden_states
[
-
1
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
if
lm_labels
is
not
None
:
if
lm_labels
is
not
None
:
# Shift so that tokens < n predict n
# Shift so that tokens < n predict n
...
@@ -902,6 +908,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
...
@@ -902,6 +908,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
head_mask
)
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
head_mask
)
if
self
.
transformer
.
output_attentions
:
if
self
.
transformer
.
output_attentions
:
all_attentions
,
hidden_states
=
hidden_states
all_attentions
,
hidden_states
=
hidden_states
hidden_states
=
hidden_states
[
-
1
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
losses
=
[]
losses
=
[]
...
...
tests/modeling_gpt2_test.py
View file @
965f172d
...
@@ -115,8 +115,9 @@ class GPT2ModelTest(unittest.TestCase):
...
@@ -115,8 +115,9 @@ class GPT2ModelTest(unittest.TestCase):
return
outputs
return
outputs
def
check_gpt2_model_output
(
self
,
result
):
def
check_gpt2_model_output
(
self
,
result
):
self
.
parent
.
assertEqual
(
len
(
result
[
"hidden_states"
]),
self
.
n_layer
+
1
)
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states"
].
size
()),
list
(
result
[
"hidden_states"
]
[
0
]
.
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
self
.
n_embd
])
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
self
.
n_embd
])
...
@@ -222,7 +223,10 @@ class GPT2ModelTest(unittest.TestCase):
...
@@ -222,7 +223,10 @@ class GPT2ModelTest(unittest.TestCase):
else
:
else
:
output
=
model
(
input_ids
,
head_mask
=
head_mask
)
output
=
model
(
input_ids
,
head_mask
=
head_mask
)
output
=
sum
(
t
.
sum
()
for
t
in
output
[:
-
1
])
if
isinstance
(
model
,
GPT2Model
):
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
elif
isinstance
(
output
,
(
list
,
tuple
)):
output
=
sum
(
t
.
sum
()
for
t
in
output
[:
-
1
])
output
=
output
.
sum
()
output
=
output
.
sum
()
output
.
backward
()
output
.
backward
()
multihead_outputs
=
(
model
if
isinstance
(
model
,
GPT2Model
)
else
model
.
transformer
).
get_multihead_outputs
()
multihead_outputs
=
(
model
if
isinstance
(
model
,
GPT2Model
)
else
model
.
transformer
).
get_multihead_outputs
()
...
@@ -256,7 +260,10 @@ class GPT2ModelTest(unittest.TestCase):
...
@@ -256,7 +260,10 @@ class GPT2ModelTest(unittest.TestCase):
else
:
else
:
output
=
model
(
input_ids
)
output
=
model
(
input_ids
)
output
=
sum
(
t
.
sum
()
for
t
in
output
[:
-
1
])
if
isinstance
(
model
,
GPT2Model
):
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
elif
isinstance
(
output
,
(
list
,
tuple
)):
output
=
sum
(
t
.
sum
()
for
t
in
output
[:
-
1
])
output
=
output
.
sum
()
output
=
output
.
sum
()
output
.
backward
()
output
.
backward
()
multihead_outputs
=
transformer
.
get_multihead_outputs
()
multihead_outputs
=
transformer
.
get_multihead_outputs
()
...
...
tests/modeling_openai_test.py
View file @
965f172d
...
@@ -125,8 +125,9 @@ class OpenAIGPTModelTest(unittest.TestCase):
...
@@ -125,8 +125,9 @@ class OpenAIGPTModelTest(unittest.TestCase):
return
outputs
return
outputs
def
check_openai_model_output
(
self
,
result
):
def
check_openai_model_output
(
self
,
result
):
self
.
parent
.
assertEqual
(
len
(
result
[
"hidden_states"
]),
self
.
n_layer
+
1
)
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states"
].
size
()),
list
(
result
[
"hidden_states"
]
[
0
]
.
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
self
.
n_embd
])
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
self
.
n_embd
])
...
@@ -195,7 +196,10 @@ class OpenAIGPTModelTest(unittest.TestCase):
...
@@ -195,7 +196,10 @@ class OpenAIGPTModelTest(unittest.TestCase):
else
:
else
:
output
=
model
(
input_ids
,
head_mask
=
head_mask
)
output
=
model
(
input_ids
,
head_mask
=
head_mask
)
output
=
sum
(
t
.
sum
()
for
t
in
output
[:
-
1
])
if
isinstance
(
model
,
OpenAIGPTModel
):
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
elif
isinstance
(
output
,
(
list
,
tuple
)):
output
=
sum
(
t
.
sum
()
for
t
in
output
)
output
=
output
.
sum
()
output
=
output
.
sum
()
output
.
backward
()
output
.
backward
()
multihead_outputs
=
(
model
if
isinstance
(
model
,
OpenAIGPTModel
)
else
model
.
transformer
).
get_multihead_outputs
()
multihead_outputs
=
(
model
if
isinstance
(
model
,
OpenAIGPTModel
)
else
model
.
transformer
).
get_multihead_outputs
()
...
@@ -229,7 +233,10 @@ class OpenAIGPTModelTest(unittest.TestCase):
...
@@ -229,7 +233,10 @@ class OpenAIGPTModelTest(unittest.TestCase):
else
:
else
:
output
=
model
(
input_ids
)
output
=
model
(
input_ids
)
output
=
sum
(
t
.
sum
()
for
t
in
output
[:
-
1
])
if
isinstance
(
model
,
OpenAIGPTModel
):
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
elif
isinstance
(
output
,
(
list
,
tuple
)):
output
=
sum
(
t
.
sum
()
for
t
in
output
)
output
=
output
.
sum
()
output
=
output
.
sum
()
output
.
backward
()
output
.
backward
()
multihead_outputs
=
transformer
.
get_multihead_outputs
()
multihead_outputs
=
transformer
.
get_multihead_outputs
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment