Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4626df50
Unverified
Commit
4626df50
authored
Jun 14, 2023
by
Joao Gante
Committed by
GitHub
Jun 14, 2023
Browse files
TF: CTRL with native embedding layers (#23456)
parent
eac8dede
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
70 additions
and
55 deletions
+70
-55
src/transformers/models/ctrl/modeling_tf_ctrl.py
src/transformers/models/ctrl/modeling_tf_ctrl.py
+69
-55
tests/models/ctrl/test_modeling_tf_ctrl.py
tests/models/ctrl/test_modeling_tf_ctrl.py
+1
-0
No files found.
src/transformers/models/ctrl/modeling_tf_ctrl.py
View file @
4626df50
...
...
@@ -15,10 +15,8 @@
# limitations under the License.
""" TF 2.0 CTRL model."""
from
__future__
import
annotations
import
warnings
from
typing
import
Optional
,
Tuple
,
Union
import
numpy
as
np
...
...
@@ -30,7 +28,6 @@ from ...modeling_tf_utils import (
TFModelInputType
,
TFPreTrainedModel
,
TFSequenceClassificationLoss
,
TFSharedEmbeddings
,
get_initializer
,
keras_serializable
,
unpack_inputs
,
...
...
@@ -224,8 +221,11 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
self
.
pos_encoding
=
positional_encoding
(
config
.
n_positions
,
self
.
d_model_size
)
self
.
w
=
TFSharedEmbeddings
(
config
.
vocab_size
,
config
.
n_embd
,
initializer_range
=
config
.
initializer_range
,
name
=
"w"
self
.
w
=
tf
.
keras
.
layers
.
Embedding
(
input_dim
=
config
.
vocab_size
,
output_dim
=
config
.
n_embd
,
embeddings_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"w"
,
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
embd_pdrop
)
...
...
@@ -246,9 +246,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
def
get_input_embeddings
(
self
):
return
self
.
w
def
set_input_embeddings
(
self
,
value
):
self
.
w
.
weight
=
value
self
.
w
.
vocab_size
=
shape_list
(
value
)[
0
]
def
set_input_embeddings
(
self
,
new_embeddings
):
self
.
w
=
new_embeddings
def
_prune_heads
(
self
,
heads_to_prune
):
"""
...
...
@@ -308,7 +307,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask
=
tf
.
reshape
(
attention_mask
,
(
input_shape
[
0
],
1
,
1
,
input_shape
[
1
]))
attention_mask
=
tf
.
reshape
(
attention_mask
,
(
input_shape
[
0
],
1
,
1
,
input_shape
[
1
]
+
past_length
))
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
...
...
@@ -332,15 +331,15 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
if
token_type_ids
is
not
None
:
token_type_ids
=
tf
.
reshape
(
token_type_ids
,
[
-
1
,
shape_list
(
token_type_ids
)[
-
1
]])
token_type_embeds
=
self
.
w
(
token_type_ids
,
mode
=
"embedding"
)
token_type_embeds
=
self
.
w
(
token_type_ids
)
token_type_embeds
*=
tf
.
math
.
sqrt
(
tf
.
cast
(
self
.
d_model_size
,
dtype
=
token_type_embeds
.
dtype
))
else
:
token_type_embeds
=
tf
.
constant
(
0.0
)
position_ids
=
tf
.
reshape
(
position_ids
,
[
-
1
,
shape_list
(
position_ids
)[
-
1
]])
if
inputs_embeds
is
None
:
check_embeddings_within_bounds
(
input_ids
,
self
.
w
.
vocab_size
)
inputs_embeds
=
self
.
w
(
input_ids
,
mode
=
"embedding"
)
check_embeddings_within_bounds
(
input_ids
,
self
.
w
.
input_dim
)
inputs_embeds
=
self
.
w
(
input_ids
)
seq_len
=
input_shape
[
-
1
]
mask
=
1
-
tf
.
linalg
.
band_part
(
tf
.
ones
((
seq_len
,
seq_len
)),
-
1
,
0
)
...
...
@@ -565,39 +564,26 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
return
outputs
class
TFCTRLLMHead
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
config
,
input_embeddings
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
config
=
config
# CTRL has numerical issues in XLA generate
self
.
supports_xla_generation
=
False
class
TFCTRLBiasLayer
(
tf
.
keras
.
layers
.
Layer
):
"""
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
so all weights have to be registered in a layer.
"""
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self
.
input_embeddings
=
input_embeddings
def
__init__
(
self
,
shape
,
initializer
,
trainable
,
name
,
**
kwargs
):
super
().
__init__
(
name
=
name
,
**
kwargs
)
self
.
shape
=
shape
self
.
initializer
=
initializer
self
.
trainable
=
trainable
def
build
(
self
,
input_shape
=
None
):
self
.
bias
=
self
.
add_weight
(
shape
=
(
self
.
config
.
vocab_size
,),
initializer
=
"zeros"
,
trainable
=
True
,
name
=
"bias"
)
def
build
(
self
,
input_shape
):
self
.
bias
=
self
.
add_weight
(
name
=
"bias"
,
shape
=
self
.
shape
,
initializer
=
self
.
initializer
,
trainable
=
self
.
trainable
)
super
().
build
(
input_shape
)
def
get_output_embeddings
(
self
):
return
self
.
input_embeddings
def
set_output_embeddings
(
self
,
value
):
self
.
input_embeddings
.
weight
=
value
self
.
input_embeddings
.
vocab_size
=
shape_list
(
value
)[
0
]
def
get_bias
(
self
):
return
{
"bias"
:
self
.
bias
}
def
set_bias
(
self
,
value
):
self
.
bias
=
value
[
"bias"
]
self
.
config
.
vocab_size
=
shape_list
(
value
[
"bias"
])[
0
]
def
call
(
self
,
hidden_states
):
hidden_states
=
self
.
input_embeddings
(
hidden_states
,
mode
=
"linear"
)
hidden_states
=
hidden_states
+
self
.
bias
return
hidden_states
def
call
(
self
,
x
):
return
x
+
self
.
bias
@
add_start_docstrings
(
...
...
@@ -611,24 +597,53 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
().
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
transformer
=
TFCTRLMainLayer
(
config
,
name
=
"transformer"
)
self
.
bias_layer
=
TFCTRLBiasLayer
(
name
=
"lm_head"
,
shape
=
[
1
,
config
.
vocab_size
],
initializer
=
"zeros"
,
trainable
=
True
)
self
.
lm_head
=
TFCTRLLMHead
(
config
,
self
.
transformer
.
w
,
name
=
"lm_head"
)
# CTRL has numerical issues in XLA generate
self
.
supports_xla_generation
=
False
def
get_output_embeddings
(
self
):
return
self
.
get_input_embeddings
()
def
g
et_
lm_head
(
self
):
return
self
.
lm_head
def
s
et_
output_embeddings
(
self
,
value
):
self
.
set_input_embeddings
(
value
)
def
get_prefix_bias_name
(
self
):
warnings
.
warn
(
"The method get_prefix_bias_name is deprecated. Please use `get_bias` instead."
,
FutureWarning
)
return
self
.
name
+
"/"
+
self
.
lm_head
.
name
def
get_bias
(
self
):
return
{
"lm_head.bias"
:
self
.
bias_layer
.
bias
}
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
use_cache
=
None
,
**
kwargs
):
def
set_bias
(
self
,
value
):
# Replaces the existing layers containing bias for correct (de)serialization.
vocab_size
=
value
[
"lm_head.bias"
].
shape
[
-
1
]
self
.
bias_layer
=
TFCTRLBiasLayer
(
name
=
"final_logits_bias"
,
shape
=
[
1
,
vocab_size
],
initializer
=
"zeros"
,
trainable
=
True
)
self
.
bias_layer
.
build
(
None
)
self
.
bias_layer
.
bias
.
assign
(
value
[
"lm_head.bias"
])
# Copied from transformers.models.gpt2.modeling_tf_gpt2.TFGPT2LMHeadModel.prepare_inputs_for_generation
def
prepare_inputs_for_generation
(
self
,
inputs
,
past_key_values
=
None
,
use_cache
=
None
,
**
kwargs
):
token_type_ids
=
kwargs
.
get
(
"token_type_ids"
,
None
)
# only last token for inputs_ids if past is defined in kwargs
if
past_key_values
:
input_ids
=
tf
.
expand_dims
(
input_ids
[:,
-
1
],
-
1
)
inputs
=
tf
.
expand_dims
(
inputs
[:,
-
1
],
-
1
)
if
token_type_ids
is
not
None
:
token_type_ids
=
tf
.
expand_dims
(
token_type_ids
[:,
-
1
],
-
1
)
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
attention_mask
=
kwargs
.
get
(
"attention_mask"
,
None
)
return
{
"input_ids"
:
input_ids
,
"past_key_values"
:
past_key_values
,
"use_cache"
:
use_cache
}
if
attention_mask
is
not
None
and
position_ids
is
None
:
position_ids
=
tf
.
math
.
cumsum
(
attention_mask
,
axis
=-
1
,
exclusive
=
True
)
if
past_key_values
:
position_ids
=
tf
.
expand_dims
(
position_ids
[:,
-
1
],
-
1
)
return
{
"input_ids"
:
inputs
,
"attention_mask"
:
attention_mask
,
"position_ids"
:
position_ids
,
"past_key_values"
:
past_key_values
,
"use_cache"
:
use_cache
,
"token_type_ids"
:
token_type_ids
,
}
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
CTRL_INPUTS_DOCSTRING
)
...
...
@@ -672,10 +687,9 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
return_dict
=
return_dict
,
training
=
training
,
)
hidden_states
=
transformer_outputs
[
0
]
logits
=
self
.
lm_head
(
hidden_state
s
)
logits
=
tf
.
matmul
(
hidden_states
,
self
.
transformer
.
w
.
weights
,
transpose_b
=
True
)
logits
=
self
.
bias_layer
(
logit
s
)
loss
=
None
if
labels
is
not
None
:
...
...
tests/models/ctrl/test_modeling_tf_ctrl.py
View file @
4626df50
...
...
@@ -225,6 +225,7 @@ class TFCTRLModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
model
.
build
()
# may be needed for the get_bias() call below
assert
isinstance
(
model
.
get_input_embeddings
(),
tf
.
keras
.
layers
.
Layer
)
if
model_class
in
list_lm_models
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment