Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1b8613ac
Commit
1b8613ac
authored
Dec 16, 2019
by
thomwolf
Browse files
updating t5 config class
parent
7140363e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
14 deletions
+3
-14
transformers/configuration_t5.py
transformers/configuration_t5.py
+2
-13
transformers/tests/modeling_t5_test.py
transformers/tests/modeling_t5_test.py
+1
-1
No files found.
transformers/configuration_t5.py
View file @
1b8613ac
...
...
@@ -66,7 +66,7 @@ class T5Config(PretrainedConfig):
pretrained_config_archive_map
=
T5_PRETRAINED_CONFIG_ARCHIVE_MAP
def
__init__
(
self
,
vocab_size
_or_config_json_file
=
32128
,
vocab_size
=
32128
,
n_positions
=
512
,
d_model
=
512
,
d_kv
=
64
,
...
...
@@ -79,7 +79,7 @@ class T5Config(PretrainedConfig):
initializer_factor
=
1.0
,
**
kwargs
):
super
(
T5Config
,
self
).
__init__
(
**
kwargs
)
self
.
vocab_size
=
vocab_size
_or_config_json_file
if
isinstance
(
vocab_size_or_config_json_file
,
int
)
else
-
1
self
.
vocab_size
=
vocab_size
self
.
n_positions
=
n_positions
self
.
d_model
=
d_model
self
.
d_kv
=
d_kv
...
...
@@ -91,17 +91,6 @@ class T5Config(PretrainedConfig):
self
.
layer_norm_epsilon
=
layer_norm_epsilon
self
.
initializer_factor
=
initializer_factor
if
isinstance
(
vocab_size_or_config_json_file
,
six
.
string_types
):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
json_config
=
json
.
loads
(
reader
.
read
())
for
key
,
value
in
json_config
.
items
():
self
.
__dict__
[
key
]
=
value
elif
not
isinstance
(
vocab_size_or_config_json_file
,
int
):
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
@
property
def
max_position_embeddings
(
self
):
return
self
.
n_positions
...
...
transformers/tests/modeling_t5_test.py
View file @
1b8613ac
...
...
@@ -93,7 +93,7 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
decoder_lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
decoder_seq_length
],
self
.
vocab_size
)
config
=
T5Config
(
vocab_size
_or_config_json_file
=
self
.
vocab_size
,
vocab_size
=
self
.
vocab_size
,
n_positions
=
self
.
n_positions
,
d_model
=
self
.
hidden_size
,
d_ff
=
self
.
d_ff
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment