Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4626df50
Unverified
Commit
4626df50
authored
Jun 14, 2023
by
Joao Gante
Committed by
GitHub
Jun 14, 2023
Browse files
TF: CTRL with native embedding layers (#23456)
parent
eac8dede
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
70 additions
and
55 deletions
+70
-55
src/transformers/models/ctrl/modeling_tf_ctrl.py
src/transformers/models/ctrl/modeling_tf_ctrl.py
+69
-55
tests/models/ctrl/test_modeling_tf_ctrl.py
tests/models/ctrl/test_modeling_tf_ctrl.py
+1
-0
No files found.
src/transformers/models/ctrl/modeling_tf_ctrl.py
View file @
4626df50
...
@@ -15,10 +15,8 @@
...
@@ -15,10 +15,8 @@
# limitations under the License.
# limitations under the License.
""" TF 2.0 CTRL model."""
""" TF 2.0 CTRL model."""
from
__future__
import
annotations
from
__future__
import
annotations
import
warnings
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -30,7 +28,6 @@ from ...modeling_tf_utils import (
...
@@ -30,7 +28,6 @@ from ...modeling_tf_utils import (
TFModelInputType
,
TFModelInputType
,
TFPreTrainedModel
,
TFPreTrainedModel
,
TFSequenceClassificationLoss
,
TFSequenceClassificationLoss
,
TFSharedEmbeddings
,
get_initializer
,
get_initializer
,
keras_serializable
,
keras_serializable
,
unpack_inputs
,
unpack_inputs
,
...
@@ -224,8 +221,11 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
...
@@ -224,8 +221,11 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
self
.
pos_encoding
=
positional_encoding
(
config
.
n_positions
,
self
.
d_model_size
)
self
.
pos_encoding
=
positional_encoding
(
config
.
n_positions
,
self
.
d_model_size
)
self
.
w
=
TFSharedEmbeddings
(
self
.
w
=
tf
.
keras
.
layers
.
Embedding
(
config
.
vocab_size
,
config
.
n_embd
,
initializer_range
=
config
.
initializer_range
,
name
=
"w"
input_dim
=
config
.
vocab_size
,
output_dim
=
config
.
n_embd
,
embeddings_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"w"
,
)
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
embd_pdrop
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
embd_pdrop
)
...
@@ -246,9 +246,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
...
@@ -246,9 +246,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
def
get_input_embeddings
(
self
):
def
get_input_embeddings
(
self
):
return
self
.
w
return
self
.
w
def
set_input_embeddings
(
self
,
value
):
def
set_input_embeddings
(
self
,
new_embeddings
):
self
.
w
.
weight
=
value
self
.
w
=
new_embeddings
self
.
w
.
vocab_size
=
shape_list
(
value
)[
0
]
def
_prune_heads
(
self
,
heads_to_prune
):
def
_prune_heads
(
self
,
heads_to_prune
):
"""
"""
...
@@ -308,7 +307,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
...
@@ -308,7 +307,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask
=
tf
.
reshape
(
attention_mask
,
(
input_shape
[
0
],
1
,
1
,
input_shape
[
1
]))
attention_mask
=
tf
.
reshape
(
attention_mask
,
(
input_shape
[
0
],
1
,
1
,
input_shape
[
1
]
+
past_length
))
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
...
@@ -332,15 +331,15 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
...
@@ -332,15 +331,15 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
if
token_type_ids
is
not
None
:
if
token_type_ids
is
not
None
:
token_type_ids
=
tf
.
reshape
(
token_type_ids
,
[
-
1
,
shape_list
(
token_type_ids
)[
-
1
]])
token_type_ids
=
tf
.
reshape
(
token_type_ids
,
[
-
1
,
shape_list
(
token_type_ids
)[
-
1
]])
token_type_embeds
=
self
.
w
(
token_type_ids
,
mode
=
"embedding"
)
token_type_embeds
=
self
.
w
(
token_type_ids
)
token_type_embeds
*=
tf
.
math
.
sqrt
(
tf
.
cast
(
self
.
d_model_size
,
dtype
=
token_type_embeds
.
dtype
))
token_type_embeds
*=
tf
.
math
.
sqrt
(
tf
.
cast
(
self
.
d_model_size
,
dtype
=
token_type_embeds
.
dtype
))
else
:
else
:
token_type_embeds
=
tf
.
constant
(
0.0
)
token_type_embeds
=
tf
.
constant
(
0.0
)
position_ids
=
tf
.
reshape
(
position_ids
,
[
-
1
,
shape_list
(
position_ids
)[
-
1
]])
position_ids
=
tf
.
reshape
(
position_ids
,
[
-
1
,
shape_list
(
position_ids
)[
-
1
]])
if
inputs_embeds
is
None
:
if
inputs_embeds
is
None
:
check_embeddings_within_bounds
(
input_ids
,
self
.
w
.
vocab_size
)
check_embeddings_within_bounds
(
input_ids
,
self
.
w
.
input_dim
)
inputs_embeds
=
self
.
w
(
input_ids
,
mode
=
"embedding"
)
inputs_embeds
=
self
.
w
(
input_ids
)
seq_len
=
input_shape
[
-
1
]
seq_len
=
input_shape
[
-
1
]
mask
=
1
-
tf
.
linalg
.
band_part
(
tf
.
ones
((
seq_len
,
seq_len
)),
-
1
,
0
)
mask
=
1
-
tf
.
linalg
.
band_part
(
tf
.
ones
((
seq_len
,
seq_len
)),
-
1
,
0
)
...
@@ -565,39 +564,26 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
...
@@ -565,39 +564,26 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
return
outputs
return
outputs
class
TFCTRLLMHead
(
tf
.
keras
.
layers
.
Layer
):
class
TFCTRLBiasLayer
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
config
,
input_embeddings
,
**
kwargs
):
"""
super
().
__init__
(
**
kwargs
)
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
self
.
config
=
config
so all weights have to be registered in a layer.
# CTRL has numerical issues in XLA generate
"""
self
.
supports_xla_generation
=
False
# The output weights are the same as the input embeddings, but there is
def
__init__
(
self
,
shape
,
initializer
,
trainable
,
name
,
**
kwargs
):
# an output-only bias for each token.
super
().
__init__
(
name
=
name
,
**
kwargs
)
self
.
input_embeddings
=
input_embeddings
self
.
shape
=
shape
self
.
initializer
=
initializer
self
.
trainable
=
trainable
def
build
(
self
,
input_shape
=
None
):
def
build
(
self
,
input_shape
):
self
.
bias
=
self
.
add_weight
(
shape
=
(
self
.
config
.
vocab_size
,),
initializer
=
"zeros"
,
trainable
=
True
,
name
=
"bias"
)
self
.
bias
=
self
.
add_weight
(
name
=
"bias"
,
shape
=
self
.
shape
,
initializer
=
self
.
initializer
,
trainable
=
self
.
trainable
)
super
().
build
(
input_shape
)
super
().
build
(
input_shape
)
def
get_output_embeddings
(
self
):
def
call
(
self
,
x
):
return
self
.
input_embeddings
return
x
+
self
.
bias
def
set_output_embeddings
(
self
,
value
):
self
.
input_embeddings
.
weight
=
value
self
.
input_embeddings
.
vocab_size
=
shape_list
(
value
)[
0
]
def
get_bias
(
self
):
return
{
"bias"
:
self
.
bias
}
def
set_bias
(
self
,
value
):
self
.
bias
=
value
[
"bias"
]
self
.
config
.
vocab_size
=
shape_list
(
value
[
"bias"
])[
0
]
def
call
(
self
,
hidden_states
):
hidden_states
=
self
.
input_embeddings
(
hidden_states
,
mode
=
"linear"
)
hidden_states
=
hidden_states
+
self
.
bias
return
hidden_states
@
add_start_docstrings
(
@
add_start_docstrings
(
...
@@ -611,24 +597,53 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -611,24 +597,53 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
().
__init__
(
config
,
*
inputs
,
**
kwargs
)
super
().
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
transformer
=
TFCTRLMainLayer
(
config
,
name
=
"transformer"
)
self
.
transformer
=
TFCTRLMainLayer
(
config
,
name
=
"transformer"
)
self
.
bias_layer
=
TFCTRLBiasLayer
(
name
=
"lm_head"
,
shape
=
[
1
,
config
.
vocab_size
],
initializer
=
"zeros"
,
trainable
=
True
)
self
.
lm_head
=
TFCTRLLMHead
(
config
,
self
.
transformer
.
w
,
name
=
"lm_head"
)
def
get_output_embeddings
(
self
):
# CTRL has numerical issues in XLA generate
return
self
.
get_input_embeddings
()
self
.
supports_xla_generation
=
False
def
g
et_
lm_head
(
self
):
def
s
et_
output_embeddings
(
self
,
value
):
return
self
.
lm_head
self
.
set_input_embeddings
(
value
)
def
get_prefix_bias_name
(
self
):
def
get_bias
(
self
):
warnings
.
warn
(
"The method get_prefix_bias_name is deprecated. Please use `get_bias` instead."
,
FutureWarning
)
return
{
"lm_head.bias"
:
self
.
bias_layer
.
bias
}
return
self
.
name
+
"/"
+
self
.
lm_head
.
name
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
use_cache
=
None
,
**
kwargs
):
def
set_bias
(
self
,
value
):
# Replaces the existing layers containing bias for correct (de)serialization.
vocab_size
=
value
[
"lm_head.bias"
].
shape
[
-
1
]
self
.
bias_layer
=
TFCTRLBiasLayer
(
name
=
"final_logits_bias"
,
shape
=
[
1
,
vocab_size
],
initializer
=
"zeros"
,
trainable
=
True
)
self
.
bias_layer
.
build
(
None
)
self
.
bias_layer
.
bias
.
assign
(
value
[
"lm_head.bias"
])
# Copied from transformers.models.gpt2.modeling_tf_gpt2.TFGPT2LMHeadModel.prepare_inputs_for_generation
def
prepare_inputs_for_generation
(
self
,
inputs
,
past_key_values
=
None
,
use_cache
=
None
,
**
kwargs
):
token_type_ids
=
kwargs
.
get
(
"token_type_ids"
,
None
)
# only last token for inputs_ids if past is defined in kwargs
# only last token for inputs_ids if past is defined in kwargs
if
past_key_values
:
if
past_key_values
:
input_ids
=
tf
.
expand_dims
(
input_ids
[:,
-
1
],
-
1
)
inputs
=
tf
.
expand_dims
(
inputs
[:,
-
1
],
-
1
)
if
token_type_ids
is
not
None
:
token_type_ids
=
tf
.
expand_dims
(
token_type_ids
[:,
-
1
],
-
1
)
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
attention_mask
=
kwargs
.
get
(
"attention_mask"
,
None
)
return
{
"input_ids"
:
input_ids
,
"past_key_values"
:
past_key_values
,
"use_cache"
:
use_cache
}
if
attention_mask
is
not
None
and
position_ids
is
None
:
position_ids
=
tf
.
math
.
cumsum
(
attention_mask
,
axis
=-
1
,
exclusive
=
True
)
if
past_key_values
:
position_ids
=
tf
.
expand_dims
(
position_ids
[:,
-
1
],
-
1
)
return
{
"input_ids"
:
inputs
,
"attention_mask"
:
attention_mask
,
"position_ids"
:
position_ids
,
"past_key_values"
:
past_key_values
,
"use_cache"
:
use_cache
,
"token_type_ids"
:
token_type_ids
,
}
@
unpack_inputs
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
CTRL_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
CTRL_INPUTS_DOCSTRING
)
...
@@ -672,10 +687,9 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -672,10 +687,9 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
return_dict
=
return_dict
,
return_dict
=
return_dict
,
training
=
training
,
training
=
training
,
)
)
hidden_states
=
transformer_outputs
[
0
]
hidden_states
=
transformer_outputs
[
0
]
logits
=
tf
.
matmul
(
hidden_states
,
self
.
transformer
.
w
.
weights
,
transpose_b
=
True
)
logits
=
self
.
lm_head
(
hidden_state
s
)
logits
=
self
.
bias_layer
(
logit
s
)
loss
=
None
loss
=
None
if
labels
is
not
None
:
if
labels
is
not
None
:
...
...
tests/models/ctrl/test_modeling_tf_ctrl.py
View file @
4626df50
...
@@ -225,6 +225,7 @@ class TFCTRLModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
...
@@ -225,6 +225,7 @@ class TFCTRLModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
for
model_class
in
self
.
all_model_classes
:
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
model
=
model_class
(
config
)
model
.
build
()
# may be needed for the get_bias() call below
assert
isinstance
(
model
.
get_input_embeddings
(),
tf
.
keras
.
layers
.
Layer
)
assert
isinstance
(
model
.
get_input_embeddings
(),
tf
.
keras
.
layers
.
Layer
)
if
model_class
in
list_lm_models
:
if
model_class
in
list_lm_models
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment