Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
c6fe9fc5
Commit
c6fe9fc5
authored
Jun 24, 2018
by
Myle Ott
Browse files
Fix for Dictionary.finalize
parent
7bcb487a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
83 additions
and
4 deletions
+83
-4
fairseq/data/dictionary.py
fairseq/data/dictionary.py
+10
-4
tests/test_dictionary.py
tests/test_dictionary.py
+73
-0
No files found.
fairseq/data/dictionary.py
View file @
c6fe9fc5
...
...
@@ -95,7 +95,7 @@ class Dictionary(object):
self
.
symbols
.
append
(
word
)
self
.
count
.
append
(
new_dict
.
count
[
idx2
])
def
finalize
(
self
,
threshold
=
1
,
nwords
=-
1
,
padding_factor
=
8
):
def
finalize
(
self
,
threshold
=
-
1
,
nwords
=-
1
,
padding_factor
=
8
):
"""Sort symbols by frequency in descending order, ignoring special ones.
Args:
...
...
@@ -109,12 +109,14 @@ class Dictionary(object):
if
nwords
==
-
1
:
nwords
=
len
(
self
)
new_indices
=
dict
(
zip
(
self
.
symbols
[:
self
.
nspecial
],
range
(
self
.
nspecial
)))
new_symbols
=
self
.
symbols
[:
self
.
nspecial
]
new_count
=
self
.
count
[:
self
.
nspecial
]
c
=
Counter
(
dict
(
zip
(
self
.
symbols
[
self
.
nspecial
:],
self
.
count
[
self
.
nspecial
:])))
for
symbol
,
count
in
c
.
most_common
(
nwords
-
self
.
nspecial
):
if
count
>=
threshold
:
new_indices
[
symbol
]
=
len
(
new_symbols
)
new_symbols
.
append
(
symbol
)
new_count
.
append
(
count
)
else
:
...
...
@@ -124,16 +126,20 @@ class Dictionary(object):
if
padding_factor
>
1
:
i
=
0
while
threshold_nwords
%
padding_factor
!=
0
:
new_symbols
.
append
(
'madeupword{:04d}'
.
format
(
i
))
symbol
=
'madeupword{:04d}'
.
format
(
i
)
new_indices
[
symbol
]
=
len
(
new_symbols
)
new_symbols
.
append
(
symbol
)
new_count
.
append
(
0
)
i
+=
1
threshold_nwords
+=
1
assert
min
(
new_count
[
self
.
nspecial
:])
>=
threshold
assert
len
(
new_symbols
)
%
padding_factor
==
0
assert
len
(
new_symbols
)
==
len
(
new_indices
)
self
.
count
=
tuple
(
new_count
)
self
.
symbols
=
tuple
(
new_symbols
)
self
.
count
=
list
(
new_count
)
self
.
symbols
=
list
(
new_symbols
)
self
.
indices
=
new_indices
def
pad
(
self
):
"""Helper to get index of pad symbol"""
...
...
tests/test_dictionary.py
0 → 100644
View file @
c6fe9fc5
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
tempfile
import
unittest
import
torch
from
fairseq.data
import
Dictionary
from
fairseq.tokenizer
import
Tokenizer
,
tokenize_line
class
TestDictionary
(
unittest
.
TestCase
):
def
test_finalize
(
self
):
txt
=
[
'A B C D'
,
'B C D'
,
'C D'
,
'D'
,
]
ref_ids1
=
list
(
map
(
torch
.
IntTensor
,
[
[
4
,
5
,
6
,
7
,
2
],
[
5
,
6
,
7
,
2
],
[
6
,
7
,
2
],
[
7
,
2
],
]))
ref_ids2
=
list
(
map
(
torch
.
IntTensor
,
[
[
7
,
6
,
5
,
4
,
2
],
[
6
,
5
,
4
,
2
],
[
5
,
4
,
2
],
[
4
,
2
],
]))
# build dictionary
d
=
Dictionary
()
for
line
in
txt
:
Tokenizer
.
tokenize
(
line
,
d
,
add_if_not_exist
=
True
)
def
get_ids
(
dictionary
):
ids
=
[]
for
line
in
txt
:
ids
.
append
(
Tokenizer
.
tokenize
(
line
,
dictionary
,
add_if_not_exist
=
False
))
return
ids
def
assertMatch
(
ids
,
ref_ids
):
for
toks
,
ref_toks
in
zip
(
ids
,
ref_ids
):
self
.
assertEqual
(
toks
.
size
(),
ref_toks
.
size
())
self
.
assertEqual
(
0
,
(
toks
!=
ref_toks
).
sum
().
item
())
ids
=
get_ids
(
d
)
assertMatch
(
ids
,
ref_ids1
)
# check finalized dictionary
d
.
finalize
()
finalized_ids
=
get_ids
(
d
)
assertMatch
(
finalized_ids
,
ref_ids2
)
# write to disk and reload
with
tempfile
.
NamedTemporaryFile
(
mode
=
'w'
)
as
tmp_dict
:
d
.
save
(
tmp_dict
.
name
)
d
=
Dictionary
.
load
(
tmp_dict
.
name
)
reload_ids
=
get_ids
(
d
)
assertMatch
(
reload_ids
,
ref_ids2
)
assertMatch
(
finalized_ids
,
reload_ids
)
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment