Commit 1b8613ac authored by thomwolf's avatar thomwolf
Browse files

updating t5 config class

parent 7140363e
...@@ -66,7 +66,7 @@ class T5Config(PretrainedConfig): ...@@ -66,7 +66,7 @@ class T5Config(PretrainedConfig):
pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(self,
vocab_size_or_config_json_file=32128, vocab_size=32128,
n_positions=512, n_positions=512,
d_model=512, d_model=512,
d_kv=64, d_kv=64,
...@@ -79,7 +79,7 @@ class T5Config(PretrainedConfig): ...@@ -79,7 +79,7 @@ class T5Config(PretrainedConfig):
initializer_factor=1.0, initializer_factor=1.0,
**kwargs): **kwargs):
super(T5Config, self).__init__(**kwargs) super(T5Config, self).__init__(**kwargs)
self.vocab_size = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, int) else -1 self.vocab_size = vocab_size
self.n_positions = n_positions self.n_positions = n_positions
self.d_model = d_model self.d_model = d_model
self.d_kv = d_kv self.d_kv = d_kv
...@@ -91,17 +91,6 @@ class T5Config(PretrainedConfig): ...@@ -91,17 +91,6 @@ class T5Config(PretrainedConfig):
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor self.initializer_factor = initializer_factor
if isinstance(vocab_size_or_config_json_file, six.string_types):
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif not isinstance(vocab_size_or_config_json_file, int):
raise ValueError(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
@property @property
def max_position_embeddings(self): def max_position_embeddings(self):
return self.n_positions return self.n_positions
......
...@@ -93,7 +93,7 @@ class T5ModelTest(CommonTestCases.CommonModelTester): ...@@ -93,7 +93,7 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
decoder_lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) decoder_lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
config = T5Config( config = T5Config(
vocab_size_or_config_json_file=self.vocab_size, vocab_size=self.vocab_size,
n_positions=self.n_positions, n_positions=self.n_positions,
d_model=self.hidden_size, d_model=self.hidden_size,
d_ff=self.d_ff, d_ff=self.d_ff,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment