Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4132a028
Unverified
Commit
4132a028
authored
Nov 17, 2018
by
Thomas Wolf
Committed by
GitHub
Nov 17, 2018
Browse files
Merge pull request #29 from huggingface/first-release
First release
parents
02173a1a
47a7d4ec
Changes
22
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
40 deletions
+27
-40
tests/optimization_test.py
tests/optimization_test.py
+2
-2
tests/tokenization_test.py
tests/tokenization_test.py
+25
-38
No files found.
tests/optimization_test.py
View file @
4132a028
...
@@ -20,7 +20,7 @@ import unittest
...
@@ -20,7 +20,7 @@ import unittest
import
torch
import
torch
import
optimization
from
pytorch_pretrained_bert
import
BertAdam
class
OptimizationTest
(
unittest
.
TestCase
):
class
OptimizationTest
(
unittest
.
TestCase
):
...
@@ -34,7 +34,7 @@ class OptimizationTest(unittest.TestCase):
...
@@ -34,7 +34,7 @@ class OptimizationTest(unittest.TestCase):
target
=
torch
.
tensor
([
0.4
,
0.2
,
-
0.5
])
target
=
torch
.
tensor
([
0.4
,
0.2
,
-
0.5
])
criterion
=
torch
.
nn
.
MSELoss
(
reduction
=
'elementwise_mean'
)
criterion
=
torch
.
nn
.
MSELoss
(
reduction
=
'elementwise_mean'
)
# No warmup, constant schedule, no gradient clipping
# No warmup, constant schedule, no gradient clipping
optimizer
=
optimization
.
BERT
Adam
(
params
=
[
w
],
lr
=
2e-1
,
optimizer
=
Bert
Adam
(
params
=
[
w
],
lr
=
2e-1
,
weight_decay_rate
=
0.0
,
weight_decay_rate
=
0.0
,
max_grad_norm
=-
1
)
max_grad_norm
=-
1
)
for
_
in
range
(
100
):
for
_
in
range
(
100
):
...
...
tests/tokenization_test.py
View file @
4132a028
...
@@ -19,7 +19,8 @@ from __future__ import print_function
...
@@ -19,7 +19,8 @@ from __future__ import print_function
import
os
import
os
import
unittest
import
unittest
import
tokenization
from
pytorch_pretrained_bert.tokenization
import
(
BertTokenizer
,
BasicTokenizer
,
WordpieceTokenizer
,
_is_whitespace
,
_is_control
,
_is_punctuation
)
class
TokenizationTest
(
unittest
.
TestCase
):
class
TokenizationTest
(
unittest
.
TestCase
):
...
@@ -34,7 +35,7 @@ class TokenizationTest(unittest.TestCase):
...
@@ -34,7 +35,7 @@ class TokenizationTest(unittest.TestCase):
vocab_file
=
vocab_writer
.
name
vocab_file
=
vocab_writer
.
name
tokenizer
=
tokenization
.
Full
Tokenizer
(
vocab_file
)
tokenizer
=
Bert
Tokenizer
(
vocab_file
)
os
.
remove
(
vocab_file
)
os
.
remove
(
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"UNwant
\u00E9
d,running"
)
tokens
=
tokenizer
.
tokenize
(
u
"UNwant
\u00E9
d,running"
)
...
@@ -44,14 +45,14 @@ class TokenizationTest(unittest.TestCase):
...
@@ -44,14 +45,14 @@ class TokenizationTest(unittest.TestCase):
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
def
test_chinese
(
self
):
def
test_chinese
(
self
):
tokenizer
=
tokenization
.
BasicTokenizer
()
tokenizer
=
BasicTokenizer
()
self
.
assertListEqual
(
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"ah
\u535A\u63A8
zz"
),
tokenizer
.
tokenize
(
u
"ah
\u535A\u63A8
zz"
),
[
u
"ah"
,
u
"
\u535A
"
,
u
"
\u63A8
"
,
u
"zz"
])
[
u
"ah"
,
u
"
\u535A
"
,
u
"
\u63A8
"
,
u
"zz"
])
def
test_basic_tokenizer_lower
(
self
):
def
test_basic_tokenizer_lower
(
self
):
tokenizer
=
tokenization
.
BasicTokenizer
(
do_lower_case
=
True
)
tokenizer
=
BasicTokenizer
(
do_lower_case
=
True
)
self
.
assertListEqual
(
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"
\t
HeLLo!how
\n
Are yoU? "
),
tokenizer
.
tokenize
(
u
"
\t
HeLLo!how
\n
Are yoU? "
),
...
@@ -59,7 +60,7 @@ class TokenizationTest(unittest.TestCase):
...
@@ -59,7 +60,7 @@ class TokenizationTest(unittest.TestCase):
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"H
\u00E9
llo"
),
[
"hello"
])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"H
\u00E9
llo"
),
[
"hello"
])
def
test_basic_tokenizer_no_lower
(
self
):
def
test_basic_tokenizer_no_lower
(
self
):
tokenizer
=
tokenization
.
BasicTokenizer
(
do_lower_case
=
False
)
tokenizer
=
BasicTokenizer
(
do_lower_case
=
False
)
self
.
assertListEqual
(
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"
\t
HeLLo!how
\n
Are yoU? "
),
tokenizer
.
tokenize
(
u
"
\t
HeLLo!how
\n
Are yoU? "
),
...
@@ -74,7 +75,7 @@ class TokenizationTest(unittest.TestCase):
...
@@ -74,7 +75,7 @@ class TokenizationTest(unittest.TestCase):
vocab
=
{}
vocab
=
{}
for
(
i
,
token
)
in
enumerate
(
vocab_tokens
):
for
(
i
,
token
)
in
enumerate
(
vocab_tokens
):
vocab
[
token
]
=
i
vocab
[
token
]
=
i
tokenizer
=
tokenization
.
WordpieceTokenizer
(
vocab
=
vocab
)
tokenizer
=
WordpieceTokenizer
(
vocab
=
vocab
)
self
.
assertListEqual
(
tokenizer
.
tokenize
(
""
),
[])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
""
),
[])
...
@@ -85,46 +86,32 @@ class TokenizationTest(unittest.TestCase):
...
@@ -85,46 +86,32 @@ class TokenizationTest(unittest.TestCase):
self
.
assertListEqual
(
self
.
assertListEqual
(
tokenizer
.
tokenize
(
"unwantedX running"
),
[
"[UNK]"
,
"runn"
,
"##ing"
])
tokenizer
.
tokenize
(
"unwantedX running"
),
[
"[UNK]"
,
"runn"
,
"##ing"
])
def
test_convert_tokens_to_ids
(
self
):
vocab_tokens
=
[
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"##ing"
]
vocab
=
{}
for
(
i
,
token
)
in
enumerate
(
vocab_tokens
):
vocab
[
token
]
=
i
self
.
assertListEqual
(
tokenization
.
convert_tokens_to_ids
(
vocab
,
[
"un"
,
"##want"
,
"##ed"
,
"runn"
,
"##ing"
]),
[
7
,
4
,
5
,
8
,
9
])
def
test_is_whitespace
(
self
):
def
test_is_whitespace
(
self
):
self
.
assertTrue
(
tokenization
.
_is_whitespace
(
u
" "
))
self
.
assertTrue
(
_is_whitespace
(
u
" "
))
self
.
assertTrue
(
tokenization
.
_is_whitespace
(
u
"
\t
"
))
self
.
assertTrue
(
_is_whitespace
(
u
"
\t
"
))
self
.
assertTrue
(
tokenization
.
_is_whitespace
(
u
"
\r
"
))
self
.
assertTrue
(
_is_whitespace
(
u
"
\r
"
))
self
.
assertTrue
(
tokenization
.
_is_whitespace
(
u
"
\n
"
))
self
.
assertTrue
(
_is_whitespace
(
u
"
\n
"
))
self
.
assertTrue
(
tokenization
.
_is_whitespace
(
u
"
\u00A0
"
))
self
.
assertTrue
(
_is_whitespace
(
u
"
\u00A0
"
))
self
.
assertFalse
(
tokenization
.
_is_whitespace
(
u
"A"
))
self
.
assertFalse
(
_is_whitespace
(
u
"A"
))
self
.
assertFalse
(
tokenization
.
_is_whitespace
(
u
"-"
))
self
.
assertFalse
(
_is_whitespace
(
u
"-"
))
def
test_is_control
(
self
):
def
test_is_control
(
self
):
self
.
assertTrue
(
tokenization
.
_is_control
(
u
"
\u0005
"
))
self
.
assertTrue
(
_is_control
(
u
"
\u0005
"
))
self
.
assertFalse
(
tokenization
.
_is_control
(
u
"A"
))
self
.
assertFalse
(
_is_control
(
u
"A"
))
self
.
assertFalse
(
tokenization
.
_is_control
(
u
" "
))
self
.
assertFalse
(
_is_control
(
u
" "
))
self
.
assertFalse
(
tokenization
.
_is_control
(
u
"
\t
"
))
self
.
assertFalse
(
_is_control
(
u
"
\t
"
))
self
.
assertFalse
(
tokenization
.
_is_control
(
u
"
\r
"
))
self
.
assertFalse
(
_is_control
(
u
"
\r
"
))
def
test_is_punctuation
(
self
):
def
test_is_punctuation
(
self
):
self
.
assertTrue
(
tokenization
.
_is_punctuation
(
u
"-"
))
self
.
assertTrue
(
_is_punctuation
(
u
"-"
))
self
.
assertTrue
(
tokenization
.
_is_punctuation
(
u
"$"
))
self
.
assertTrue
(
_is_punctuation
(
u
"$"
))
self
.
assertTrue
(
tokenization
.
_is_punctuation
(
u
"`"
))
self
.
assertTrue
(
_is_punctuation
(
u
"`"
))
self
.
assertTrue
(
tokenization
.
_is_punctuation
(
u
"."
))
self
.
assertTrue
(
_is_punctuation
(
u
"."
))
self
.
assertFalse
(
tokenization
.
_is_punctuation
(
u
"A"
))
self
.
assertFalse
(
_is_punctuation
(
u
"A"
))
self
.
assertFalse
(
tokenization
.
_is_punctuation
(
u
" "
))
self
.
assertFalse
(
_is_punctuation
(
u
" "
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment