Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
862cad11
Commit
862cad11
authored
Sep 12, 2018
by
Sergey Edunov
Committed by
Myle Ott
Sep 25, 2018
Browse files
Parallel preprocessing
parent
ee46c63b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
169 additions
and
34 deletions
+169
-34
fairseq/data/indexed_dataset.py
fairseq/data/indexed_dataset.py
+29
-3
fairseq/tokenizer.py
fairseq/tokenizer.py
+71
-13
preprocess.py
preprocess.py
+69
-18
No files found.
fairseq/data/indexed_dataset.py
View file @
862cad11
...
@@ -52,9 +52,15 @@ def data_file_path(prefix_path):
...
@@ -52,9 +52,15 @@ def data_file_path(prefix_path):
class
IndexedDataset
(
torch
.
utils
.
data
.
Dataset
):
class
IndexedDataset
(
torch
.
utils
.
data
.
Dataset
):
"""Loader for TorchNet IndexedDataset"""
"""Loader for TorchNet IndexedDataset"""
def
__init__
(
self
,
path
,
fix_lua_indexing
=
False
):
def
__init__
(
self
,
path
,
fix_lua_indexing
=
False
,
read_data
=
True
):
super
().
__init__
()
super
().
__init__
()
self
.
fix_lua_indexing
=
fix_lua_indexing
self
.
fix_lua_indexing
=
fix_lua_indexing
self
.
read_index
(
path
)
self
.
data_file
=
None
if
read_data
:
self
.
read_data
(
path
)
def
read_index
(
self
,
path
):
with
open
(
index_file_path
(
path
),
'rb'
)
as
f
:
with
open
(
index_file_path
(
path
),
'rb'
)
as
f
:
magic
=
f
.
read
(
8
)
magic
=
f
.
read
(
8
)
assert
magic
==
b
'TNTIDX
\x00\x00
'
assert
magic
==
b
'TNTIDX
\x00\x00
'
...
@@ -66,7 +72,6 @@ class IndexedDataset(torch.utils.data.Dataset):
...
@@ -66,7 +72,6 @@ class IndexedDataset(torch.utils.data.Dataset):
self
.
dim_offsets
=
read_longs
(
f
,
self
.
size
+
1
)
self
.
dim_offsets
=
read_longs
(
f
,
self
.
size
+
1
)
self
.
data_offsets
=
read_longs
(
f
,
self
.
size
+
1
)
self
.
data_offsets
=
read_longs
(
f
,
self
.
size
+
1
)
self
.
sizes
=
read_longs
(
f
,
self
.
s
)
self
.
sizes
=
read_longs
(
f
,
self
.
s
)
self
.
read_data
(
path
)
def
read_data
(
self
,
path
):
def
read_data
(
self
,
path
):
self
.
data_file
=
open
(
data_file_path
(
path
),
'rb'
,
buffering
=
0
)
self
.
data_file
=
open
(
data_file_path
(
path
),
'rb'
,
buffering
=
0
)
...
@@ -76,7 +81,8 @@ class IndexedDataset(torch.utils.data.Dataset):
...
@@ -76,7 +81,8 @@ class IndexedDataset(torch.utils.data.Dataset):
raise
IndexError
(
'index out of range'
)
raise
IndexError
(
'index out of range'
)
def
__del__
(
self
):
def
__del__
(
self
):
self
.
data_file
.
close
()
if
self
.
data_file
:
self
.
data_file
.
close
()
def
__getitem__
(
self
,
i
):
def
__getitem__
(
self
,
i
):
self
.
check_index
(
i
)
self
.
check_index
(
i
)
...
@@ -193,6 +199,26 @@ class IndexedDatasetBuilder(object):
...
@@ -193,6 +199,26 @@ class IndexedDatasetBuilder(object):
self
.
sizes
.
append
(
s
)
self
.
sizes
.
append
(
s
)
self
.
dim_offsets
.
append
(
self
.
dim_offsets
[
-
1
]
+
len
(
tensor
.
size
()))
self
.
dim_offsets
.
append
(
self
.
dim_offsets
[
-
1
]
+
len
(
tensor
.
size
()))
def
merge_file_
(
self
,
another_file
):
index
=
IndexedDataset
(
another_file
,
read_data
=
False
)
assert
index
.
dtype
==
self
.
dtype
begin
=
self
.
data_offsets
[
-
1
]
for
offset
in
index
.
data_offsets
[
1
:]:
self
.
data_offsets
.
append
(
begin
+
offset
)
self
.
sizes
.
extend
(
index
.
sizes
)
begin
=
self
.
dim_offsets
[
-
1
]
for
dim_offset
in
index
.
dim_offsets
[
1
:]:
self
.
dim_offsets
.
append
(
begin
+
dim_offset
)
with
open
(
data_file_path
(
another_file
),
'rb'
)
as
f
:
while
True
:
data
=
f
.
read
(
1024
)
if
data
:
self
.
out_file
.
write
(
data
)
else
:
break
def
finalize
(
self
,
index_file
):
def
finalize
(
self
,
index_file
):
self
.
out_file
.
close
()
self
.
out_file
.
close
()
index
=
open
(
index_file
,
'wb'
)
index
=
open
(
index_file
,
'wb'
)
...
...
fairseq/tokenizer.py
View file @
862cad11
...
@@ -6,10 +6,10 @@
...
@@ -6,10 +6,10 @@
# can be found in the PATENTS file in the same directory.
# can be found in the PATENTS file in the same directory.
from
collections
import
Counter
from
collections
import
Counter
import
re
import
os
,
re
import
torch
import
torch
from
multiprocessing
import
Pool
SPACE_NORMALIZER
=
re
.
compile
(
"\s+"
)
SPACE_NORMALIZER
=
re
.
compile
(
"\s+"
)
...
@@ -20,28 +20,74 @@ def tokenize_line(line):
...
@@ -20,28 +20,74 @@ def tokenize_line(line):
return
line
.
split
()
return
line
.
split
()
def
safe_readline
(
f
):
pos
=
f
.
tell
()
while
True
:
try
:
return
f
.
readline
()
except
UnicodeDecodeError
:
pos
-=
1
f
.
seek
(
pos
)
# search where this character begins
class
Tokenizer
:
class
Tokenizer
:
@
staticmethod
@
staticmethod
def
add_file_to_dictionary
(
filename
,
dict
,
tokenize
):
def
add_file_to_dictionary_single_worker
(
filename
,
tokenize
,
eos_word
,
worker_id
=
0
,
num_workers
=
1
):
counter
=
Counter
()
with
open
(
filename
,
'r'
)
as
f
:
with
open
(
filename
,
'r'
)
as
f
:
for
line
in
f
:
size
=
os
.
fstat
(
f
.
fileno
()).
st_size
chunk_size
=
size
//
num_workers
offset
=
worker_id
*
chunk_size
end
=
offset
+
chunk_size
f
.
seek
(
offset
)
if
offset
>
0
:
safe_readline
(
f
)
# drop first incomplete line
line
=
f
.
readline
()
while
line
:
for
word
in
tokenize
(
line
):
for
word
in
tokenize
(
line
):
dict
.
add_symbol
(
word
)
counter
.
update
([
word
])
dict
.
add_symbol
(
dict
.
eos_word
)
counter
.
update
([
eos_word
])
if
f
.
tell
()
>
end
:
break
line
=
f
.
readline
()
return
counter
@
staticmethod
def
add_file_to_dictionary
(
filename
,
dict
,
tokenize
,
num_workers
):
def
merge_result
(
counter
):
for
w
,
c
in
counter
.
items
():
dict
.
add_symbol
(
w
,
c
)
if
num_workers
>
1
:
pool
=
Pool
(
processes
=
num_workers
)
results
=
[]
for
worker_id
in
range
(
num_workers
):
results
.
append
(
pool
.
apply_async
(
Tokenizer
.
add_file_to_dictionary_single_worker
,
(
filename
,
tokenize
,
dict
.
eos_word
,
worker_id
,
num_workers
)
))
pool
.
close
()
pool
.
join
()
for
r
in
results
:
merge_result
(
r
.
get
())
else
:
merge_result
(
Tokenizer
.
add_file_to_dictionary_single_worker
(
filename
,
tokenize
,
dict
.
eos_word
))
@
staticmethod
@
staticmethod
def
binarize
(
filename
,
dict
,
consumer
,
tokenize
=
tokenize_line
,
def
binarize
(
filename
,
dict
,
consumer
,
tokenize
=
tokenize_line
,
append_eos
=
True
,
reverse_order
=
False
):
append_eos
=
True
,
reverse_order
=
False
,
offset
=
0
,
end
=-
1
):
nseq
,
ntok
=
0
,
0
nseq
,
ntok
=
0
,
0
replaced
=
Counter
()
replaced
=
Counter
()
def
replaced_consumer
(
word
,
idx
):
def
replaced_consumer
(
word
,
idx
):
if
idx
==
dict
.
unk_index
and
word
!=
dict
.
unk_word
:
if
idx
==
dict
.
unk_index
and
word
!=
dict
.
unk_word
:
replaced
.
update
([
word
])
replaced
.
update
([
word
])
with
open
(
filename
,
'r'
)
as
f
:
with
open
(
filename
,
'r'
)
as
f
:
for
line
in
f
:
f
.
seek
(
offset
)
# next(f) breaks f.tell(), hence readline() must be used
line
=
safe_readline
(
f
)
while
line
:
if
end
>
0
and
f
.
tell
()
>
end
:
break
ids
=
Tokenizer
.
tokenize
(
ids
=
Tokenizer
.
tokenize
(
line
=
line
,
line
=
line
,
dict
=
dict
,
dict
=
dict
,
...
@@ -52,10 +98,22 @@ class Tokenizer:
...
@@ -52,10 +98,22 @@ class Tokenizer:
reverse_order
=
reverse_order
,
reverse_order
=
reverse_order
,
)
)
nseq
+=
1
nseq
+=
1
consumer
(
ids
)
ntok
+=
len
(
ids
)
ntok
+=
len
(
ids
)
return
{
'nseq'
:
nseq
,
'nunk'
:
sum
(
replaced
.
values
()),
'ntok'
:
ntok
,
'replaced'
:
len
(
replaced
)}
consumer
(
ids
)
line
=
f
.
readline
()
return
{
'nseq'
:
nseq
,
'nunk'
:
sum
(
replaced
.
values
()),
'ntok'
:
ntok
,
'replaced'
:
replaced
}
@
staticmethod
def
find_offsets
(
filename
,
num_chunks
):
with
open
(
filename
,
'r'
)
as
f
:
size
=
os
.
fstat
(
f
.
fileno
()).
st_size
chunk_size
=
size
//
num_chunks
offsets
=
[
0
for
_
in
range
(
num_chunks
+
1
)]
for
i
in
range
(
1
,
num_chunks
):
f
.
seek
(
chunk_size
*
i
)
safe_readline
(
f
)
offsets
[
i
]
=
f
.
tell
()
return
offsets
@
staticmethod
@
staticmethod
def
tokenize
(
line
,
dict
,
tokenize
=
tokenize_line
,
add_if_not_exist
=
True
,
def
tokenize
(
line
,
dict
,
tokenize
=
tokenize_line
,
add_if_not_exist
=
True
,
...
...
preprocess.py
View file @
862cad11
...
@@ -10,12 +10,16 @@ Data pre-processing: build vocabularies and binarize training data.
...
@@ -10,12 +10,16 @@ Data pre-processing: build vocabularies and binarize training data.
"""
"""
import
argparse
import
argparse
from
collections
import
Counter
from
itertools
import
zip_longest
from
itertools
import
zip_longest
import
os
import
os
import
shutil
import
shutil
from
fairseq.data
import
indexed_dataset
,
dictionary
from
fairseq.data
import
indexed_dataset
,
dictionary
from
fairseq.tokenizer
import
Tokenizer
,
tokenize_line
from
fairseq.tokenizer
import
Tokenizer
,
tokenize_line
from
multiprocessing
import
Pool
,
Manager
,
Process
def
get_parser
():
def
get_parser
():
...
@@ -41,6 +45,7 @@ def get_parser():
...
@@ -41,6 +45,7 @@ def get_parser():
parser
.
add_argument
(
'--only-source'
,
action
=
'store_true'
,
help
=
'Only process the source language'
)
parser
.
add_argument
(
'--only-source'
,
action
=
'store_true'
,
help
=
'Only process the source language'
)
parser
.
add_argument
(
'--padding-factor'
,
metavar
=
'N'
,
default
=
8
,
type
=
int
,
parser
.
add_argument
(
'--padding-factor'
,
metavar
=
'N'
,
default
=
8
,
type
=
int
,
help
=
'Pad dictionary size to be multiple of N'
)
help
=
'Pad dictionary size to be multiple of N'
)
parser
.
add_argument
(
'--workers'
,
metavar
=
'N'
,
default
=
1
,
type
=
int
,
help
=
'number of parallel workers'
)
return
parser
return
parser
...
@@ -52,7 +57,7 @@ def main(args):
...
@@ -52,7 +57,7 @@ def main(args):
def
build_dictionary
(
filenames
):
def
build_dictionary
(
filenames
):
d
=
dictionary
.
Dictionary
()
d
=
dictionary
.
Dictionary
()
for
filename
in
filenames
:
for
filename
in
filenames
:
Tokenizer
.
add_file_to_dictionary
(
filename
,
d
,
tokenize_line
)
Tokenizer
.
add_file_to_dictionary
(
filename
,
d
,
tokenize_line
,
args
.
workers
)
return
d
return
d
def
train_path
(
lang
):
def
train_path
(
lang
):
...
@@ -70,11 +75,6 @@ def main(args):
...
@@ -70,11 +75,6 @@ def main(args):
def
dict_path
(
lang
):
def
dict_path
(
lang
):
return
dest_path
(
'dict'
,
lang
)
+
'.txt'
return
dest_path
(
'dict'
,
lang
)
+
'.txt'
def
dataset_dest_path
(
output_prefix
,
lang
,
extension
):
base
=
f
'
{
args
.
destdir
}
/
{
output_prefix
}
'
lang_part
=
f
'.
{
args
.
source_lang
}
-
{
args
.
target_lang
}
.
{
lang
}
'
if
lang
is
not
None
else
''
return
f
'
{
base
}{
lang_part
}
.
{
extension
}
'
if
args
.
joined_dictionary
:
if
args
.
joined_dictionary
:
assert
not
args
.
srcdict
,
'cannot combine --srcdict and --joined-dictionary'
assert
not
args
.
srcdict
,
'cannot combine --srcdict and --joined-dictionary'
assert
not
args
.
tgtdict
,
'cannot combine --tgtdict and --joined-dictionary'
assert
not
args
.
tgtdict
,
'cannot combine --tgtdict and --joined-dictionary'
...
@@ -111,25 +111,54 @@ def main(args):
...
@@ -111,25 +111,54 @@ def main(args):
)
)
tgt_dict
.
save
(
dict_path
(
args
.
target_lang
))
tgt_dict
.
save
(
dict_path
(
args
.
target_lang
))
def
make_binary_dataset
(
input_prefix
,
output_prefix
,
lang
):
def
make_binary_dataset
(
input_prefix
,
output_prefix
,
lang
,
num_workers
):
dict
=
dictionary
.
Dictionary
.
load
(
dict_path
(
lang
))
dict
=
dictionary
.
Dictionary
.
load
(
dict_path
(
lang
))
print
(
'| [{}] Dictionary: {} types'
.
format
(
lang
,
len
(
dict
)
-
1
))
print
(
'| [{}] Dictionary: {} types'
.
format
(
lang
,
len
(
dict
)
-
1
))
n_seq_tok
=
[
0
,
0
]
replaced
=
Counter
()
d
s
=
indexed_dataset
.
IndexedDatasetBuilder
(
dataset_dest_path
(
output_prefix
,
lang
,
'bin'
))
d
ef
merge_result
(
worker_result
):
replaced
.
update
(
worker_result
[
'replaced'
])
def
consumer
(
tensor
):
n_seq_tok
[
0
]
+=
worker_result
[
'nseq'
]
ds
.
add_item
(
tensor
)
n_seq_tok
[
1
]
+=
worker_result
[
'ntok'
]
input_file
=
'{}{}'
.
format
(
input_prefix
,
(
'.'
+
lang
)
if
lang
is
not
None
else
''
)
input_file
=
'{}{}'
.
format
(
input_prefix
,
(
'.'
+
lang
)
if
lang
is
not
None
else
''
)
res
=
Tokenizer
.
binarize
(
input_file
,
dict
,
consumer
)
offsets
=
Tokenizer
.
find_offsets
(
input_file
,
num_workers
)
pool
=
None
if
num_workers
>
1
:
pool
=
Pool
(
processes
=
num_workers
-
1
)
for
worker_id
in
range
(
1
,
num_workers
):
prefix
=
"{}{}"
.
format
(
output_prefix
,
worker_id
)
pool
.
apply_async
(
binarize
,
(
args
,
input_file
,
dict
,
prefix
,
lang
,
offsets
[
worker_id
],
offsets
[
worker_id
+
1
]),
callback
=
merge_result
)
pool
.
close
()
ds
=
indexed_dataset
.
IndexedDatasetBuilder
(
dataset_dest_file
(
args
,
output_prefix
,
lang
,
'bin'
))
merge_result
(
Tokenizer
.
binarize
(
input_file
,
dict
,
lambda
t
:
ds
.
add_item
(
t
),
offset
=
0
,
end
=
offsets
[
1
]))
if
num_workers
>
1
:
pool
.
join
()
for
worker_id
in
range
(
1
,
num_workers
):
prefix
=
"{}{}"
.
format
(
output_prefix
,
worker_id
)
temp_file_path
=
dataset_dest_prefix
(
args
,
prefix
,
lang
)
ds
.
merge_file_
(
temp_file_path
)
os
.
remove
(
indexed_dataset
.
data_file_path
(
temp_file_path
))
os
.
remove
(
indexed_dataset
.
index_file_path
(
temp_file_path
))
ds
.
finalize
(
dataset_dest_file
(
args
,
output_prefix
,
lang
,
'idx'
))
print
(
'| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'
.
format
(
print
(
'| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'
.
format
(
lang
,
input_file
,
res
[
'nseq'
],
res
[
'ntok'
],
lang
,
input_file
,
n_seq_tok
[
0
],
n_seq_tok
[
1
],
100
*
res
[
'nunk'
]
/
res
[
'ntok'
],
dict
.
unk_word
))
100
*
sum
(
replaced
.
values
())
/
n_seq_tok
[
1
],
dict
.
unk_word
))
ds
.
finalize
(
dataset_dest_path
(
output_prefix
,
lang
,
'idx'
))
def
make_dataset
(
input_prefix
,
output_prefix
,
lang
):
def
make_dataset
(
input_prefix
,
output_prefix
,
lang
,
num_workers
=
1
):
if
args
.
output_format
==
'binary'
:
if
args
.
output_format
==
'binary'
:
make_binary_dataset
(
input_prefix
,
output_prefix
,
lang
)
make_binary_dataset
(
input_prefix
,
output_prefix
,
lang
,
num_workers
)
elif
args
.
output_format
==
'raw'
:
elif
args
.
output_format
==
'raw'
:
# Copy original text file to destination folder
# Copy original text file to destination folder
output_text_file
=
dest_path
(
output_text_file
=
dest_path
(
...
@@ -140,7 +169,7 @@ def main(args):
...
@@ -140,7 +169,7 @@ def main(args):
def
make_all
(
lang
):
def
make_all
(
lang
):
if
args
.
trainpref
:
if
args
.
trainpref
:
make_dataset
(
args
.
trainpref
,
'train'
,
lang
)
make_dataset
(
args
.
trainpref
,
'train'
,
lang
,
num_workers
=
args
.
workers
)
if
args
.
validpref
:
if
args
.
validpref
:
for
k
,
validpref
in
enumerate
(
args
.
validpref
.
split
(
','
)):
for
k
,
validpref
in
enumerate
(
args
.
validpref
.
split
(
','
)):
outprefix
=
'valid{}'
.
format
(
k
)
if
k
>
0
else
'valid'
outprefix
=
'valid{}'
.
format
(
k
)
if
k
>
0
else
'valid'
...
@@ -196,6 +225,28 @@ def main(args):
...
@@ -196,6 +225,28 @@ def main(args):
print
(
'{} {}'
.
format
(
src_dict
[
k
],
tgt_dict
[
v
]),
file
=
f
)
print
(
'{} {}'
.
format
(
src_dict
[
k
],
tgt_dict
[
v
]),
file
=
f
)
def
binarize
(
args
,
filename
,
dict
,
output_prefix
,
lang
,
offset
,
end
):
ds
=
indexed_dataset
.
IndexedDatasetBuilder
(
dataset_dest_file
(
args
,
output_prefix
,
lang
,
'bin'
))
def
consumer
(
tensor
):
ds
.
add_item
(
tensor
)
res
=
Tokenizer
.
binarize
(
filename
,
dict
,
consumer
,
offset
=
offset
,
end
=
end
)
ds
.
finalize
(
dataset_dest_file
(
args
,
output_prefix
,
lang
,
'idx'
))
return
res
def
dataset_dest_prefix
(
args
,
output_prefix
,
lang
):
base
=
f
'
{
args
.
destdir
}
/
{
output_prefix
}
'
lang_part
=
f
'.
{
args
.
source_lang
}
-
{
args
.
target_lang
}
.
{
lang
}
'
if
lang
is
not
None
else
''
return
f
'
{
base
}{
lang_part
}
'
def
dataset_dest_file
(
args
,
output_prefix
,
lang
,
extension
):
base
=
dataset_dest_prefix
(
args
,
output_prefix
,
lang
)
return
f
'
{
base
}
.
{
extension
}
'
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
get_parser
()
parser
=
get_parser
()
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment