Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4626df50
Unverified
Commit
4626df50
authored
Jun 14, 2023
by
Joao Gante
Committed by
GitHub
Jun 14, 2023
Browse files
TF: CTRL with native embedding layers (#23456)
parent
eac8dede
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
70 additions
and
55 deletions
+70
-55
src/transformers/models/ctrl/modeling_tf_ctrl.py
src/transformers/models/ctrl/modeling_tf_ctrl.py
+69
-55
tests/models/ctrl/test_modeling_tf_ctrl.py
tests/models/ctrl/test_modeling_tf_ctrl.py
+1
-0
No files found.
src/transformers/models/ctrl/modeling_tf_ctrl.py
View file @
4626df50
...
@@ -15,10 +15,8 @@
...
@@ -15,10 +15,8 @@
# limitations under the License.
# limitations under the License.
""" TF 2.0 CTRL model."""
""" TF 2.0 CTRL model."""
from
__future__
import
annotations
from
__future__
import
annotations
import
warnings
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -30,7 +28,6 @@ from ...modeling_tf_utils import (
...
@@ -30,7 +28,6 @@ from ...modeling_tf_utils import (
TFModelInputType
,
TFModelInputType
,
TFPreTrainedModel
,
TFPreTrainedModel
,
TFSequenceClassificationLoss
,
TFSequenceClassificationLoss
,
TFSharedEmbeddings
,
get_initializer
,
get_initializer
,
keras_serializable
,
keras_serializable
,
unpack_inputs
,
unpack_inputs
,
...
@@ -224,8 +221,11 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
...
@@ -224,8 +221,11 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
self
.
pos_encoding
=
positional_encoding
(
config
.
n_positions
,
self
.
d_model_size
)
self
.
pos_encoding
=
positional_encoding
(
config
.
n_positions
,
self
.
d_model_size
)
self
.
w
=
TFSharedEmbeddings
(
self
.
w
=
tf
.
keras
.
layers
.
Embedding
(
config
.
vocab_size
,
config
.
n_embd
,
initializer_range
=
config
.
initializer_range
,
name
=
"w"
input_dim
=
config
.
vocab_size
,
output_dim
=
config
.
n_embd
,
embeddings_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"w"
,
)
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
embd_pdrop
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
embd_pdrop
)
...
@@ -246,9 +246,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
...
@@ -246,9 +246,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
def
get_input_embeddings
(
self
):
def
get_input_embeddings
(
self
):
return
self
.
w
return
self
.
w
def
set_input_embeddings
(
self
,
value
):
def
set_input_embeddings
(
self
,
new_embeddings
):
self
.
w
.
weight
=
value
self
.
w
=
new_embeddings
self
.
w
.
vocab_size
=
shape_list
(
value
)[
0
]
def
_prune_heads
(
self
,
heads_to_prune
):
def
_prune_heads
(
self
,
heads_to_prune
):
"""
"""
...
@@ -308,7 +307,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
...
@@ -308,7 +307,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask
=
tf
.
reshape
(
attention_mask
,
(
input_shape
[
0
],
1
,
1
,
input_shape
[
1
]))
attention_mask
=
tf
.
reshape
(
attention_mask
,
(
input_shape
[
0
],
1
,
1
,
input_shape
[
1
]
+
past_length
))
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
...
@@ -332,15 +331,15 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
...
@@ -332,15 +331,15 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
if
token_type_ids
is
not
None
:
if
token_type_ids
is
not
None
:
token_type_ids
=
tf
.
reshape
(
token_type_ids
,
[
-
1
,
shape_list
(
token_type_ids
)[
-
1
]])
token_type_ids
=
tf
.
reshape
(
token_type_ids
,
[
-
1
,
shape_list
(
token_type_ids
)[
-
1
]])
token_type_embeds
=
self
.
w
(
token_type_ids
,
mode
=
"embedding"
)
token_type_embeds
=
self
.
w
(
token_type_ids
)
token_type_embeds
*=
tf
.
math
.
sqrt
(
tf
.
cast
(
self
.
d_model_size
,
dtype
=
token_type_embeds
.
dtype
))
token_type_embeds
*=
tf
.
math
.
sqrt
(
tf
.
cast
(
self
.
d_model_size
,
dtype
=
token_type_embeds
.
dtype
))
else
:
else
:
token_type_embeds
=
tf
.
constant
(
0.0
)
token_type_embeds
=
tf
.
constant
(
0.0
)
position_ids
=
tf
.
reshape
(
position_ids
,
[
-
1
,
shape_list
(
position_ids
)[
-
1
]])
position_ids
=
tf
.
reshape
(
position_ids
,
[
-
1
,
shape_list
(
position_ids
)[
-
1
]])
if
inputs_embeds
is
None
:
if
inputs_embeds
is
None
:
check_embeddings_within_bounds
(
input_ids
,
self
.
w
.
vocab_size
)
check_embeddings_within_bounds
(
input_ids
,
self
.
w
.
input_dim
)
inputs_embeds
=
self
.
w
(
input_ids
,
mode
=
"embedding"
)
inputs_embeds
=
self
.
w
(
input_ids
)
seq_len
=
input_shape
[
-
1
]
seq_len
=
input_shape
[
-
1
]
mask
=
1
-
tf
.
linalg
.
band_part
(
tf
.
ones
((
seq_len
,
seq_len
)),
-
1
,
0
)
mask
=
1
-
tf
.
linalg
.
band_part
(
tf
.
ones
((
seq_len
,
seq_len
)),
-
1
,
0
)
...
@@ -565,39 +564,26 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
...
@@ -565,39 +564,26 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
return
outputs
return
outputs
class
TFCTRLLMHead
(
tf
.
keras
.
layers
.
Layer
):
class
TFCTRLBiasLayer
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
config
,
input_embeddings
,
**
kwargs
):
"""
super
().
__init__
(
**
kwargs
)
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
self
.
config
=
config
so all weights have to be registered in a layer.
# CTRL has numerical issues in XLA generate
"""
self
.
supports_xla_generation
=
False
# The output weights are the same as the input embeddings, but there is
def
__init__
(
self
,
shape
,
initializer
,
trainable
,
name
,
**
kwargs
):
# an output-only bias for each token.
super
().
__init__
(
name
=
name
,
**
kwargs
)
self
.
input_embeddings
=
input_embeddings
self
.
shape
=
shape
self
.
initializer
=
initializer
self
.
trainable
=
trainable
def
build
(
self
,
input_shape
=
None
):
def
build
(
self
,
input_shape
):
self
.
bias
=
self
.
add_weight
(
shape
=
(
self
.
config
.
vocab_size
,),
initializer
=
"zeros"
,
trainable
=
True
,
name
=
"bias"
)
self
.
bias
=
self
.
add_weight
(
name
=
"bias"
,
shape
=
self
.
shape
,
initializer
=
self
.
initializer
,
trainable
=
self
.
trainable
)
super
().
build
(
input_shape
)
super
().
build
(
input_shape
)
def
get_output_embeddings
(
self
):
def
call
(
self
,
x
):
return
self
.
input_embeddings
return
x
+
self
.
bias
def
set_output_embeddings
(
self
,
value
):
self
.
input_embeddings
.
weight
=
value
self
.
input_embeddings
.
vocab_size
=
shape_list
(
value
)[
0
]
def
get_bias
(
self
):
return
{
"bias"
:
self
.
bias
}
def
set_bias
(
self
,
value
):
self
.
bias
=
value
[
"bias"
]
self
.
config
.
vocab_size
=
shape_list
(
value
[
"bias"
])[
0
]
def
call
(
self
,
hidden_states
):
hidden_states
=
self
.
input_embeddings
(
hidden_states
,
mode
=
"linear"
)
hidden_states
=
hidden_states
+
self
.
bias
return
hidden_states
@
add_start_docstrings
(
@
add_start_docstrings
(
...
@@ -611,24 +597,53 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -611,24 +597,53 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
().
__init__
(
config
,
*
inputs
,
**
kwargs
)
super
().
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
transformer
=
TFCTRLMainLayer
(
config
,
name
=
"transformer"
)
self
.
transformer
=
TFCTRLMainLayer
(
config
,
name
=
"transformer"
)
self
.
bias_layer
=
TFCTRLBiasLayer
(
name
=
"lm_head"
,
shape
=
[
1
,
config
.
vocab_size
],
initializer
=
"zeros"
,
trainable
=
True
)
self
.
lm_head
=
TFCTRLLMHead
(
config
,
self
.
transformer
.
w
,
name
=
"lm_head"
)
def
get_output_embeddings
(
self
):
# CTRL has numerical issues in XLA generate
return
self
.
get_input_embeddings
()
self
.
supports_xla_generation
=
False
def
g
et_
lm_head
(
self
):
def
s
et_
output_embeddings
(
self
,
value
):
return
self
.
lm_head
self
.
set_input_embeddings
(
value
)
def
get_prefix_bias_name
(
self
):
def
get_bias
(
self
):
warnings
.
warn
(
"The method get_prefix_bias_name is deprecated. Please use `get_bias` instead."
,
FutureWarning
)
return
{
"lm_head.bias"
:
self
.
bias_layer
.
bias
}
return
self
.
name
+
"/"
+
self
.
lm_head
.
name
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
use_cache
=
None
,
**
kwargs
):
def
set_bias
(
self
,
value
):
# Replaces the existing layers containing bias for correct (de)serialization.
vocab_size
=
value
[
"lm_head.bias"
].
shape
[
-
1
]
self
.
bias_layer
=
TFCTRLBiasLayer
(
name
=
"final_logits_bias"
,
shape
=
[
1
,
vocab_size
],
initializer
=
"zeros"
,
trainable
=
True
)
self
.
bias_layer
.
build
(
None
)
self
.
bias_layer
.
bias
.
assign
(
value
[
"lm_head.bias"
])
# Copied from transformers.models.gpt2.modeling_tf_gpt2.TFGPT2LMHeadModel.prepare_inputs_for_generation
def
prepare_inputs_for_generation
(
self
,
inputs
,
past_key_values
=
None
,
use_cache
=
None
,
**
kwargs
):
token_type_ids
=
kwargs
.
get
(
"token_type_ids"
,
None
)
# only last token for inputs_ids if past is defined in kwargs
# only last token for inputs_ids if past is defined in kwargs
if
past_key_values
:
if
past_key_values
:
input_ids
=
tf
.
expand_dims
(
input_ids
[:,
-
1
],
-
1
)
inputs
=
tf
.
expand_dims
(
inputs
[:,
-
1
],
-
1
)
if
token_type_ids
is
not
None
:
token_type_ids
=
tf
.
expand_dims
(
token_type_ids
[:,
-
1
],
-
1
)
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
attention_mask
=
kwargs
.
get
(
"attention_mask"
,
None
)
return
{
"input_ids"
:
input_ids
,
"past_key_values"
:
past_key_values
,
"use_cache"
:
use_cache
}
if
attention_mask
is
not
None
and
position_ids
is
None
:
position_ids
=
tf
.
math
.
cumsum
(
attention_mask
,
axis
=-
1
,
exclusive
=
True
)
if
past_key_values
:
position_ids
=
tf
.
expand_dims
(
position_ids
[:,
-
1
],
-
1
)
return
{
"input_ids"
:
inputs
,
"attention_mask"
:
attention_mask
,
"position_ids"
:
position_ids
,
"past_key_values"
:
past_key_values
,
"use_cache"
:
use_cache
,
"token_type_ids"
:
token_type_ids
,
}
@
unpack_inputs
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
CTRL_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
CTRL_INPUTS_DOCSTRING
)
...
@@ -672,10 +687,9 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -672,10 +687,9 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
return_dict
=
return_dict
,
return_dict
=
return_dict
,
training
=
training
,
training
=
training
,
)
)
hidden_states
=
transformer_outputs
[
0
]
hidden_states
=
transformer_outputs
[
0
]
logits
=
tf
.
matmul
(
hidden_states
,
self
.
transformer
.
w
.
weights
,
transpose_b
=
True
)
logits
=
self
.
lm_head
(
hidden_state
s
)
logits
=
self
.
bias_layer
(
logit
s
)
loss
=
None
loss
=
None
if
labels
is
not
None
:
if
labels
is
not
None
:
...
...
tests/models/ctrl/test_modeling_tf_ctrl.py
View file @
4626df50
...
@@ -225,6 +225,7 @@ class TFCTRLModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
...
@@ -225,6 +225,7 @@ class TFCTRLModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
for
model_class
in
self
.
all_model_classes
:
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
model
=
model_class
(
config
)
model
.
build
()
# may be needed for the get_bias() call below
assert
isinstance
(
model
.
get_input_embeddings
(),
tf
.
keras
.
layers
.
Layer
)
assert
isinstance
(
model
.
get_input_embeddings
(),
tf
.
keras
.
layers
.
Layer
)
if
model_class
in
list_lm_models
:
if
model_class
in
list_lm_models
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment