Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
db136341
Unverified
Commit
db136341
authored
May 18, 2023
by
Joao Gante
Committed by
GitHub
May 18, 2023
Browse files
TF: GPT2 with native embedding layers (#23436)
parent
c618ab4f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
24 deletions
+22
-24
docs/source/en/internal/modeling_utils.mdx
docs/source/en/internal/modeling_utils.mdx
+0
-3
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+4
-0
src/transformers/models/gpt2/modeling_tf_gpt2.py
src/transformers/models/gpt2/modeling_tf_gpt2.py
+18
-21
No files found.
docs/source/en/internal/modeling_utils.mdx
View file @
db136341
...
...
@@ -54,9 +54,6 @@ Most of those are only useful if you are studying the code of the models in the
[[autodoc]] modeling_tf_utils.TFConv1D
[[autodoc]] modeling_tf_utils.TFSharedEmbeddings
- call
[[autodoc]] modeling_tf_utils.TFSequenceSummary
## TensorFlow loss functions
...
...
src/transformers/modeling_tf_utils.py
View file @
db136341
...
...
@@ -3132,6 +3132,10 @@ class TFSharedEmbeddings(tf.keras.layers.Layer):
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
initializer_range
=
hidden_size
**-
0.5
if
initializer_range
is
None
else
initializer_range
warnings
.
warn
(
"`TFSharedEmbeddings` is scheduled for deletion in v4.32, use `tf.keras.layers.Embedding` instead."
,
DeprecationWarning
,
)
def
build
(
self
,
input_shape
):
"""
...
...
src/transformers/models/gpt2/modeling_tf_gpt2.py
View file @
db136341
...
...
@@ -34,7 +34,6 @@ from ...modeling_tf_utils import (
TFPreTrainedModel
,
TFSequenceClassificationLoss
,
TFSequenceSummary
,
TFSharedEmbeddings
,
get_initializer
,
keras_serializable
,
unpack_inputs
,
...
...
@@ -315,29 +314,27 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
self
.
n_positions
=
config
.
n_positions
self
.
initializer_range
=
config
.
initializer_range
self
.
wte
=
TFSharedEmbeddings
(
config
.
vocab_size
,
config
.
hidden_size
,
initializer_range
=
config
.
initializer_range
,
name
=
"wte"
self
.
wte
=
tf
.
keras
.
layers
.
Embedding
(
input_dim
=
config
.
vocab_size
,
output_dim
=
config
.
hidden_size
,
embeddings_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"wte"
,
)
self
.
wpe
=
tf
.
keras
.
layers
.
Embedding
(
input_dim
=
config
.
n_positions
,
output_dim
=
config
.
n_embd
,
embeddings_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"wpe"
,
)
self
.
drop
=
tf
.
keras
.
layers
.
Dropout
(
config
.
embd_pdrop
)
self
.
h
=
[
TFBlock
(
config
,
scale
=
True
,
name
=
f
"h_._
{
i
}
"
)
for
i
in
range
(
config
.
n_layer
)]
self
.
ln_f
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_epsilon
,
name
=
"ln_f"
)
def
build
(
self
,
input_shape
):
with
tf
.
name_scope
(
"wpe"
):
self
.
wpe
=
self
.
add_weight
(
name
=
"embeddings"
,
shape
=
[
self
.
n_positions
,
self
.
n_embd
],
initializer
=
get_initializer
(
self
.
initializer_range
),
)
super
().
build
(
input_shape
)
def
get_input_embeddings
(
self
):
return
self
.
wte
def
set_input_embeddings
(
self
,
value
):
self
.
wte
.
weight
=
value
self
.
wte
.
vocab_size
=
shape_list
(
value
)[
0
]
def
set_input_embeddings
(
self
,
new_embeddings
):
self
.
wte
=
new_embeddings
def
_prune_heads
(
self
,
heads_to_prune
):
"""
...
...
@@ -438,13 +435,13 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
if
inputs_embeds
is
None
:
check_embeddings_within_bounds
(
input_ids
,
self
.
config
.
vocab_size
)
inputs_embeds
=
self
.
wte
(
input_ids
,
mode
=
"embedding"
)
inputs_embeds
=
self
.
wte
(
input_ids
)
position_embeds
=
tf
.
gather
(
self
.
wpe
,
position_ids
)
position_embeds
=
self
.
wpe
(
position_ids
)
if
token_type_ids
is
not
None
:
token_type_ids
=
tf
.
reshape
(
token_type_ids
,
[
-
1
,
shape_list
(
token_type_ids
)[
-
1
]])
token_type_embeds
=
self
.
wte
(
token_type_ids
,
mode
=
"embedding"
)
token_type_embeds
=
self
.
wte
(
token_type_ids
)
else
:
token_type_embeds
=
tf
.
constant
(
0.0
)
...
...
@@ -904,7 +901,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
training
=
training
,
)
hidden_states
=
transformer_outputs
[
0
]
logits
=
self
.
transformer
.
wte
(
hidden_states
,
mode
=
"linear"
)
logits
=
tf
.
matmul
(
hidden_states
,
self
.
transformer
.
wte
.
weights
,
transpose_b
=
True
)
loss
=
None
if
labels
is
not
None
:
...
...
@@ -1048,7 +1045,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
all_hidden_states
=
transformer_outputs
.
hidden_states
[:
-
1
]
+
(
hidden_states
,)
else
:
all_hidden_states
=
None
lm_logits
=
self
.
transformer
.
wte
(
hidden_states
,
mode
=
"linear"
)
lm_logits
=
tf
.
matmul
(
hidden_states
,
self
.
transformer
.
wte
.
weights
,
transpose_b
=
True
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
,
training
=
training
)
mc_logits
=
tf
.
squeeze
(
mc_logits
,
axis
=-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment