Commit 7140363e authored by thomwolf's avatar thomwolf
Browse files

update bertabs

parent a52d56c8
...@@ -33,6 +33,8 @@ class BertAbsConfig(PretrainedConfig): ...@@ -33,6 +33,8 @@ class BertAbsConfig(PretrainedConfig):
r""" Class to store the configuration of the BertAbs model. r""" Class to store the configuration of the BertAbs model.
Arguments: Arguments:
vocab_size: int
Number of tokens in the vocabulary.
max_pos: int max_pos: int
The maximum sequence length that this model will be used with. The maximum sequence length that this model will be used with.
enc_layer: int enc_layer: int
...@@ -81,39 +83,17 @@ class BertAbsConfig(PretrainedConfig): ...@@ -81,39 +83,17 @@ class BertAbsConfig(PretrainedConfig):
): ):
super(BertAbsConfig, self).__init__(**kwargs) super(BertAbsConfig, self).__init__(**kwargs)
if self._input_is_path_to_json(vocab_size): self.vocab_size = vocab_size
path_to_json = vocab_size self.max_pos = max_pos
with open(path_to_json, "r", encoding="utf-8") as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size, int):
self.vocab_size = vocab_size
self.max_pos = max_pos
self.enc_layers = enc_layers self.enc_layers = enc_layers
self.enc_hidden_size = enc_hidden_size self.enc_hidden_size = enc_hidden_size
self.enc_heads = enc_heads self.enc_heads = enc_heads
self.enc_ff_size = enc_ff_size self.enc_ff_size = enc_ff_size
self.enc_dropout = enc_dropout self.enc_dropout = enc_dropout
self.dec_layers = dec_layers self.dec_layers = dec_layers
self.dec_hidden_size = dec_hidden_size self.dec_hidden_size = dec_hidden_size
self.dec_heads = dec_heads self.dec_heads = dec_heads
self.dec_ff_size = dec_ff_size self.dec_ff_size = dec_ff_size
self.dec_dropout = dec_dropout self.dec_dropout = dec_dropout
else:
raise ValueError(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
def _input_is_path_to_json(self, first_argument):
""" Checks whether the first argument passed to config
is the path to a JSON file that contains the config.
"""
is_python_2 = sys.version_info[0] == 2
if is_python_2:
return isinstance(first_argument, unicode)
else:
return isinstance(first_argument, str)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment