Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a1994a71
"...resnet50_tensorflow.git" did not exist on "476f8e62160d878045b7b7fdb8a292efefcf7427"
Commit
a1994a71
authored
Dec 05, 2019
by
Rémi Louf
Committed by
Julien Chaumond
Dec 09, 2019
Browse files
simplified model and configuration
parent
3a9a9f78
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
59 deletions
+10
-59
examples/summarization/configuration_bertabs.py
examples/summarization/configuration_bertabs.py
+0
-22
examples/summarization/modeling_bertabs.py
examples/summarization/modeling_bertabs.py
+7
-34
examples/summarization/run_summarization.py
examples/summarization/run_summarization.py
+3
-3
No files found.
examples/summarization/configuration_bertabs.py
View file @
a1994a71
...
@@ -33,17 +33,6 @@ class BertAbsConfig(PretrainedConfig):
...
@@ -33,17 +33,6 @@ class BertAbsConfig(PretrainedConfig):
r
""" Class to store the configuration of the BertAbs model.
r
""" Class to store the configuration of the BertAbs model.
Arguments:
Arguments:
temp_dir: string
Unused in the current situation. Kept for compatibility but will be removed.
finetune_bert: bool
Whether to fine-tune the model or not. Will be kept for reference
in case we want to add the possibility to fine-tune the model.
large: bool
Whether to use bert-large as a base.
share_emb: book
Whether the embeddings are shared between the encoder and decoder.
encoder: string
Not clear what this does. Leave to "bert" for pre-trained weights.
max_pos: int
max_pos: int
The maximum sequence length that this model will be used with.
The maximum sequence length that this model will be used with.
enc_layer: int
enc_layer: int
...
@@ -77,11 +66,6 @@ class BertAbsConfig(PretrainedConfig):
...
@@ -77,11 +66,6 @@ class BertAbsConfig(PretrainedConfig):
def
__init__
(
def
__init__
(
self
,
self
,
vocab_size_or_config_json_file
=
30522
,
vocab_size_or_config_json_file
=
30522
,
temp_dir
=
"."
,
finetune_bert
=
False
,
large
=
False
,
share_emb
=
True
,
encoder
=
"bert"
,
max_pos
=
512
,
max_pos
=
512
,
enc_layers
=
6
,
enc_layers
=
6
,
enc_hidden_size
=
512
,
enc_hidden_size
=
512
,
...
@@ -104,21 +88,15 @@ class BertAbsConfig(PretrainedConfig):
...
@@ -104,21 +88,15 @@ class BertAbsConfig(PretrainedConfig):
for
key
,
value
in
json_config
.
items
():
for
key
,
value
in
json_config
.
items
():
self
.
__dict__
[
key
]
=
value
self
.
__dict__
[
key
]
=
value
elif
isinstance
(
vocab_size_or_config_json_file
,
int
):
elif
isinstance
(
vocab_size_or_config_json_file
,
int
):
self
.
temp_dir
=
temp_dir
self
.
finetune_bert
=
finetune_bert
self
.
large
=
large
self
.
vocab_size
=
vocab_size_or_config_json_file
self
.
vocab_size
=
vocab_size_or_config_json_file
self
.
max_pos
=
max_pos
self
.
max_pos
=
max_pos
self
.
encoder
=
encoder
self
.
enc_layers
=
enc_layers
self
.
enc_layers
=
enc_layers
self
.
enc_hidden_size
=
enc_hidden_size
self
.
enc_hidden_size
=
enc_hidden_size
self
.
enc_heads
=
enc_heads
self
.
enc_heads
=
enc_heads
self
.
enc_ff_size
=
enc_ff_size
self
.
enc_ff_size
=
enc_ff_size
self
.
enc_dropout
=
enc_dropout
self
.
enc_dropout
=
enc_dropout
self
.
share_emb
=
share_emb
self
.
dec_layers
=
dec_layers
self
.
dec_layers
=
dec_layers
self
.
dec_hidden_size
=
dec_hidden_size
self
.
dec_hidden_size
=
dec_hidden_size
self
.
dec_heads
=
dec_heads
self
.
dec_heads
=
dec_heads
...
...
examples/summarization/modeling_bertabs.py
View file @
a1994a71
...
@@ -53,7 +53,7 @@ class BertAbs(BertAbsPreTrainedModel):
...
@@ -53,7 +53,7 @@ class BertAbs(BertAbsPreTrainedModel):
def
__init__
(
self
,
args
,
checkpoint
=
None
,
bert_extractive_checkpoint
=
None
):
def
__init__
(
self
,
args
,
checkpoint
=
None
,
bert_extractive_checkpoint
=
None
):
super
(
BertAbs
,
self
).
__init__
(
args
)
super
(
BertAbs
,
self
).
__init__
(
args
)
self
.
args
=
args
self
.
args
=
args
self
.
bert
=
Bert
(
args
.
large
,
args
.
temp_dir
,
args
.
finetune_bert
)
self
.
bert
=
Bert
()
# If pre-trained weights are passed for Bert, load these.
# If pre-trained weights are passed for Bert, load these.
load_bert_pretrained_extractive
=
True
if
bert_extractive_checkpoint
else
False
load_bert_pretrained_extractive
=
True
if
bert_extractive_checkpoint
else
False
...
@@ -69,18 +69,6 @@ class BertAbs(BertAbsPreTrainedModel):
...
@@ -69,18 +69,6 @@ class BertAbs(BertAbsPreTrainedModel):
strict
=
True
,
strict
=
True
,
)
)
if
args
.
encoder
==
"baseline"
:
bert_config
=
BertConfig
(
self
.
bert
.
model
.
config
.
vocab_size
,
hidden_size
=
args
.
enc_hidden_size
,
num_hidden_layers
=
args
.
enc_layers
,
num_attention_heads
=
8
,
intermediate_size
=
args
.
enc_ff_size
,
hidden_dropout_prob
=
args
.
enc_dropout
,
attention_probs_dropout_prob
=
args
.
enc_dropout
,
)
self
.
bert
.
model
=
BertModel
(
bert_config
)
self
.
vocab_size
=
self
.
bert
.
model
.
config
.
vocab_size
self
.
vocab_size
=
self
.
bert
.
model
.
config
.
vocab_size
if
args
.
max_pos
>
512
:
if
args
.
max_pos
>
512
:
...
@@ -101,7 +89,7 @@ class BertAbs(BertAbsPreTrainedModel):
...
@@ -101,7 +89,7 @@ class BertAbs(BertAbsPreTrainedModel):
tgt_embeddings
=
nn
.
Embedding
(
tgt_embeddings
=
nn
.
Embedding
(
self
.
vocab_size
,
self
.
bert
.
model
.
config
.
hidden_size
,
padding_idx
=
0
self
.
vocab_size
,
self
.
bert
.
model
.
config
.
hidden_size
,
padding_idx
=
0
)
)
if
self
.
args
.
share_emb
:
tgt_embeddings
.
weight
=
copy
.
deepcopy
(
tgt_embeddings
.
weight
=
copy
.
deepcopy
(
self
.
bert
.
model
.
embeddings
.
word_embeddings
.
weight
self
.
bert
.
model
.
embeddings
.
word_embeddings
.
weight
)
)
...
@@ -141,16 +129,6 @@ class BertAbs(BertAbsPreTrainedModel):
...
@@ -141,16 +129,6 @@ class BertAbs(BertAbsPreTrainedModel):
else
:
else
:
p
.
data
.
zero_
()
p
.
data
.
zero_
()
def
maybe_tie_embeddings
(
self
,
args
):
if
args
.
use_bert_emb
:
tgt_embeddings
=
nn
.
Embedding
(
self
.
vocab_size
,
self
.
bert
.
model
.
config
.
hidden_size
,
padding_idx
=
0
)
tgt_embeddings
.
weight
=
copy
.
deepcopy
(
self
.
bert
.
model
.
embeddings
.
word_embeddings
.
weight
)
self
.
decoder
.
embeddings
=
tgt_embeddings
def
forward
(
def
forward
(
self
,
self
,
encoder_input_ids
,
encoder_input_ids
,
...
@@ -178,14 +156,9 @@ class Bert(nn.Module):
...
@@ -178,14 +156,9 @@ class Bert(nn.Module):
""" This class is not really necessary and should probably disappear.
""" This class is not really necessary and should probably disappear.
"""
"""
def
__init__
(
self
,
large
,
temp_dir
,
finetune
=
False
):
def
__init__
(
self
):
super
(
Bert
,
self
).
__init__
()
super
(
Bert
,
self
).
__init__
()
if
large
:
self
.
model
=
BertModel
.
from_pretrained
(
"bert-base-uncased"
)
self
.
model
=
BertModel
.
from_pretrained
(
"bert-large-uncased"
,
cache_dir
=
temp_dir
)
else
:
self
.
model
=
BertModel
.
from_pretrained
(
"bert-base-uncased"
,
cache_dir
=
temp_dir
)
self
.
finetune
=
finetune
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
**
kwargs
):
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
**
kwargs
):
self
.
eval
()
self
.
eval
()
...
...
examples/summarization/run_summarization.py
View file @
a1994a71
...
@@ -31,9 +31,9 @@ Batch = namedtuple(
...
@@ -31,9 +31,9 @@ Batch = namedtuple(
def
evaluate
(
args
):
def
evaluate
(
args
):
tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-uncased"
,
do_lower_case
=
True
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-uncased"
,
do_lower_case
=
True
)
model
=
bertabs
=
BertAbs
.
from_pretrained
(
"bertabs-finetuned-cnndm"
)
model
=
BertAbs
.
from_pretrained
(
"bertabs-finetuned-cnndm"
)
bertabs
.
to
(
args
.
device
)
model
.
to
(
args
.
device
)
bertabs
.
eval
()
model
.
eval
()
symbols
=
{
symbols
=
{
"BOS"
:
tokenizer
.
vocab
[
"[unused0]"
],
"BOS"
:
tokenizer
.
vocab
[
"[unused0]"
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment