Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7140363e
Commit
7140363e
authored
Dec 14, 2019
by
thomwolf
Browse files
update bertabs
parent
a52d56c8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
34 deletions
+14
-34
examples/summarization/configuration_bertabs.py
examples/summarization/configuration_bertabs.py
+14
-34
No files found.
examples/summarization/configuration_bertabs.py
View file @
7140363e
...
...
@@ -33,6 +33,8 @@ class BertAbsConfig(PretrainedConfig):
r
""" Class to store the configuration of the BertAbs model.
Arguments:
vocab_size: int
Number of tokens in the vocabulary.
max_pos: int
The maximum sequence length that this model will be used with.
enc_layer: int
...
...
@@ -81,39 +83,17 @@ class BertAbsConfig(PretrainedConfig):
):
super
(
BertAbsConfig
,
self
).
__init__
(
**
kwargs
)
if
self
.
_input_is_path_to_json
(
vocab_size
):
path_to_json
=
vocab_size
with
open
(
path_to_json
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
json_config
=
json
.
loads
(
reader
.
read
())
for
key
,
value
in
json_config
.
items
():
self
.
__dict__
[
key
]
=
value
elif
isinstance
(
vocab_size
,
int
):
self
.
vocab_size
=
vocab_size
self
.
max_pos
=
max_pos
self
.
vocab_size
=
vocab_size
self
.
max_pos
=
max_pos
self
.
enc_layers
=
enc_layers
self
.
enc_hidden_size
=
enc_hidden_size
self
.
enc_heads
=
enc_heads
self
.
enc_ff_size
=
enc_ff_size
self
.
enc_dropout
=
enc_dropout
self
.
enc_layers
=
enc_layers
self
.
enc_hidden_size
=
enc_hidden_size
self
.
enc_heads
=
enc_heads
self
.
enc_ff_size
=
enc_ff_size
self
.
enc_dropout
=
enc_dropout
self
.
dec_layers
=
dec_layers
self
.
dec_hidden_size
=
dec_hidden_size
self
.
dec_heads
=
dec_heads
self
.
dec_ff_size
=
dec_ff_size
self
.
dec_dropout
=
dec_dropout
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
def
_input_is_path_to_json
(
self
,
first_argument
):
""" Checks whether the first argument passed to config
is the path to a JSON file that contains the config.
"""
is_python_2
=
sys
.
version_info
[
0
]
==
2
if
is_python_2
:
return
isinstance
(
first_argument
,
unicode
)
else
:
return
isinstance
(
first_argument
,
str
)
self
.
dec_layers
=
dec_layers
self
.
dec_hidden_size
=
dec_hidden_size
self
.
dec_heads
=
dec_heads
self
.
dec_ff_size
=
dec_ff_size
self
.
dec_dropout
=
dec_dropout
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment