Commit 1b8613ac authored by thomwolf's avatar thomwolf
Browse files

updating t5 config class

parent 7140363e
......@@ -66,7 +66,7 @@ class T5Config(PretrainedConfig):
pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=32128,
vocab_size=32128,
n_positions=512,
d_model=512,
d_kv=64,
......@@ -79,7 +79,7 @@ class T5Config(PretrainedConfig):
initializer_factor=1.0,
**kwargs):
super(T5Config, self).__init__(**kwargs)
self.vocab_size = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, int) else -1
self.vocab_size = vocab_size
self.n_positions = n_positions
self.d_model = d_model
self.d_kv = d_kv
......@@ -91,17 +91,6 @@ class T5Config(PretrainedConfig):
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor
if isinstance(vocab_size_or_config_json_file, six.string_types):
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif not isinstance(vocab_size_or_config_json_file, int):
raise ValueError(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
@property
def max_position_embeddings(self):
return self.n_positions
......
......@@ -93,7 +93,7 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
decoder_lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
config = T5Config(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
n_positions=self.n_positions,
d_model=self.hidden_size,
d_ff=self.d_ff,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment