Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
fa7c575a
Commit
fa7c575a
authored
Apr 12, 2018
by
Myle Ott
Browse files
Fix preprocess.py
parent
f607d9e8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
55 additions
and
47 deletions
+55
-47
fairseq/dictionary.py
fairseq/dictionary.py
+40
-27
preprocess.py
preprocess.py
+15
-20
No files found.
fairseq/dictionary.py
View file @
fa7c575a
...
@@ -5,8 +5,9 @@
...
@@ -5,8 +5,9 @@
# the root directory of this source tree. An additional grant of patent rights
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
# can be found in the PATENTS file in the same directory.
import
math
from
collections
import
Counter
import
os
import
os
import
torch
import
torch
...
@@ -81,26 +82,43 @@ class Dictionary(object):
...
@@ -81,26 +82,43 @@ class Dictionary(object):
self
.
count
.
append
(
n
)
self
.
count
.
append
(
n
)
return
idx
return
idx
def
update
(
self
,
new_dict
):
def
finalize
(
self
,
threshold
=
1
,
nwords
=-
1
,
padding_factor
=
8
):
"""Updates counts from new dictionary."""
"""Sort symbols by frequency in descending order, ignoring special ones.
for
word
in
new_dict
.
symbols
:
idx2
=
new_dict
.
indices
[
word
]
Args:
if
word
in
self
.
indices
:
- threshold defines the minimum word count
idx
=
self
.
indices
[
word
]
- nwords defines the total number of words in the final dictionary,
self
.
count
[
idx
]
=
self
.
count
[
idx
]
+
new_dict
.
count
[
idx2
]
including special symbols
- padding_factor can be used to pad the dictionary size to be a
multiple of 8, which is important on some hardware (e.g., Nvidia
Tensor Cores).
"""
if
padding_factor
>
1
:
if
nwords
==
-
1
:
nwords
=
len
(
self
)
i
=
0
while
nwords
%
padding_factor
!=
0
:
if
nwords
>=
len
(
self
):
self
.
add_symbol
(
'madeupword{:04d}'
.
format
(
i
))
i
+=
1
nwords
+=
1
new_symbols
=
self
.
symbols
[:
self
.
nspecial
]
new_count
=
self
.
count
[:
self
.
nspecial
]
c
=
Counter
(
dict
(
zip
(
self
.
symbols
[
self
.
nspecial
:],
self
.
count
[
self
.
nspecial
:])))
for
symbol
,
count
in
c
.
most_common
(
nwords
-
self
.
nspecial
):
if
count
>=
threshold
:
new_symbols
.
append
(
symbol
)
new_count
.
append
(
count
)
else
:
else
:
idx
=
len
(
self
.
symbols
)
break
self
.
indices
[
word
]
=
idx
assert
min
(
new_count
[
self
.
nspecial
:])
>=
threshold
self
.
symbols
.
append
(
word
)
assert
len
(
new_symbols
)
<=
nwords
self
.
count
.
append
(
new_dict
.
count
[
idx2
])
assert
len
(
new_symbols
)
%
padding_factor
==
0
def
finalize
(
self
):
self
.
count
=
tuple
(
new_count
)
"""Sort symbols by frequency in descending order, ignoring special ones."""
self
.
symbols
=
tuple
(
new_symbols
)
self
.
count
,
self
.
symbols
=
zip
(
*
sorted
(
zip
(
self
.
count
,
self
.
symbols
),
key
=
(
lambda
x
:
math
.
inf
if
self
.
indices
[
x
[
1
]]
<
self
.
nspecial
else
x
[
0
]),
reverse
=
True
)
)
def
pad
(
self
):
def
pad
(
self
):
"""Helper to get index of pad symbol"""
"""Helper to get index of pad symbol"""
...
@@ -124,7 +142,6 @@ class Dictionary(object):
...
@@ -124,7 +142,6 @@ class Dictionary(object):
...
...
```
```
"""
"""
if
isinstance
(
f
,
str
):
if
isinstance
(
f
,
str
):
try
:
try
:
if
not
ignore_utf_errors
:
if
not
ignore_utf_errors
:
...
@@ -155,9 +172,5 @@ class Dictionary(object):
...
@@ -155,9 +172,5 @@ class Dictionary(object):
os
.
makedirs
(
os
.
path
.
dirname
(
f
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
dirname
(
f
),
exist_ok
=
True
)
with
open
(
f
,
'w'
,
encoding
=
'utf-8'
)
as
fd
:
with
open
(
f
,
'w'
,
encoding
=
'utf-8'
)
as
fd
:
return
self
.
save
(
fd
,
threshold
,
nwords
)
return
self
.
save
(
fd
,
threshold
,
nwords
)
cnt
=
0
for
symbol
,
count
in
zip
(
self
.
symbols
[
self
.
nspecial
:],
self
.
count
[
self
.
nspecial
:]):
for
i
,
t
in
enumerate
(
zip
(
self
.
symbols
,
self
.
count
)):
print
(
'{} {}'
.
format
(
symbol
,
count
),
file
=
f
)
if
i
>=
self
.
nspecial
and
t
[
1
]
>=
threshold
\
and
(
nwords
<=
0
or
cnt
<
nwords
):
print
(
'{} {}'
.
format
(
t
[
0
],
t
[
1
]),
file
=
f
)
cnt
+=
1
preprocess.py
View file @
fa7c575a
...
@@ -38,7 +38,8 @@ def get_parser():
...
@@ -38,7 +38,8 @@ def get_parser():
help
=
'output format (optional)'
)
help
=
'output format (optional)'
)
parser
.
add_argument
(
'--joined-dictionary'
,
action
=
'store_true'
,
help
=
'Generate joined dictionary'
)
parser
.
add_argument
(
'--joined-dictionary'
,
action
=
'store_true'
,
help
=
'Generate joined dictionary'
)
parser
.
add_argument
(
'--only-source'
,
action
=
'store_true'
,
help
=
'Only process the source language'
)
parser
.
add_argument
(
'--only-source'
,
action
=
'store_true'
,
help
=
'Only process the source language'
)
parser
.
add_argument
(
'--padding-factor'
,
metavar
=
'N'
,
default
=
8
,
help
=
'Pad dictionary size to be multiple of N'
)
parser
.
add_argument
(
'--padding-factor'
,
metavar
=
'N'
,
default
=
8
,
type
=
int
,
help
=
'Pad dictionary size to be multiple of N'
)
return
parser
return
parser
...
@@ -47,25 +48,10 @@ def main(args):
...
@@ -47,25 +48,10 @@ def main(args):
os
.
makedirs
(
args
.
destdir
,
exist_ok
=
True
)
os
.
makedirs
(
args
.
destdir
,
exist_ok
=
True
)
target
=
not
args
.
only_source
target
=
not
args
.
only_source
def
pad_dictionary
(
d
):
"""Pad dictionary to be a multiple of args.padding_factor.
Keeping the dictionary size a multiple of 8 improves performance on some
architectures, e.g., Nvidia Tensor Cores.
"""
if
args
.
padding_factor
>
1
:
i
=
0
while
len
(
d
)
%
args
.
padding_factor
!=
0
:
d
.
add_symbol
(
'madeupword{:04d}'
.
format
(
i
))
i
+=
1
assert
len
(
d
)
%
args
.
padding_factor
==
0
def
build_dictionary
(
filenames
):
def
build_dictionary
(
filenames
):
d
=
dictionary
.
Dictionary
()
d
=
dictionary
.
Dictionary
()
for
filename
in
filenames
:
for
filename
in
filenames
:
Tokenizer
.
add_file_to_dictionary
(
filename
,
d
,
tokenize_line
)
Tokenizer
.
add_file_to_dictionary
(
filename
,
d
,
tokenize_line
)
pad_dictionary
(
d
)
d
.
finalize
()
return
d
return
d
if
args
.
joined_dictionary
:
if
args
.
joined_dictionary
:
...
@@ -89,11 +75,20 @@ def main(args):
...
@@ -89,11 +75,20 @@ def main(args):
assert
args
.
trainpref
,
"--trainpref must be set if --tgtdict is not specified"
assert
args
.
trainpref
,
"--trainpref must be set if --tgtdict is not specified"
tgt_dict
=
build_dictionary
([
'{}.{}'
.
format
(
args
.
trainpref
,
args
.
target_lang
)])
tgt_dict
=
build_dictionary
([
'{}.{}'
.
format
(
args
.
trainpref
,
args
.
target_lang
)])
src_dict
.
save
(
os
.
path
.
join
(
args
.
destdir
,
'dict.{}.txt'
.
format
(
args
.
source_lang
)),
src_dict
.
finalize
(
threshold
=
args
.
thresholdsrc
,
nwords
=
args
.
nwordssrc
)
threshold
=
args
.
thresholdsrc
,
nwords
=
args
.
nwordssrc
,
padding_factor
=
args
.
padding_factor
,
)
src_dict
.
save
(
os
.
path
.
join
(
args
.
destdir
,
'dict.{}.txt'
.
format
(
args
.
source_lang
)))
if
target
:
if
target
:
tgt_dict
.
save
(
os
.
path
.
join
(
args
.
destdir
,
'dict.{}.txt'
.
format
(
args
.
target_lang
)),
if
not
args
.
joined_dictionary
:
threshold
=
args
.
thresholdtgt
,
nwords
=
args
.
nwordstgt
)
tgt_dict
.
finalize
(
threshold
=
args
.
thresholdtgt
,
nwords
=
args
.
nwordstgt
,
padding_factor
=
args
.
padding_factor
,
)
tgt_dict
.
save
(
os
.
path
.
join
(
args
.
destdir
,
'dict.{}.txt'
.
format
(
args
.
target_lang
)))
def
make_binary_dataset
(
input_prefix
,
output_prefix
,
lang
):
def
make_binary_dataset
(
input_prefix
,
output_prefix
,
lang
):
dict
=
dictionary
.
Dictionary
.
load
(
os
.
path
.
join
(
args
.
destdir
,
'dict.{}.txt'
.
format
(
lang
)))
dict
=
dictionary
.
Dictionary
.
load
(
os
.
path
.
join
(
args
.
destdir
,
'dict.{}.txt'
.
format
(
lang
)))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment