Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c56d921d
"docs/vscode:/vscode.git/clone" did not exist on "b50e2035f1a46c663945915156a6016dc845083e"
Commit
c56d921d
authored
Oct 09, 2019
by
thomwolf
Browse files
adding TF 2.0 model
parent
45dc04f3
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
426 additions
and
209 deletions
+426
-209
transformers/configuration_ctrl.py
transformers/configuration_ctrl.py
+2
-2
transformers/modeling_ctrl.py
transformers/modeling_ctrl.py
+1
-1
transformers/modeling_roberta.py
transformers/modeling_roberta.py
+2
-1
transformers/modeling_tf_ctrl.py
transformers/modeling_tf_ctrl.py
+219
-204
transformers/tests/modeling_tf_ctrl_test.py
transformers/tests/modeling_tf_ctrl_test.py
+201
-0
transformers/tests/modeling_tf_gpt2_test.py
transformers/tests/modeling_tf_gpt2_test.py
+1
-1
No files found.
transformers/configuration_ctrl.py
View file @
c56d921d
...
@@ -53,8 +53,8 @@ class CTRLConfig(PretrainedConfig):
...
@@ -53,8 +53,8 @@ class CTRLConfig(PretrainedConfig):
def
__init__
(
def
__init__
(
self
,
self
,
vocab_size_or_config_json_file
=
246534
,
vocab_size_or_config_json_file
=
246534
,
n_positions
=
50000
,
n_positions
=
256
,
n_ctx
=
512
,
n_ctx
=
256
,
n_embd
=
1280
,
n_embd
=
1280
,
dff
=
8192
,
dff
=
8192
,
n_layer
=
48
,
n_layer
=
48
,
...
...
transformers/modeling_ctrl.py
View file @
c56d921d
...
@@ -351,7 +351,7 @@ class CTRLModel(CTRLPreTrainedModel):
...
@@ -351,7 +351,7 @@ class CTRLModel(CTRLPreTrainedModel):
x
=
self
.
w
(
input_ids
)
x
=
self
.
w
(
input_ids
)
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len
=
input_ids
.
shape
[
1
]
seq_len
=
input_ids
.
shape
[
-
1
]
mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
),
1
).
to
(
x
.
device
)
mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
),
1
).
to
(
x
.
device
)
x
*=
np
.
sqrt
(
self
.
d_model_size
)
x
*=
np
.
sqrt
(
self
.
d_model_size
)
...
...
transformers/modeling_roberta.py
View file @
c56d921d
...
@@ -172,7 +172,8 @@ class RobertaModel(BertModel):
...
@@ -172,7 +172,8 @@ class RobertaModel(BertModel):
if
input_ids
[:,
0
].
sum
().
item
()
!=
0
:
if
input_ids
[:,
0
].
sum
().
item
()
!=
0
:
logger
.
warning
(
"A sequence with no special tokens has been passed to the RoBERTa model. "
logger
.
warning
(
"A sequence with no special tokens has been passed to the RoBERTa model. "
"This model requires special tokens in order to work. "
"This model requires special tokens in order to work. "
"Please specify add_special_tokens=True in your encoding."
)
"Please specify add_special_tokens=True in your tokenize.encode()"
"or tokenizer.convert_tokens_to_ids()."
)
return
super
(
RobertaModel
,
self
).
forward
(
input_ids
,
return
super
(
RobertaModel
,
self
).
forward
(
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
...
...
transformers/modeling_tf_ctrl.py
View file @
c56d921d
This diff is collapsed.
Click to expand it.
transformers/tests/modeling_tf_ctrl_test.py
0 → 100644
View file @
c56d921d
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
unittest
import
shutil
import
pytest
import
sys
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
transformers
import
CTRLConfig
,
is_tf_available
if
is_tf_available
():
import
tensorflow
as
tf
from
transformers.modeling_tf_ctrl
import
(
TFCTRLModel
,
TFCTRLLMHeadModel
,
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
)
else
:
pytestmark
=
pytest
.
mark
.
skip
(
"Require TensorFlow"
)
class
TFCTRLModelTest
(
TFCommonTestCases
.
TFCommonModelTester
):
all_model_classes
=
(
TFCTRLModel
,
TFCTRLLMHeadModel
)
if
is_tf_available
()
else
()
class
TFCTRLModelTester
(
object
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
is_training
=
True
,
use_token_type_ids
=
True
,
use_input_mask
=
True
,
use_labels
=
True
,
use_mc_token_ids
=
True
,
vocab_size
=
99
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
num_attention_heads
=
4
,
intermediate_size
=
37
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
type_vocab_size
=
16
,
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
num_labels
=
3
,
num_choices
=
4
,
scope
=
None
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
is_training
=
is_training
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_input_mask
=
use_input_mask
self
.
use_labels
=
use_labels
self
.
use_mc_token_ids
=
use_mc_token_ids
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
intermediate_size
=
intermediate_size
self
.
hidden_act
=
hidden_act
self
.
hidden_dropout_prob
=
hidden_dropout_prob
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
self
.
num_labels
=
num_labels
self
.
num_choices
=
num_choices
self
.
scope
=
scope
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_mask
=
None
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
vocab_size
=
2
)
token_type_ids
=
None
if
self
.
use_token_type_ids
:
token_type_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
mc_token_ids
=
None
if
self
.
use_mc_token_ids
:
mc_token_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
num_choices
],
self
.
seq_length
)
sequence_labels
=
None
token_labels
=
None
choice_labels
=
None
if
self
.
use_labels
:
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
CTRLConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
n_embd
=
self
.
hidden_size
,
n_layer
=
self
.
num_hidden_layers
,
n_head
=
self
.
num_attention_heads
,
# intermediate_size=self.intermediate_size,
# hidden_act=self.hidden_act,
# hidden_dropout_prob=self.hidden_dropout_prob,
# attention_probs_dropout_prob=self.attention_probs_dropout_prob,
n_positions
=
self
.
max_position_embeddings
,
n_ctx
=
self
.
max_position_embeddings
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range
)
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
return
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
mc_token_ids
,
sequence_labels
,
token_labels
,
choice_labels
def
create_and_check_ctrl_model
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
model
=
TFCTRLModel
(
config
=
config
)
inputs
=
{
'input_ids'
:
input_ids
,
'attention_mask'
:
input_mask
,
'token_type_ids'
:
token_type_ids
}
sequence_output
=
model
(
inputs
)[
0
]
inputs
=
[
input_ids
,
None
,
input_mask
]
# None is the input for 'past'
sequence_output
=
model
(
inputs
)[
0
]
sequence_output
=
model
(
input_ids
)[
0
]
result
=
{
"sequence_output"
:
sequence_output
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"sequence_output"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
def
create_and_check_ctrl_lm_head
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
model
=
TFCTRLLMHeadModel
(
config
=
config
)
inputs
=
{
'input_ids'
:
input_ids
,
'attention_mask'
:
input_mask
,
'token_type_ids'
:
token_type_ids
}
prediction_scores
=
model
(
inputs
)[
0
]
result
=
{
"prediction_scores"
:
prediction_scores
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
mc_token_ids
,
sequence_labels
,
token_labels
,
choice_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'attention_mask'
:
input_mask
}
return
config
,
inputs_dict
def
setUp
(
self
):
self
.
model_tester
=
TFCTRLModelTest
.
TFCTRLModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
CTRLConfig
,
n_embd
=
37
)
def
test_config
(
self
):
self
.
config_tester
.
run_common_tests
()
def
test_ctrl_model
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_ctrl_model
(
*
config_and_inputs
)
def
test_ctrl_lm_head
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_ctrl_lm_head
(
*
config_and_inputs
)
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/transformers_test/"
for
model_name
in
list
(
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
TFCTRLModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
unittest
.
main
()
transformers/tests/modeling_tf_gpt2_test.py
View file @
c56d921d
...
@@ -222,7 +222,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -222,7 +222,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
@
pytest
.
mark
.
slow
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/transformers_test/"
cache_dir
=
"/tmp/transformers_test/"
for
model_name
in
list
(
TF_
gpt
2_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
for
model_name
in
list
(
TF_
GPT
2_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
TFGPT2Model
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
model
=
TFGPT2Model
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment