Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
db136341
Unverified
Commit
db136341
authored
May 18, 2023
by
Joao Gante
Committed by
GitHub
May 18, 2023
Browse files
TF: GPT2 with native embedding layers (#23436)
parent
c618ab4f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
24 deletions
+22
-24
docs/source/en/internal/modeling_utils.mdx
docs/source/en/internal/modeling_utils.mdx
+0
-3
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+4
-0
src/transformers/models/gpt2/modeling_tf_gpt2.py
src/transformers/models/gpt2/modeling_tf_gpt2.py
+18
-21
No files found.
docs/source/en/internal/modeling_utils.mdx
View file @
db136341
...
...
@@ -54,9 +54,6 @@ Most of those are only useful if you are studying the code of the models in the
[[autodoc]] modeling_tf_utils.TFConv1D
[[autodoc]] modeling_tf_utils.TFSharedEmbeddings
- call
[[autodoc]] modeling_tf_utils.TFSequenceSummary
## TensorFlow loss functions
...
...
src/transformers/modeling_tf_utils.py
View file @
db136341
...
...
@@ -3132,6 +3132,10 @@ class TFSharedEmbeddings(tf.keras.layers.Layer):
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
initializer_range
=
hidden_size
**-
0.5
if
initializer_range
is
None
else
initializer_range
warnings
.
warn
(
"`TFSharedEmbeddings` is scheduled for deletion in v4.32, use `tf.keras.layers.Embedding` instead."
,
DeprecationWarning
,
)
def
build
(
self
,
input_shape
):
"""
...
...
src/transformers/models/gpt2/modeling_tf_gpt2.py
View file @
db136341
...
...
@@ -34,7 +34,6 @@ from ...modeling_tf_utils import (
TFPreTrainedModel
,
TFSequenceClassificationLoss
,
TFSequenceSummary
,
TFSharedEmbeddings
,
get_initializer
,
keras_serializable
,
unpack_inputs
,
...
...
@@ -315,29 +314,27 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
self
.
n_positions
=
config
.
n_positions
self
.
initializer_range
=
config
.
initializer_range
self
.
wte
=
TFSharedEmbeddings
(
config
.
vocab_size
,
config
.
hidden_size
,
initializer_range
=
config
.
initializer_range
,
name
=
"wte"
self
.
wte
=
tf
.
keras
.
layers
.
Embedding
(
input_dim
=
config
.
vocab_size
,
output_dim
=
config
.
hidden_size
,
embeddings_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"wte"
,
)
self
.
wpe
=
tf
.
keras
.
layers
.
Embedding
(
input_dim
=
config
.
n_positions
,
output_dim
=
config
.
n_embd
,
embeddings_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"wpe"
,
)
self
.
drop
=
tf
.
keras
.
layers
.
Dropout
(
config
.
embd_pdrop
)
self
.
h
=
[
TFBlock
(
config
,
scale
=
True
,
name
=
f
"h_._
{
i
}
"
)
for
i
in
range
(
config
.
n_layer
)]
self
.
ln_f
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_epsilon
,
name
=
"ln_f"
)
def
build
(
self
,
input_shape
):
with
tf
.
name_scope
(
"wpe"
):
self
.
wpe
=
self
.
add_weight
(
name
=
"embeddings"
,
shape
=
[
self
.
n_positions
,
self
.
n_embd
],
initializer
=
get_initializer
(
self
.
initializer_range
),
)
super
().
build
(
input_shape
)
def
get_input_embeddings
(
self
):
return
self
.
wte
def
set_input_embeddings
(
self
,
value
):
self
.
wte
.
weight
=
value
self
.
wte
.
vocab_size
=
shape_list
(
value
)[
0
]
def
set_input_embeddings
(
self
,
new_embeddings
):
self
.
wte
=
new_embeddings
def
_prune_heads
(
self
,
heads_to_prune
):
"""
...
...
@@ -438,13 +435,13 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
if
inputs_embeds
is
None
:
check_embeddings_within_bounds
(
input_ids
,
self
.
config
.
vocab_size
)
inputs_embeds
=
self
.
wte
(
input_ids
,
mode
=
"embedding"
)
inputs_embeds
=
self
.
wte
(
input_ids
)
position_embeds
=
tf
.
gather
(
self
.
wpe
,
position_ids
)
position_embeds
=
self
.
wpe
(
position_ids
)
if
token_type_ids
is
not
None
:
token_type_ids
=
tf
.
reshape
(
token_type_ids
,
[
-
1
,
shape_list
(
token_type_ids
)[
-
1
]])
token_type_embeds
=
self
.
wte
(
token_type_ids
,
mode
=
"embedding"
)
token_type_embeds
=
self
.
wte
(
token_type_ids
)
else
:
token_type_embeds
=
tf
.
constant
(
0.0
)
...
...
@@ -904,7 +901,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
training
=
training
,
)
hidden_states
=
transformer_outputs
[
0
]
logits
=
self
.
transformer
.
wte
(
hidden_states
,
mode
=
"linear"
)
logits
=
tf
.
matmul
(
hidden_states
,
self
.
transformer
.
wte
.
weights
,
transpose_b
=
True
)
loss
=
None
if
labels
is
not
None
:
...
...
@@ -1048,7 +1045,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
all_hidden_states
=
transformer_outputs
.
hidden_states
[:
-
1
]
+
(
hidden_states
,)
else
:
all_hidden_states
=
None
lm_logits
=
self
.
transformer
.
wte
(
hidden_states
,
mode
=
"linear"
)
lm_logits
=
tf
.
matmul
(
hidden_states
,
self
.
transformer
.
wte
.
weights
,
transpose_b
=
True
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
,
training
=
training
)
mc_logits
=
tf
.
squeeze
(
mc_logits
,
axis
=-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment