Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1b8613ac
Commit
1b8613ac
authored
Dec 16, 2019
by
thomwolf
Browse files
updating t5 config class
parent
7140363e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
14 deletions
+3
-14
transformers/configuration_t5.py
transformers/configuration_t5.py
+2
-13
transformers/tests/modeling_t5_test.py
transformers/tests/modeling_t5_test.py
+1
-1
No files found.
transformers/configuration_t5.py
View file @
1b8613ac
...
...
@@ -66,7 +66,7 @@ class T5Config(PretrainedConfig):
pretrained_config_archive_map
=
T5_PRETRAINED_CONFIG_ARCHIVE_MAP
def
__init__
(
self
,
vocab_size
_or_config_json_file
=
32128
,
vocab_size
=
32128
,
n_positions
=
512
,
d_model
=
512
,
d_kv
=
64
,
...
...
@@ -79,7 +79,7 @@ class T5Config(PretrainedConfig):
initializer_factor
=
1.0
,
**
kwargs
):
super
(
T5Config
,
self
).
__init__
(
**
kwargs
)
self
.
vocab_size
=
vocab_size
_or_config_json_file
if
isinstance
(
vocab_size_or_config_json_file
,
int
)
else
-
1
self
.
vocab_size
=
vocab_size
self
.
n_positions
=
n_positions
self
.
d_model
=
d_model
self
.
d_kv
=
d_kv
...
...
@@ -91,17 +91,6 @@ class T5Config(PretrainedConfig):
self
.
layer_norm_epsilon
=
layer_norm_epsilon
self
.
initializer_factor
=
initializer_factor
if
isinstance
(
vocab_size_or_config_json_file
,
six
.
string_types
):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
json_config
=
json
.
loads
(
reader
.
read
())
for
key
,
value
in
json_config
.
items
():
self
.
__dict__
[
key
]
=
value
elif
not
isinstance
(
vocab_size_or_config_json_file
,
int
):
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
@
property
def
max_position_embeddings
(
self
):
return
self
.
n_positions
...
...
transformers/tests/modeling_t5_test.py
View file @
1b8613ac
...
...
@@ -93,7 +93,7 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
decoder_lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
decoder_seq_length
],
self
.
vocab_size
)
config
=
T5Config
(
vocab_size
_or_config_json_file
=
self
.
vocab_size
,
vocab_size
=
self
.
vocab_size
,
n_positions
=
self
.
n_positions
,
d_model
=
self
.
hidden_size
,
d_ff
=
self
.
d_ff
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment