Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
dcuai
dlexamples
Commits
c0f05c10
Commit
c0f05c10
authored
Nov 29, 2022
by
hepj
Browse files
更新transformer代码
parent
c056df78
Changes
321
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2121 additions
and
0 deletions
+2121
-0
PyTorch/NLP/new-Transformer/fairseq/data/encoders/bytes.py
PyTorch/NLP/new-Transformer/fairseq/data/encoders/bytes.py
+34
-0
PyTorch/NLP/new-Transformer/fairseq/data/encoders/characters.py
...h/NLP/new-Transformer/fairseq/data/encoders/characters.py
+30
-0
PyTorch/NLP/new-Transformer/fairseq/data/encoders/fastbpe.py
PyTorch/NLP/new-Transformer/fairseq/data/encoders/fastbpe.py
+36
-0
PyTorch/NLP/new-Transformer/fairseq/data/encoders/gpt2_bpe.py
...rch/NLP/new-Transformer/fairseq/data/encoders/gpt2_bpe.py
+45
-0
PyTorch/NLP/new-Transformer/fairseq/data/encoders/gpt2_bpe_utils.py
...P/new-Transformer/fairseq/data/encoders/gpt2_bpe_utils.py
+140
-0
PyTorch/NLP/new-Transformer/fairseq/data/encoders/hf_bert_bpe.py
.../NLP/new-Transformer/fairseq/data/encoders/hf_bert_bpe.py
+50
-0
PyTorch/NLP/new-Transformer/fairseq/data/encoders/hf_byte_bpe.py
.../NLP/new-Transformer/fairseq/data/encoders/hf_byte_bpe.py
+50
-0
PyTorch/NLP/new-Transformer/fairseq/data/encoders/moses_tokenizer.py
.../new-Transformer/fairseq/data/encoders/moses_tokenizer.py
+49
-0
PyTorch/NLP/new-Transformer/fairseq/data/encoders/nltk_tokenizer.py
...P/new-Transformer/fairseq/data/encoders/nltk_tokenizer.py
+24
-0
PyTorch/NLP/new-Transformer/fairseq/data/encoders/sentencepiece_bpe.py
...ew-Transformer/fairseq/data/encoders/sentencepiece_bpe.py
+65
-0
PyTorch/NLP/new-Transformer/fairseq/data/encoders/space_tokenizer.py
.../new-Transformer/fairseq/data/encoders/space_tokenizer.py
+21
-0
PyTorch/NLP/new-Transformer/fairseq/data/encoders/subword_nmt_bpe.py
.../new-Transformer/fairseq/data/encoders/subword_nmt_bpe.py
+54
-0
PyTorch/NLP/new-Transformer/fairseq/data/encoders/utils.py
PyTorch/NLP/new-Transformer/fairseq/data/encoders/utils.py
+30
-0
PyTorch/NLP/new-Transformer/fairseq/data/fairseq_dataset.py
PyTorch/NLP/new-Transformer/fairseq/data/fairseq_dataset.py
+205
-0
PyTorch/NLP/new-Transformer/fairseq/data/fasta_dataset.py
PyTorch/NLP/new-Transformer/fairseq/data/fasta_dataset.py
+107
-0
PyTorch/NLP/new-Transformer/fairseq/data/huffman/__init__.py
PyTorch/NLP/new-Transformer/fairseq/data/huffman/__init__.py
+21
-0
PyTorch/NLP/new-Transformer/fairseq/data/huffman/huffman_coder.py
...NLP/new-Transformer/fairseq/data/huffman/huffman_coder.py
+267
-0
PyTorch/NLP/new-Transformer/fairseq/data/huffman/huffman_mmap_indexed_dataset.py
...rmer/fairseq/data/huffman/huffman_mmap_indexed_dataset.py
+287
-0
PyTorch/NLP/new-Transformer/fairseq/data/id_dataset.py
PyTorch/NLP/new-Transformer/fairseq/data/id_dataset.py
+19
-0
PyTorch/NLP/new-Transformer/fairseq/data/indexed_dataset.py
PyTorch/NLP/new-Transformer/fairseq/data/indexed_dataset.py
+587
-0
No files found.
Too many changes to show.
To preserve performance only
321 of 321+
files are displayed.
Plain diff
Email patch
PyTorch/NLP/new-Transformer/fairseq/data/encoders/bytes.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
fairseq.data.encoders
import
register_bpe
from
fairseq.data.encoders.byte_utils
import
(
SPACE
,
SPACE_ESCAPE
,
byte_encode
,
smart_byte_decode
,
)
@
register_bpe
(
"bytes"
)
class
Bytes
(
object
):
def
__init__
(
self
,
*
unused
):
pass
@
staticmethod
def
add_args
(
parser
):
pass
@
staticmethod
def
encode
(
x
:
str
)
->
str
:
encoded
=
byte_encode
(
x
)
escaped
=
encoded
.
replace
(
SPACE
,
SPACE_ESCAPE
)
return
SPACE
.
join
(
list
(
escaped
))
@
staticmethod
def
decode
(
x
:
str
)
->
str
:
unescaped
=
x
.
replace
(
SPACE
,
""
).
replace
(
SPACE_ESCAPE
,
SPACE
)
return
smart_byte_decode
(
unescaped
)
PyTorch/NLP/new-Transformer/fairseq/data/encoders/characters.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
fairseq.data.encoders
import
register_bpe
SPACE
=
chr
(
32
)
SPACE_ESCAPE
=
chr
(
9601
)
@
register_bpe
(
"characters"
)
class
Characters
(
object
):
def
__init__
(
self
,
*
unused
):
pass
@
staticmethod
def
add_args
(
parser
):
pass
@
staticmethod
def
encode
(
x
:
str
)
->
str
:
escaped
=
x
.
replace
(
SPACE
,
SPACE_ESCAPE
)
return
SPACE
.
join
(
list
(
escaped
))
@
staticmethod
def
decode
(
x
:
str
)
->
str
:
return
x
.
replace
(
SPACE
,
""
).
replace
(
SPACE_ESCAPE
,
SPACE
)
PyTorch/NLP/new-Transformer/fairseq/data/encoders/fastbpe.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
dataclasses
import
dataclass
,
field
from
fairseq
import
file_utils
from
fairseq.data.encoders
import
register_bpe
from
fairseq.dataclass
import
FairseqDataclass
@
dataclass
class
fastBPEConfig
(
FairseqDataclass
):
bpe_codes
:
str
=
field
(
default
=
"???"
,
metadata
=
{
"help"
:
"path to fastBPE BPE"
})
@
register_bpe
(
"fastbpe"
,
dataclass
=
fastBPEConfig
)
class
fastBPE
(
object
):
def
__init__
(
self
,
cfg
):
if
cfg
.
bpe_codes
is
None
:
raise
ValueError
(
"--bpe-codes is required for --bpe=fastbpe"
)
codes
=
file_utils
.
cached_path
(
cfg
.
bpe_codes
)
try
:
import
fastBPE
self
.
bpe
=
fastBPE
.
fastBPE
(
codes
)
self
.
bpe_symbol
=
"@@ "
except
ImportError
:
raise
ImportError
(
"Please install fastBPE with: pip install fastBPE"
)
def
encode
(
self
,
x
:
str
)
->
str
:
return
self
.
bpe
.
apply
([
x
])[
0
]
def
decode
(
self
,
x
:
str
)
->
str
:
return
(
x
+
" "
).
replace
(
self
.
bpe_symbol
,
""
).
rstrip
()
PyTorch/NLP/new-Transformer/fairseq/data/encoders/gpt2_bpe.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
dataclasses
import
dataclass
,
field
from
fairseq
import
file_utils
from
fairseq.data.encoders
import
register_bpe
from
fairseq.dataclass
import
FairseqDataclass
from
.gpt2_bpe_utils
import
get_encoder
DEFAULT_ENCODER_JSON
=
"https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"
DEFAULT_VOCAB_BPE
=
"https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
@
dataclass
class
GPT2BPEConfig
(
FairseqDataclass
):
gpt2_encoder_json
:
str
=
field
(
default
=
DEFAULT_ENCODER_JSON
,
metadata
=
{
"help"
:
"path to encoder.json"
}
)
gpt2_vocab_bpe
:
str
=
field
(
default
=
DEFAULT_VOCAB_BPE
,
metadata
=
{
"help"
:
"path to vocab.bpe"
}
)
@
register_bpe
(
"gpt2"
,
dataclass
=
GPT2BPEConfig
)
class
GPT2BPE
(
object
):
def
__init__
(
self
,
cfg
):
encoder_json
=
file_utils
.
cached_path
(
cfg
.
gpt2_encoder_json
)
vocab_bpe
=
file_utils
.
cached_path
(
cfg
.
gpt2_vocab_bpe
)
self
.
bpe
=
get_encoder
(
encoder_json
,
vocab_bpe
)
def
encode
(
self
,
x
:
str
)
->
str
:
return
" "
.
join
(
map
(
str
,
self
.
bpe
.
encode
(
x
)))
def
decode
(
self
,
x
:
str
)
->
str
:
return
self
.
bpe
.
decode
(
[
int
(
tok
)
if
tok
not
in
{
"<unk>"
,
"<mask>"
}
else
tok
for
tok
in
x
.
split
()]
)
def
is_beginning_of_word
(
self
,
x
:
str
)
->
bool
:
return
self
.
decode
(
x
).
startswith
(
" "
)
PyTorch/NLP/new-Transformer/fairseq/data/encoders/gpt2_bpe_utils.py
0 → 100644
View file @
c0f05c10
"""
Byte pair encoding utilities from GPT-2.
Original source: https://github.com/openai/gpt-2/blob/master/src/encoder.py
Original license: MIT
"""
import
json
from
functools
import
lru_cache
@
lru_cache
()
def
bytes_to_unicode
():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs
=
(
list
(
range
(
ord
(
"!"
),
ord
(
"~"
)
+
1
))
+
list
(
range
(
ord
(
"¡"
),
ord
(
"¬"
)
+
1
))
+
list
(
range
(
ord
(
"®"
),
ord
(
"ÿ"
)
+
1
))
)
cs
=
bs
[:]
n
=
0
for
b
in
range
(
2
**
8
):
if
b
not
in
bs
:
bs
.
append
(
b
)
cs
.
append
(
2
**
8
+
n
)
n
+=
1
cs
=
[
chr
(
n
)
for
n
in
cs
]
return
dict
(
zip
(
bs
,
cs
))
def
get_pairs
(
word
):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs
=
set
()
prev_char
=
word
[
0
]
for
char
in
word
[
1
:]:
pairs
.
add
((
prev_char
,
char
))
prev_char
=
char
return
pairs
class
Encoder
:
def
__init__
(
self
,
encoder
,
bpe_merges
,
errors
=
"replace"
):
self
.
encoder
=
encoder
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
errors
=
errors
# how to handle errors in decoding
self
.
byte_encoder
=
bytes_to_unicode
()
self
.
byte_decoder
=
{
v
:
k
for
k
,
v
in
self
.
byte_encoder
.
items
()}
self
.
bpe_ranks
=
dict
(
zip
(
bpe_merges
,
range
(
len
(
bpe_merges
))))
self
.
cache
=
{}
try
:
import
regex
as
re
self
.
re
=
re
except
ImportError
:
raise
ImportError
(
"Please install regex with: pip install regex"
)
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self
.
pat
=
self
.
re
.
compile
(
r
"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
def
bpe
(
self
,
token
):
if
token
in
self
.
cache
:
return
self
.
cache
[
token
]
word
=
tuple
(
token
)
pairs
=
get_pairs
(
word
)
if
not
pairs
:
return
token
while
True
:
bigram
=
min
(
pairs
,
key
=
lambda
pair
:
self
.
bpe_ranks
.
get
(
pair
,
float
(
"inf"
)))
if
bigram
not
in
self
.
bpe_ranks
:
break
first
,
second
=
bigram
new_word
=
[]
i
=
0
while
i
<
len
(
word
):
try
:
j
=
word
.
index
(
first
,
i
)
new_word
.
extend
(
word
[
i
:
j
])
i
=
j
except
:
new_word
.
extend
(
word
[
i
:])
break
if
word
[
i
]
==
first
and
i
<
len
(
word
)
-
1
and
word
[
i
+
1
]
==
second
:
new_word
.
append
(
first
+
second
)
i
+=
2
else
:
new_word
.
append
(
word
[
i
])
i
+=
1
new_word
=
tuple
(
new_word
)
word
=
new_word
if
len
(
word
)
==
1
:
break
else
:
pairs
=
get_pairs
(
word
)
word
=
" "
.
join
(
word
)
self
.
cache
[
token
]
=
word
return
word
def
encode
(
self
,
text
):
bpe_tokens
=
[]
for
token
in
self
.
re
.
findall
(
self
.
pat
,
text
):
token
=
""
.
join
(
self
.
byte_encoder
[
b
]
for
b
in
token
.
encode
(
"utf-8"
))
bpe_tokens
.
extend
(
self
.
encoder
[
bpe_token
]
for
bpe_token
in
self
.
bpe
(
token
).
split
(
" "
)
)
return
bpe_tokens
def
decode
(
self
,
tokens
):
text
=
""
.
join
([
self
.
decoder
.
get
(
token
,
token
)
for
token
in
tokens
])
text
=
bytearray
([
self
.
byte_decoder
[
c
]
for
c
in
text
]).
decode
(
"utf-8"
,
errors
=
self
.
errors
)
return
text
def
get_encoder
(
encoder_json_path
,
vocab_bpe_path
):
with
open
(
encoder_json_path
,
"r"
)
as
f
:
encoder
=
json
.
load
(
f
)
with
open
(
vocab_bpe_path
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
bpe_data
=
f
.
read
()
bpe_merges
=
[
tuple
(
merge_str
.
split
())
for
merge_str
in
bpe_data
.
split
(
"
\n
"
)[
1
:
-
1
]]
return
Encoder
(
encoder
=
encoder
,
bpe_merges
=
bpe_merges
,
)
PyTorch/NLP/new-Transformer/fairseq/data/encoders/hf_bert_bpe.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
from
fairseq.data.encoders
import
register_bpe
from
fairseq.dataclass
import
FairseqDataclass
@
dataclass
class
BertBPEConfig
(
FairseqDataclass
):
bpe_cased
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"set for cased BPE"
})
bpe_vocab_file
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"bpe vocab file"
}
)
@
register_bpe
(
"bert"
,
dataclass
=
BertBPEConfig
)
class
BertBPE
(
object
):
def
__init__
(
self
,
cfg
):
try
:
from
transformers
import
BertTokenizer
except
ImportError
:
raise
ImportError
(
"Please install transformers with: pip install transformers"
)
if
cfg
.
bpe_vocab_file
:
self
.
bert_tokenizer
=
BertTokenizer
(
cfg
.
bpe_vocab_file
,
do_lower_case
=
not
cfg
.
bpe_cased
)
else
:
vocab_file_name
=
(
"bert-base-cased"
if
cfg
.
bpe_cased
else
"bert-base-uncased"
)
self
.
bert_tokenizer
=
BertTokenizer
.
from_pretrained
(
vocab_file_name
)
def
encode
(
self
,
x
:
str
)
->
str
:
return
" "
.
join
(
self
.
bert_tokenizer
.
tokenize
(
x
))
def
decode
(
self
,
x
:
str
)
->
str
:
return
self
.
bert_tokenizer
.
clean_up_tokenization
(
self
.
bert_tokenizer
.
convert_tokens_to_string
(
x
.
split
(
" "
))
)
def
is_beginning_of_word
(
self
,
x
:
str
)
->
bool
:
return
not
x
.
startswith
(
"##"
)
PyTorch/NLP/new-Transformer/fairseq/data/encoders/hf_byte_bpe.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
dataclasses
import
dataclass
,
field
from
fairseq.data.encoders
import
register_bpe
from
fairseq.dataclass
import
FairseqDataclass
from
fairseq
import
file_utils
@
dataclass
class
HuggingFaceByteLevelBPEConfig
(
FairseqDataclass
):
bpe_merges
:
str
=
field
(
default
=
"???"
,
metadata
=
{
"help"
:
"path to merges.txt"
})
bpe_vocab
:
str
=
field
(
default
=
"???"
,
metadata
=
{
"help"
:
"path to vocab.json"
})
bpe_add_prefix_space
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"add prefix space before encoding"
}
)
@
register_bpe
(
"hf_byte_bpe"
,
dataclass
=
HuggingFaceByteLevelBPEConfig
)
class
HuggingFaceByteLevelBPE
(
object
):
def
__init__
(
self
,
cfg
):
try
:
from
tokenizers
import
ByteLevelBPETokenizer
except
ImportError
:
raise
ImportError
(
"Please install huggingface/tokenizers with: "
"pip install tokenizers"
)
bpe_vocab
=
file_utils
.
cached_path
(
cfg
.
bpe_vocab
)
bpe_merges
=
file_utils
.
cached_path
(
cfg
.
bpe_merges
)
self
.
bpe
=
ByteLevelBPETokenizer
(
bpe_vocab
,
bpe_merges
,
add_prefix_space
=
cfg
.
bpe_add_prefix_space
,
)
def
encode
(
self
,
x
:
str
)
->
str
:
return
" "
.
join
(
map
(
str
,
self
.
bpe
.
encode
(
x
).
ids
))
def
decode
(
self
,
x
:
str
)
->
str
:
return
self
.
bpe
.
decode
(
[
int
(
tok
)
if
tok
not
in
{
"<unk>"
,
"<mask>"
}
else
tok
for
tok
in
x
.
split
()]
)
def
is_beginning_of_word
(
self
,
x
:
str
)
->
bool
:
return
self
.
decode
(
x
).
startswith
(
" "
)
PyTorch/NLP/new-Transformer/fairseq/data/encoders/moses_tokenizer.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
dataclasses
import
dataclass
,
field
from
fairseq.data.encoders
import
register_tokenizer
from
fairseq.dataclass
import
FairseqDataclass
@
dataclass
class
MosesTokenizerConfig
(
FairseqDataclass
):
source_lang
:
str
=
field
(
default
=
"en"
,
metadata
=
{
"help"
:
"source language"
})
target_lang
:
str
=
field
(
default
=
"en"
,
metadata
=
{
"help"
:
"target language"
})
moses_no_dash_splits
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"don't apply dash split rules"
}
)
moses_no_escape
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"don't perform HTML escaping on apostrophe, quotes, etc."
},
)
@
register_tokenizer
(
"moses"
,
dataclass
=
MosesTokenizerConfig
)
class
MosesTokenizer
(
object
):
def
__init__
(
self
,
cfg
:
MosesTokenizerConfig
):
self
.
cfg
=
cfg
try
:
from
sacremoses
import
MosesTokenizer
,
MosesDetokenizer
self
.
tok
=
MosesTokenizer
(
cfg
.
source_lang
)
self
.
detok
=
MosesDetokenizer
(
cfg
.
target_lang
)
except
ImportError
:
raise
ImportError
(
"Please install Moses tokenizer with: pip install sacremoses"
)
def
encode
(
self
,
x
:
str
)
->
str
:
return
self
.
tok
.
tokenize
(
x
,
aggressive_dash_splits
=
(
not
self
.
cfg
.
moses_no_dash_splits
),
return_str
=
True
,
escape
=
(
not
self
.
cfg
.
moses_no_escape
),
)
def
decode
(
self
,
x
:
str
)
->
str
:
return
self
.
detok
.
detokenize
(
x
.
split
())
PyTorch/NLP/new-Transformer/fairseq/data/encoders/nltk_tokenizer.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
fairseq.data.encoders
import
register_tokenizer
from
fairseq.dataclass
import
FairseqDataclass
@
register_tokenizer
(
"nltk"
,
dataclass
=
FairseqDataclass
)
class
NLTKTokenizer
(
object
):
def
__init__
(
self
,
*
unused
):
try
:
from
nltk.tokenize
import
word_tokenize
self
.
word_tokenize
=
word_tokenize
except
ImportError
:
raise
ImportError
(
"Please install nltk with: pip install nltk"
)
def
encode
(
self
,
x
:
str
)
->
str
:
return
" "
.
join
(
self
.
word_tokenize
(
x
))
def
decode
(
self
,
x
:
str
)
->
str
:
return
x
PyTorch/NLP/new-Transformer/fairseq/data/encoders/sentencepiece_bpe.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
from
fairseq
import
file_utils
from
fairseq.data.encoders
import
register_bpe
from
fairseq.dataclass
import
FairseqDataclass
@
dataclass
class
SentencepieceConfig
(
FairseqDataclass
):
sentencepiece_model
:
str
=
field
(
default
=
"???"
,
metadata
=
{
"help"
:
"path to sentencepiece model"
}
)
sentencepiece_enable_sampling
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"enable sampling"
}
)
sentencepiece_alpha
:
Optional
[
float
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"soothing parameter for unigram sampling, "
"and merge probability for BPE-dropout"
},
)
@
register_bpe
(
"sentencepiece"
,
dataclass
=
SentencepieceConfig
)
class
SentencepieceBPE
(
object
):
def
__init__
(
self
,
cfg
):
self
.
enable_sampling
=
cfg
.
sentencepiece_enable_sampling
self
.
alpha
=
cfg
.
sentencepiece_alpha
sentencepiece_model
=
file_utils
.
cached_path
(
cfg
.
sentencepiece_model
)
try
:
import
sentencepiece
as
spm
self
.
sp
=
spm
.
SentencePieceProcessor
()
self
.
sp
.
Load
(
sentencepiece_model
)
except
ImportError
:
raise
ImportError
(
"Please install sentencepiece with: pip install sentencepiece"
)
def
encode
(
self
,
x
:
str
)
->
str
:
return
" "
.
join
(
self
.
sp
.
Encode
(
x
,
out_type
=
str
,
enable_sampling
=
self
.
enable_sampling
,
alpha
=
self
.
alpha
)
)
def
decode
(
self
,
x
:
str
)
->
str
:
return
x
.
replace
(
" "
,
""
).
replace
(
"
\u2581
"
,
" "
).
strip
()
def
is_beginning_of_word
(
self
,
x
:
str
)
->
bool
:
if
x
in
[
"<unk>"
,
"<s>"
,
"</s>"
,
"<pad>"
]:
# special elements are always considered beginnings
# HACK: this logic is already present in fairseq/tasks/masked_lm.py
# but these special tokens are also contained in the sentencepiece
# vocabulary which causes duplicate special tokens. This hack makes
# sure that they are all taken into account.
return
True
return
x
.
startswith
(
"
\u2581
"
)
PyTorch/NLP/new-Transformer/fairseq/data/encoders/space_tokenizer.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
re
from
fairseq.data.encoders
import
register_tokenizer
from
fairseq.dataclass
import
FairseqDataclass
@
register_tokenizer
(
"space"
,
dataclass
=
FairseqDataclass
)
class
SpaceTokenizer
(
object
):
def
__init__
(
self
,
*
unused
):
self
.
space_tok
=
re
.
compile
(
r
"\s+"
)
def
encode
(
self
,
x
:
str
)
->
str
:
return
self
.
space_tok
.
sub
(
" "
,
x
)
def
decode
(
self
,
x
:
str
)
->
str
:
return
x
PyTorch/NLP/new-Transformer/fairseq/data/encoders/subword_nmt_bpe.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
dataclasses
import
dataclass
,
field
from
fairseq
import
file_utils
from
fairseq.data.encoders
import
register_bpe
from
fairseq.dataclass
import
FairseqDataclass
@
dataclass
class
SubwordNMTBPEConfig
(
FairseqDataclass
):
bpe_codes
:
str
=
field
(
default
=
"???"
,
metadata
=
{
"help"
:
"path to subword NMT BPE"
})
bpe_separator
:
str
=
field
(
default
=
"@@"
,
metadata
=
{
"help"
:
"BPE separator"
})
@
register_bpe
(
"subword_nmt"
,
dataclass
=
SubwordNMTBPEConfig
)
class
SubwordNMTBPE
(
object
):
def
__init__
(
self
,
cfg
):
if
cfg
.
bpe_codes
is
None
:
raise
ValueError
(
"--bpe-codes is required for --bpe=subword_nmt"
)
codes
=
file_utils
.
cached_path
(
cfg
.
bpe_codes
)
try
:
from
subword_nmt
import
apply_bpe
bpe_parser
=
apply_bpe
.
create_parser
()
bpe_args
=
bpe_parser
.
parse_args
(
[
"--codes"
,
codes
,
"--separator"
,
cfg
.
bpe_separator
,
]
)
self
.
bpe
=
apply_bpe
.
BPE
(
bpe_args
.
codes
,
bpe_args
.
merges
,
bpe_args
.
separator
,
None
,
bpe_args
.
glossaries
,
)
self
.
bpe_symbol
=
bpe_args
.
separator
+
" "
except
ImportError
:
raise
ImportError
(
"Please install subword_nmt with: pip install subword-nmt"
)
def
encode
(
self
,
x
:
str
)
->
str
:
return
self
.
bpe
.
process_line
(
x
)
def
decode
(
self
,
x
:
str
)
->
str
:
return
(
x
+
" "
).
replace
(
self
.
bpe_symbol
,
""
).
rstrip
()
PyTorch/NLP/new-Transformer/fairseq/data/encoders/utils.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
from
fairseq.data
import
encoders
def
get_whole_word_mask
(
args
,
dictionary
):
bpe
=
encoders
.
build_bpe
(
args
)
if
bpe
is
not
None
:
def
is_beginning_of_word
(
i
):
if
i
<
dictionary
.
nspecial
:
# special elements are always considered beginnings
return
True
tok
=
dictionary
[
i
]
if
tok
.
startswith
(
"madeupword"
):
return
True
try
:
return
bpe
.
is_beginning_of_word
(
tok
)
except
ValueError
:
return
True
mask_whole_words
=
torch
.
ByteTensor
(
list
(
map
(
is_beginning_of_word
,
range
(
len
(
dictionary
))))
)
return
mask_whole_words
return
None
PyTorch/NLP/new-Transformer/fairseq/data/fairseq_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
logging
import
numpy
as
np
import
torch.utils.data
from
fairseq.data
import
data_utils
logger
=
logging
.
getLogger
(
__name__
)
class
EpochListening
:
"""Mixin for receiving updates whenever the epoch increments."""
@
property
def
can_reuse_epoch_itr_across_epochs
(
self
):
"""
Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for
this dataset across epochs.
This needs to return ``False`` if the sample sizes can change across
epochs, in which case we may need to regenerate batches at each epoch.
If your dataset relies in ``set_epoch`` then you should consider setting
this to ``False``.
"""
return
True
def
set_epoch
(
self
,
epoch
):
"""Will receive the updated epoch number at the beginning of the epoch."""
pass
class
FairseqDataset
(
torch
.
utils
.
data
.
Dataset
,
EpochListening
):
"""A dataset that provides helpers for batching."""
def
__getitem__
(
self
,
index
):
raise
NotImplementedError
def
__len__
(
self
):
raise
NotImplementedError
def
collater
(
self
,
samples
):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
raise
NotImplementedError
def
num_tokens
(
self
,
index
):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
raise
NotImplementedError
def
num_tokens_vec
(
self
,
indices
):
"""Return the number of tokens for a set of positions defined by indices.
This value is used to enforce ``--max-tokens`` during batching."""
raise
NotImplementedError
def
size
(
self
,
index
):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
raise
NotImplementedError
def
ordered_indices
(
self
):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
return
np
.
arange
(
len
(
self
),
dtype
=
np
.
int64
)
@
property
def
supports_prefetch
(
self
):
"""Whether this dataset supports prefetching."""
return
False
def
attr
(
self
,
attr
:
str
,
index
:
int
):
return
getattr
(
self
,
attr
,
None
)
def
prefetch
(
self
,
indices
):
"""Prefetch the data required for this epoch."""
raise
NotImplementedError
def
get_batch_shapes
(
self
):
"""
Return a list of valid batch shapes, for example::
[(8, 512), (16, 256), (32, 128)]
The first dimension of each tuple is the batch size and can be ``None``
to automatically infer the max batch size based on ``--max-tokens``.
The second dimension of each tuple is the max supported length as given
by :func:`fairseq.data.FairseqDataset.num_tokens`.
This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size`
to restrict batch shapes. This is useful on TPUs to avoid too many
dynamic shapes (and recompilations).
"""
return
None
def
batch_by_size
(
self
,
indices
,
max_tokens
=
None
,
max_sentences
=
None
,
required_batch_size_multiple
=
1
,
):
"""
Given an ordered set of indices, return batches according to
*max_tokens*, *max_sentences* and *required_batch_size_multiple*.
"""
from
fairseq.data
import
data_utils
fixed_shapes
=
self
.
get_batch_shapes
()
if
fixed_shapes
is
not
None
:
def
adjust_bsz
(
bsz
,
num_tokens
):
if
bsz
is
None
:
assert
max_tokens
is
not
None
,
"Must specify --max-tokens"
bsz
=
max_tokens
//
num_tokens
if
max_sentences
is
not
None
:
bsz
=
min
(
bsz
,
max_sentences
)
elif
(
bsz
>=
required_batch_size_multiple
and
bsz
%
required_batch_size_multiple
!=
0
):
bsz
-=
bsz
%
required_batch_size_multiple
return
bsz
fixed_shapes
=
np
.
array
(
[
[
adjust_bsz
(
bsz
,
num_tokens
),
num_tokens
]
for
(
bsz
,
num_tokens
)
in
fixed_shapes
]
)
try
:
num_tokens_vec
=
self
.
num_tokens_vec
(
indices
).
astype
(
"int64"
)
except
NotImplementedError
:
num_tokens_vec
=
None
return
data_utils
.
batch_by_size
(
indices
,
num_tokens_fn
=
self
.
num_tokens
,
num_tokens_vec
=
num_tokens_vec
,
max_tokens
=
max_tokens
,
max_sentences
=
max_sentences
,
required_batch_size_multiple
=
required_batch_size_multiple
,
fixed_shapes
=
fixed_shapes
,
)
def
filter_indices_by_size
(
self
,
indices
,
max_sizes
):
"""
Filter a list of sample indices. Remove those that are longer than
specified in *max_sizes*.
WARNING: don't update, override method in child classes
Args:
indices (np.array): original array of sample indices
max_sizes (int or list[int] or tuple[int]): max sample size,
can be defined separately for src and tgt (then list or tuple)
Returns:
np.array: filtered sample array
list: list of removed indices
"""
if
isinstance
(
max_sizes
,
float
)
or
isinstance
(
max_sizes
,
int
):
if
hasattr
(
self
,
"sizes"
)
and
isinstance
(
self
.
sizes
,
np
.
ndarray
):
ignored
=
indices
[
self
.
sizes
[
indices
]
>
max_sizes
].
tolist
()
indices
=
indices
[
self
.
sizes
[
indices
]
<=
max_sizes
]
elif
(
hasattr
(
self
,
"sizes"
)
and
isinstance
(
self
.
sizes
,
list
)
and
len
(
self
.
sizes
)
==
1
):
ignored
=
indices
[
self
.
sizes
[
0
][
indices
]
>
max_sizes
].
tolist
()
indices
=
indices
[
self
.
sizes
[
0
][
indices
]
<=
max_sizes
]
else
:
indices
,
ignored
=
data_utils
.
_filter_by_size_dynamic
(
indices
,
self
.
size
,
max_sizes
)
else
:
indices
,
ignored
=
data_utils
.
_filter_by_size_dynamic
(
indices
,
self
.
size
,
max_sizes
)
return
indices
,
ignored
@
property
def
supports_fetch_outside_dataloader
(
self
):
"""Whether this dataset supports fetching outside the workers of the dataloader."""
return
True
class
FairseqIterableDataset
(
torch
.
utils
.
data
.
IterableDataset
,
EpochListening
):
"""
For datasets that need to be read sequentially, usually because the data is
being streamed or otherwise can't be manipulated on a single machine.
"""
def
__iter__
(
self
):
raise
NotImplementedError
PyTorch/NLP/new-Transformer/fairseq/data/fasta_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
os
import
subprocess
import
threading
from
pathlib
import
Path
import
numpy
as
np
import
torch
def
fasta_file_path
(
prefix_path
):
return
prefix_path
+
".fasta"
class
FastaDataset
(
torch
.
utils
.
data
.
Dataset
):
"""
For loading protein sequence datasets in the common FASTA data format
"""
def
__init__
(
self
,
path
:
str
,
cache_indices
=
False
):
self
.
fn
=
fasta_file_path
(
path
)
self
.
threadlocal
=
threading
.
local
()
self
.
cache
=
Path
(
f
"
{
path
}
.fasta.idx.npy"
)
if
cache_indices
:
if
self
.
cache
.
exists
():
self
.
offsets
,
self
.
sizes
=
np
.
load
(
self
.
cache
)
else
:
self
.
offsets
,
self
.
sizes
=
self
.
_build_index
(
path
)
np
.
save
(
self
.
cache
,
np
.
stack
([
self
.
offsets
,
self
.
sizes
]))
else
:
self
.
offsets
,
self
.
sizes
=
self
.
_build_index
(
path
)
def
_get_file
(
self
):
if
not
hasattr
(
self
.
threadlocal
,
"f"
):
self
.
threadlocal
.
f
=
open
(
self
.
fn
,
"r"
)
return
self
.
threadlocal
.
f
def
__getitem__
(
self
,
idx
):
f
=
self
.
_get_file
()
f
.
seek
(
self
.
offsets
[
idx
])
desc
=
f
.
readline
().
strip
()
line
=
f
.
readline
()
seq
=
""
while
line
!=
""
and
line
[
0
]
!=
">"
:
seq
+=
line
.
strip
()
line
=
f
.
readline
()
return
desc
,
seq
def
__len__
(
self
):
return
self
.
offsets
.
size
def
_build_index
(
self
,
path
:
str
):
# Use grep and awk to get 100M/s on local SSD.
# Should process your enormous 100G fasta in ~10 min single core...
path
=
fasta_file_path
(
path
)
bytes_offsets
=
subprocess
.
check_output
(
f
"cat
{
path
}
| tqdm --bytes --total $(wc -c <
{
path
}
)"
"| grep --byte-offset '^>' -o | cut -d: -f1"
,
shell
=
True
,
)
fasta_lengths
=
subprocess
.
check_output
(
f
"cat
{
path
}
| tqdm --bytes --total $(wc -c <
{
path
}
)"
"| awk '/^>/ {print
\"\"
;next;} { printf(
\"
%s
\"
,$0);}' | tail -n+2 | awk '{print length($1)}'"
,
shell
=
True
,
)
bytes_np
=
np
.
fromstring
(
bytes_offsets
,
dtype
=
np
.
int64
,
sep
=
" "
)
sizes_np
=
np
.
fromstring
(
fasta_lengths
,
dtype
=
np
.
int64
,
sep
=
" "
)
return
bytes_np
,
sizes_np
def
__setstate__
(
self
,
state
):
self
.
__dict__
=
state
self
.
threadlocal
=
threading
.
local
()
def
__getstate__
(
self
):
d
=
{}
for
i
,
v
in
self
.
__dict__
.
items
():
if
i
!=
"threadlocal"
:
d
[
i
]
=
v
return
d
def
__del__
(
self
):
if
hasattr
(
self
.
threadlocal
,
"f"
):
self
.
threadlocal
.
f
.
close
()
del
self
.
threadlocal
.
f
@
staticmethod
def
exists
(
path
):
return
os
.
path
.
exists
(
fasta_file_path
(
path
))
class
EncodedFastaDataset
(
FastaDataset
):
"""
The FastaDataset returns raw sequences - this allows us to return
indices with a dictionary instead.
"""
def
__init__
(
self
,
path
,
dictionary
):
super
().
__init__
(
path
,
cache_indices
=
True
)
self
.
dictionary
=
dictionary
def
__getitem__
(
self
,
idx
):
desc
,
seq
=
super
().
__getitem__
(
idx
)
return
self
.
dictionary
.
encode_line
(
seq
,
line_tokenizer
=
list
).
long
()
PyTorch/NLP/new-Transformer/fairseq/data/huffman/__init__.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
.huffman_coder
import
HuffmanCodeBuilder
,
HuffmanCoder
from
.huffman_mmap_indexed_dataset
import
(
HuffmanMMapIndex
,
HuffmanMMapIndexedDataset
,
HuffmanMMapIndexedDatasetBuilder
,
vocab_file_path
,
)
__all__
=
[
"HuffmanCoder"
,
"HuffmanCodeBuilder"
,
"HuffmanMMapIndexedDatasetBuilder"
,
"HuffmanMMapIndexedDataset"
,
"HuffmanMMapIndex"
,
"vocab_file_path"
,
]
PyTorch/NLP/new-Transformer/fairseq/data/huffman/huffman_coder.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
re
import
typing
as
tp
from
collections
import
Counter
,
deque
from
dataclasses
import
dataclass
from
bitarray
import
bitarray
,
util
from
fairseq.data
import
Dictionary
# basically we have to write to addressable bytes for the memory mapped
# dataset loader. Sentences that get encoded to a length that is not a
# multiple of BLOCKSIZE (a byte) will be padded to fit. (see _pad in the coder)
BLOCKSIZE
=
8
class
HuffmanCoder
:
def
__init__
(
self
,
root
:
"HuffmanNode"
,
bos
=
"<s>"
,
pad
=
"<pad>"
,
eos
=
"</s>"
,
unk
=
"<unk>"
):
self
.
root
=
root
self
.
table
=
root
.
code_table
()
self
.
bos_word
,
self
.
unk_word
,
self
.
pad_word
,
self
.
eos_word
=
bos
,
unk
,
pad
,
eos
def
_pad
(
self
,
a
:
bitarray
)
->
bitarray
:
"""
bitpadding, 1 then 0.
If the array is already a multiple of blocksize, we add a full block.
"""
pad_len
=
BLOCKSIZE
-
(
len
(
a
)
%
BLOCKSIZE
)
-
1
padding
=
bitarray
(
"1"
+
"0"
*
pad_len
)
return
a
+
padding
def
_unpad
(
self
,
a
:
bitarray
)
->
bitarray
:
"""
remove the bitpadding.
There will be a set of 0s preceded by a 1 at the end of the bitarray, we remove that
"""
# count the 0 padding at the end until we find the first 1
# we want to remove the one too
remove_cnt
=
util
.
rindex
(
a
,
1
)
return
a
[:
remove_cnt
]
def
encode
(
self
,
iter
:
tp
.
List
[
str
])
->
bytes
:
"""
encode a list of tokens a return bytes. We use bitpadding to make sure the encoded bits fit in bytes.
"""
a
=
bitarray
()
for
token
in
iter
:
code
=
self
.
get_code
(
token
)
if
code
is
None
:
if
self
.
unk_word
is
None
:
raise
Exception
(
f
"unknown token
{
token
}
cannot be encoded."
)
else
:
token
=
self
.
unk_word
a
=
a
+
self
.
get_code
(
token
)
return
self
.
_pad
(
a
).
tobytes
()
def
decode
(
self
,
bits
:
bytes
)
->
tp
.
Iterator
[
"HuffmanNode"
]:
"""
take bitpadded bytes and decode it to a set of leaves. You can then use each node to find the symbol/id
"""
a
=
bitarray
()
a
.
frombytes
(
bits
)
return
self
.
root
.
decode
(
self
.
_unpad
(
a
))
def
get_code
(
self
,
symbol
:
str
)
->
tp
.
Optional
[
bitarray
]:
node
=
self
.
get_node
(
symbol
)
return
None
if
node
is
None
else
node
.
code
def
get_node
(
self
,
symbol
:
str
)
->
"HuffmanNode"
:
return
self
.
table
.
get
(
symbol
)
@
classmethod
def
from_file
(
cls
,
filename
:
str
,
bos
=
"<s>"
,
pad
=
"<pad>"
,
eos
=
"</s>"
,
unk
=
"<unk>"
,
)
->
"HuffmanCoder"
:
builder
=
HuffmanCodeBuilder
.
from_file
(
filename
)
return
builder
.
build_code
(
bos
=
bos
,
pad
=
pad
,
eos
=
eos
,
unk
=
unk
)
def
to_file
(
self
,
filename
,
sep
=
"
\t
"
):
nodes
=
list
(
self
.
table
.
values
())
nodes
.
sort
(
key
=
lambda
n
:
n
.
id
)
with
open
(
filename
,
"w"
,
encoding
=
"utf-8"
)
as
output
:
for
n
in
nodes
:
output
.
write
(
f
"
{
n
.
symbol
}{
sep
}{
n
.
count
}
\n
"
)
def
__iter__
(
self
):
for
n
in
self
.
table
.
values
():
yield
n
def
merge
(
self
,
other_coder
:
"HuffmanCoder"
)
->
"HuffmanCoder"
:
builder
=
HuffmanCodeBuilder
()
for
n
in
self
:
builder
.
increment
(
n
.
symbol
,
n
.
count
)
for
n
in
other_coder
:
builder
.
increment
(
n
.
symbol
,
n
.
count
)
return
builder
.
build_code
()
def
__eq__
(
self
,
other
:
"HuffmanCoder"
)
->
bool
:
return
self
.
table
==
other
.
table
def
__len__
(
self
)
->
int
:
return
len
(
self
.
table
)
def
__contains__
(
self
,
sym
:
str
)
->
bool
:
return
sym
in
self
.
table
def
to_dictionary
(
self
)
->
Dictionary
:
dictionary
=
Dictionary
(
bos
=
self
.
bos
,
unk
=
self
.
unk
,
pad
=
self
.
pad
,
eos
=
self
.
eos
)
for
n
in
self
:
dictionary
.
add_symbol
(
n
.
symbol
,
n
=
n
.
count
)
dictionary
.
finalize
()
return
dictionary
@
dataclass
class
HuffmanNode
:
"""
a node in a Huffman tree
"""
id
:
int
count
:
int
symbol
:
tp
.
Optional
[
str
]
=
None
left
:
tp
.
Optional
[
"HuffmanNode"
]
=
None
right
:
tp
.
Optional
[
"HuffmanNode"
]
=
None
code
:
tp
.
Optional
[
bitarray
]
=
None
def
is_leaf
(
self
)
->
bool
:
return
self
.
left
is
None
and
self
.
right
is
None
def
code_table
(
self
,
prefix
:
tp
.
Optional
[
bitarray
]
=
None
)
->
tp
.
Dict
[
str
,
"HuffmanNode"
]:
defaulted_prefix
=
prefix
if
prefix
is
not
None
else
bitarray
()
if
self
.
is_leaf
():
self
.
code
=
(
defaulted_prefix
if
len
(
defaulted_prefix
)
>
0
else
bitarray
(
"0"
)
)
# leaf could be the root if there is only one symbol
return
{
self
.
symbol
:
self
}
codes_right
=
self
.
right
.
code_table
(
defaulted_prefix
+
bitarray
([
0
]))
codes_left
=
self
.
left
.
code_table
(
defaulted_prefix
+
bitarray
([
1
]))
return
{
**
codes_left
,
**
codes_right
}
def
decode
(
self
,
bits
:
bitarray
)
->
tp
.
Iterator
[
"HuffmanNode"
]:
current_node
=
self
for
bit
in
bits
:
if
bit
==
0
:
# go right
current_node
=
current_node
.
right
else
:
# go left
current_node
=
current_node
.
left
if
current_node
is
None
:
# we shouldn't be on a leaf here
raise
Exception
(
"fell off a leaf"
)
if
current_node
.
is_leaf
():
yield
current_node
current_node
=
self
if
current_node
!=
self
:
raise
Exception
(
"couldn't decode all the bits"
)
class
HuffmanCodeBuilder
:
"""
build a dictionary with occurence count and then build the Huffman code for it.
"""
def
__init__
(
self
):
self
.
symbols
=
Counter
()
def
add_symbols
(
self
,
*
syms
)
->
None
:
self
.
symbols
.
update
(
syms
)
def
increment
(
self
,
symbol
:
str
,
cnt
:
int
)
->
None
:
self
.
symbols
[
symbol
]
+=
cnt
@
classmethod
def
from_file
(
cls
,
filename
):
c
=
cls
()
with
open
(
filename
,
"r"
,
encoding
=
"utf-8"
)
as
input
:
for
line
in
input
:
split
=
re
.
split
(
r
"[\s]+"
,
line
)
c
.
increment
(
split
[
0
],
int
(
split
[
1
]))
return
c
def
to_file
(
self
,
filename
,
sep
=
"
\t
"
):
with
open
(
filename
,
"w"
,
encoding
=
"utf-8"
)
as
output
:
for
(
tok
,
cnt
)
in
self
.
symbols
.
most_common
():
output
.
write
(
f
"
{
tok
}{
sep
}{
cnt
}
\n
"
)
def
_smallest
(
self
,
q1
:
deque
,
q2
:
deque
)
->
HuffmanNode
:
if
len
(
q1
)
==
0
:
return
q2
.
pop
()
if
len
(
q2
)
==
0
:
return
q1
.
pop
()
if
q1
[
-
1
].
count
<
q2
[
-
1
].
count
:
return
q1
.
pop
()
return
q2
.
pop
()
def
__add__
(
self
,
c
:
"HuffmanCodeBuilder"
)
->
"HuffmanCodeBuilder"
:
new_c
=
self
.
symbols
+
c
.
symbols
new_b
=
HuffmanCodeBuilder
()
new_b
.
symbols
=
new_c
return
new_b
def
build_code
(
self
,
bos
=
"<s>"
,
pad
=
"<pad>"
,
eos
=
"</s>"
,
unk
=
"<unk>"
,
)
->
HuffmanCoder
:
assert
len
(
self
.
symbols
)
>
0
,
"cannot build code from empty list of symbols"
if
self
.
symbols
[
bos
]
==
0
:
self
.
add_symbols
(
bos
)
if
self
.
symbols
[
pad
]
==
0
:
self
.
add_symbols
(
pad
)
if
self
.
symbols
[
eos
]
==
0
:
self
.
add_symbols
(
eos
)
if
self
.
symbols
[
unk
]
==
0
:
self
.
add_symbols
(
unk
)
node_id
=
0
leaves_queue
=
deque
(
[
HuffmanNode
(
symbol
=
symbol
,
count
=
count
,
id
=
idx
)
for
idx
,
(
symbol
,
count
)
in
enumerate
(
self
.
symbols
.
most_common
())
]
)
# left are the most common, right are the least common
if
len
(
leaves_queue
)
==
1
:
root
=
leaves_queue
.
pop
()
root
.
id
=
0
return
HuffmanCoder
(
root
)
nodes_queue
=
deque
()
while
len
(
leaves_queue
)
>
0
or
len
(
nodes_queue
)
!=
1
:
# get the lowest two nodes at the head of each queue
node1
=
self
.
_smallest
(
leaves_queue
,
nodes_queue
)
node2
=
self
.
_smallest
(
leaves_queue
,
nodes_queue
)
# add new node
nodes_queue
.
appendleft
(
HuffmanNode
(
count
=
node1
.
count
+
node2
.
count
,
left
=
node1
,
right
=
node2
,
id
=
node_id
)
)
node_id
+=
1
# we are left with the root
return
HuffmanCoder
(
nodes_queue
.
pop
(),
bos
=
bos
,
pad
=
pad
,
eos
=
eos
,
unk
=
unk
)
PyTorch/NLP/new-Transformer/fairseq/data/huffman/huffman_mmap_indexed_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
mmap
import
os
import
shutil
import
struct
import
typing
as
tp
from
functools
import
lru_cache
import
numpy
as
np
import
torch
from
fairseq.data
import
indexed_dataset
from
fairseq.data.huffman
import
HuffmanCoder
from
fairseq.file_io
import
PathManager
class
HuffmanMMapIndex
:
"""
keep an index of the offsets in the huffman binary file.
First a header, then the list of sizes (num tokens) for each instance and finally
the addresses of each instance.
"""
_HDR_MAGIC
=
b
"HUFFIDX
\x00\x00
"
_VERSION
=
1
@
classmethod
def
writer
(
cls
,
path
:
str
,
data_len
:
int
):
class
_Writer
:
def
__enter__
(
self
):
self
.
_file
=
open
(
path
,
"wb"
)
# write header (magic + version)
self
.
_file
.
write
(
cls
.
_HDR_MAGIC
)
self
.
_file
.
write
(
struct
.
pack
(
"<Q"
,
cls
.
_VERSION
))
self
.
_file
.
write
(
struct
.
pack
(
"<Q"
,
data_len
))
return
self
def
write
(
self
,
sizes
,
pointers
):
# add number of items in the index to the header
self
.
_file
.
write
(
struct
.
pack
(
"<Q"
,
len
(
sizes
)))
# write sizes
sizes
=
np
.
array
(
sizes
,
dtype
=
np
.
int32
)
self
.
_file
.
write
(
sizes
.
tobytes
(
order
=
"C"
))
del
sizes
# write address pointers
pointers
=
np
.
array
(
pointers
,
dtype
=
np
.
int64
)
self
.
_file
.
write
(
pointers
.
tobytes
(
order
=
"C"
))
del
pointers
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
_file
.
close
()
return
_Writer
()
def
__init__
(
self
,
path
):
with
open
(
path
,
"rb"
)
as
stream
:
# read headers
magic_test
=
stream
.
read
(
9
)
assert
self
.
_HDR_MAGIC
==
magic_test
,
(
"Index file doesn't match expected format. "
"Make sure that --dataset-impl is configured properly."
)
(
version
,)
=
struct
.
unpack
(
"<Q"
,
stream
.
read
(
8
))
assert
(
self
.
_VERSION
==
version
),
f
"Unexpected file version
{
version
}
!= code version
{
self
.
_VERSION
}
"
# read length of data file
(
self
.
_data_len
,)
=
struct
.
unpack
(
"<Q"
,
stream
.
read
(
8
))
# read number of items in data file/index
(
self
.
_len
,)
=
struct
.
unpack
(
"<Q"
,
stream
.
read
(
8
))
offset
=
stream
.
tell
()
indexed_dataset
.
_warmup_mmap_file
(
path
)
self
.
_bin_buffer_mmap
=
np
.
memmap
(
path
,
mode
=
"r"
,
order
=
"C"
)
self
.
_bin_buffer
=
memoryview
(
self
.
_bin_buffer_mmap
)
self
.
_sizes
=
np
.
frombuffer
(
self
.
_bin_buffer
,
dtype
=
np
.
int32
,
count
=
self
.
_len
,
offset
=
offset
)
self
.
_pointers
=
np
.
frombuffer
(
self
.
_bin_buffer
,
dtype
=
np
.
int64
,
count
=
self
.
_len
,
offset
=
offset
+
self
.
_sizes
.
nbytes
,
)
def
__del__
(
self
):
self
.
_bin_buffer_mmap
.
_mmap
.
close
()
del
self
.
_bin_buffer_mmap
def
__iter__
(
self
):
for
i
in
range
(
self
.
_len
):
yield
self
[
i
]
@
property
def
data_len
(
self
):
return
self
.
_data_len
@
property
def
sizes
(
self
):
return
self
.
_sizes
@
lru_cache
(
maxsize
=
8
)
def
__getitem__
(
self
,
i
):
return
self
.
_pointers
[
i
],
self
.
_sizes
[
i
]
def
__len__
(
self
):
return
self
.
_len
def
vocab_file_path
(
prefix_path
):
return
prefix_path
+
".vocab"
class
HuffmanMMapIndexedDataset
(
torch
.
utils
.
data
.
Dataset
):
"""
an indexed dataset that use mmap and memoryview to access data from disk
that was compressed with a HuffmanCoder.
"""
def
__init__
(
self
,
prefix_path
):
super
().
__init__
()
self
.
_prefix_path
=
None
self
.
_index
=
None
self
.
_bin_buffer
=
None
self
.
_coder
=
None
self
.
_file
=
None
self
.
_bin_buffer_mmap
=
None
self
.
_do_init
(
prefix_path
)
def
__getstate__
(
self
):
return
self
.
_prefix_path
def
__setstate__
(
self
,
state
):
self
.
_do_init
(
state
)
def
_do_init
(
self
,
prefix_path
):
self
.
_prefix_path
=
prefix_path
self
.
_index
=
HuffmanMMapIndex
(
indexed_dataset
.
index_file_path
(
self
.
_prefix_path
)
)
self
.
_coder
=
HuffmanCoder
.
from_file
(
vocab_file_path
(
self
.
_prefix_path
))
indexed_dataset
.
_warmup_mmap_file
(
indexed_dataset
.
data_file_path
(
self
.
_prefix_path
)
)
self
.
_file
=
os
.
open
(
indexed_dataset
.
data_file_path
(
self
.
_prefix_path
),
os
.
O_RDONLY
)
self
.
_bin_buffer_mmap
=
mmap
.
mmap
(
self
.
_file
,
self
.
_index
.
data_len
,
access
=
mmap
.
ACCESS_READ
,
)
self
.
_bin_buffer
=
memoryview
(
self
.
_bin_buffer_mmap
)
def
__del__
(
self
):
del
self
.
_bin_buffer
if
self
.
_file
:
os
.
close
(
self
.
_file
)
del
self
.
_index
def
__len__
(
self
):
return
len
(
self
.
_index
)
def
_decode
(
self
,
i
):
ptr
,
_
=
self
.
_index
[
i
]
if
i
==
0
:
raw_bytes
=
self
.
_bin_buffer
[:
ptr
]
else
:
(
prev_ptr
,
_
)
=
self
.
_index
[
i
-
1
]
raw_bytes
=
self
.
_bin_buffer
[
prev_ptr
:
ptr
]
return
self
.
_coder
.
decode
(
raw_bytes
.
tobytes
())
@
lru_cache
(
maxsize
=
8
)
def
__getitem__
(
self
,
i
):
nodes
=
self
.
_decode
(
i
)
return
torch
.
tensor
([
n
.
id
for
n
in
nodes
],
dtype
=
torch
.
int64
)
def
__iter__
(
self
):
for
idx
in
range
(
len
(
self
)):
yield
self
[
idx
]
def
get_symbols
(
self
,
i
):
nodes
=
self
.
_decode
(
i
)
for
n
in
nodes
:
yield
n
.
symbol
@
property
def
sizes
(
self
):
return
self
.
_index
.
sizes
@
property
def
supports_prefetch
(
self
):
return
False
@
property
def
coder
(
self
):
return
self
.
_coder
@
staticmethod
def
exists
(
prefix_path
):
return
(
PathManager
.
exists
(
indexed_dataset
.
index_file_path
(
prefix_path
))
and
PathManager
.
exists
(
indexed_dataset
.
data_file_path
(
prefix_path
))
and
PathManager
.
exists
(
vocab_file_path
(
prefix_path
))
)
class
HuffmanMMapIndexedDatasetBuilder
:
"""
Helper to build a memory mapped datasets with a huffman encoder.
You can either open/close this manually or use it as a ContextManager.
Provide your own coder, it will then be stored alongside the dataset.
The builder will first write the vocab file, then open the binary file so you can stream
into it, finally the index will be written when the builder is closed (your index should fit in memory).
"""
def
__init__
(
self
,
path_prefix
:
str
,
coder
:
HuffmanCoder
)
->
None
:
self
.
_path_prefix
=
path_prefix
self
.
_coder
=
coder
self
.
_sizes
=
[]
self
.
_ptrs
=
[]
self
.
_data_len
=
0
def
open
(
self
):
self
.
_coder
.
to_file
(
vocab_file_path
(
self
.
_path_prefix
))
self
.
_data_file
=
open
(
indexed_dataset
.
data_file_path
(
self
.
_path_prefix
),
"wb"
)
def
__enter__
(
self
)
->
"HuffmanMMapIndexedDatasetBuilder"
:
self
.
open
()
return
self
def
add_item
(
self
,
tokens
:
tp
.
List
[
str
])
->
None
:
"""
add a list of tokens to the dataset, they will compressed with the
provided coder before being written to file.
"""
encoded
=
self
.
_coder
.
encode
(
tokens
)
code_len
=
len
(
encoded
)
last_ptr
=
0
if
len
(
self
.
_ptrs
)
>
0
:
last_ptr
=
self
.
_ptrs
[
-
1
]
self
.
_sizes
.
append
(
len
(
tokens
))
self
.
_ptrs
.
append
(
last_ptr
+
code_len
)
self
.
_data_len
+=
code_len
self
.
_data_file
.
write
(
encoded
)
def
append
(
self
,
other_dataset_path_prefix
:
str
)
->
None
:
"""
append an existing dataset.
Beware, if it wasn't built with the same coder, you are in trouble.
"""
other_index
=
HuffmanMMapIndex
(
indexed_dataset
.
index_file_path
(
other_dataset_path_prefix
)
)
for
(
ptr
,
size
)
in
other_index
:
self
.
_ptrs
.
append
(
ptr
+
self
.
_data_len
)
self
.
_sizes
.
append
(
size
)
# Concatenate data
with
open
(
indexed_dataset
.
data_file_path
(
other_dataset_path_prefix
),
"rb"
)
as
f
:
shutil
.
copyfileobj
(
f
,
self
.
_data_file
)
self
.
_data_len
+=
other_index
.
data_len
def
close
(
self
):
self
.
_data_file
.
close
()
with
HuffmanMMapIndex
.
writer
(
indexed_dataset
.
index_file_path
(
self
.
_path_prefix
),
self
.
_data_len
)
as
index
:
index
.
write
(
self
.
_sizes
,
self
.
_ptrs
)
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
)
->
None
:
self
.
close
()
PyTorch/NLP/new-Transformer/fairseq/data/id_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
from
.
import
FairseqDataset
class
IdDataset
(
FairseqDataset
):
def
__getitem__
(
self
,
index
):
return
index
def
__len__
(
self
):
return
0
def
collater
(
self
,
samples
):
return
torch
.
tensor
(
samples
)
PyTorch/NLP/new-Transformer/fairseq/data/indexed_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
shutil
import
struct
from
functools
import
lru_cache
import
numpy
as
np
import
torch
from
fairseq.dataclass.constants
import
DATASET_IMPL_CHOICES
from
fairseq.data.fasta_dataset
import
FastaDataset
from
fairseq.file_io
import
PathManager
from
fairseq.data.huffman
import
HuffmanMMapIndexedDataset
,
HuffmanMMapIndex
from
.
import
FairseqDataset
from
typing
import
Union
def
best_fitting_int_dtype
(
max_int_to_represent
,
)
->
Union
[
np
.
uint16
,
np
.
uint32
,
np
.
int64
]:
if
max_int_to_represent
is
None
:
return
np
.
uint32
# Safe guess
elif
max_int_to_represent
<
65500
:
return
np
.
uint16
elif
max_int_to_represent
<
4294967295
:
return
np
.
uint32
else
:
return
np
.
int64
# we avoid np.uint64 because it doesn't save space and its type promotion behaves unexpectedly
# https://github.com/numpy/numpy/issues/5745
def
get_available_dataset_impl
():
return
list
(
map
(
str
,
DATASET_IMPL_CHOICES
))
def
infer_dataset_impl
(
path
):
if
IndexedRawTextDataset
.
exists
(
path
):
return
"raw"
elif
IndexedDataset
.
exists
(
path
):
with
open
(
index_file_path
(
path
),
"rb"
)
as
f
:
magic
=
f
.
read
(
8
)
if
magic
==
IndexedDataset
.
_HDR_MAGIC
:
return
"cached"
elif
magic
==
MMapIndexedDataset
.
Index
.
_HDR_MAGIC
[:
8
]:
return
"mmap"
elif
magic
==
HuffmanMMapIndex
.
_HDR_MAGIC
[:
8
]:
return
"huffman"
else
:
return
None
elif
FastaDataset
.
exists
(
path
):
return
"fasta"
else
:
return
None
def
make_builder
(
out_file
,
impl
,
vocab_size
=
None
):
if
impl
==
"mmap"
:
return
MMapIndexedDatasetBuilder
(
out_file
,
dtype
=
best_fitting_int_dtype
(
vocab_size
)
)
elif
impl
==
"fasta"
:
raise
NotImplementedError
elif
impl
==
"huffman"
:
raise
ValueError
(
"Use HuffmanCodeBuilder directly as it has a different interface."
)
else
:
return
IndexedDatasetBuilder
(
out_file
)
def
make_dataset
(
path
,
impl
,
fix_lua_indexing
=
False
,
dictionary
=
None
):
if
impl
==
"raw"
and
IndexedRawTextDataset
.
exists
(
path
):
assert
dictionary
is
not
None
return
IndexedRawTextDataset
(
path
,
dictionary
)
elif
impl
==
"lazy"
and
IndexedDataset
.
exists
(
path
):
return
IndexedDataset
(
path
,
fix_lua_indexing
=
fix_lua_indexing
)
elif
impl
==
"cached"
and
IndexedDataset
.
exists
(
path
):
return
IndexedCachedDataset
(
path
,
fix_lua_indexing
=
fix_lua_indexing
)
elif
impl
==
"mmap"
and
MMapIndexedDataset
.
exists
(
path
):
return
MMapIndexedDataset
(
path
)
elif
impl
==
"fasta"
and
FastaDataset
.
exists
(
path
):
from
fairseq.data.fasta_dataset
import
EncodedFastaDataset
return
EncodedFastaDataset
(
path
,
dictionary
)
elif
impl
==
"huffman"
and
HuffmanMMapIndexedDataset
.
exists
(
path
):
return
HuffmanMMapIndexedDataset
(
path
)
return
None
def
dataset_exists
(
path
,
impl
):
if
impl
==
"raw"
:
return
IndexedRawTextDataset
.
exists
(
path
)
elif
impl
==
"mmap"
:
return
MMapIndexedDataset
.
exists
(
path
)
elif
impl
==
"huffman"
:
return
HuffmanMMapIndexedDataset
.
exists
(
path
)
else
:
return
IndexedDataset
.
exists
(
path
)
def
read_longs
(
f
,
n
):
a
=
np
.
empty
(
n
,
dtype
=
np
.
int64
)
f
.
readinto
(
a
)
return
a
def
write_longs
(
f
,
a
):
f
.
write
(
np
.
array
(
a
,
dtype
=
np
.
int64
))
_code_to_dtype
=
{
1
:
np
.
uint8
,
2
:
np
.
int8
,
3
:
np
.
int16
,
4
:
np
.
int32
,
5
:
np
.
int64
,
6
:
np
.
float64
,
7
:
np
.
double
,
8
:
np
.
uint16
,
9
:
np
.
uint32
,
10
:
np
.
uint64
,
}
def
_dtype_header_code
(
dtype
)
->
int
:
for
k
in
_code_to_dtype
.
keys
():
if
_code_to_dtype
[
k
]
==
dtype
:
return
k
raise
ValueError
(
dtype
)
def
index_file_path
(
prefix_path
):
return
prefix_path
+
".idx"
def
data_file_path
(
prefix_path
):
return
prefix_path
+
".bin"
class
IndexedDataset
(
FairseqDataset
):
"""Loader for TorchNet IndexedDataset"""
_HDR_MAGIC
=
b
"TNTIDX
\x00\x00
"
def
__init__
(
self
,
path
,
fix_lua_indexing
=
False
):
super
().
__init__
()
self
.
path
=
path
self
.
fix_lua_indexing
=
fix_lua_indexing
self
.
data_file
=
None
self
.
read_index
(
path
)
def
read_index
(
self
,
path
):
with
open
(
index_file_path
(
path
),
"rb"
)
as
f
:
magic
=
f
.
read
(
8
)
assert
magic
==
self
.
_HDR_MAGIC
,
(
"Index file doesn't match expected format. "
"Make sure that --dataset-impl is configured properly."
)
version
=
f
.
read
(
8
)
assert
struct
.
unpack
(
"<Q"
,
version
)
==
(
1
,)
code
,
self
.
element_size
=
struct
.
unpack
(
"<QQ"
,
f
.
read
(
16
))
self
.
dtype
=
_code_to_dtype
[
code
]
self
.
_len
,
self
.
s
=
struct
.
unpack
(
"<QQ"
,
f
.
read
(
16
))
self
.
dim_offsets
=
read_longs
(
f
,
self
.
_len
+
1
)
self
.
data_offsets
=
read_longs
(
f
,
self
.
_len
+
1
)
self
.
sizes
=
read_longs
(
f
,
self
.
s
)
def
read_data
(
self
,
path
):
self
.
data_file
=
open
(
data_file_path
(
path
),
"rb"
,
buffering
=
0
)
def
check_index
(
self
,
i
):
if
i
<
0
or
i
>=
self
.
_len
:
raise
IndexError
(
"index out of range"
)
def
__del__
(
self
):
if
self
.
data_file
:
self
.
data_file
.
close
()
@
lru_cache
(
maxsize
=
8
)
def
__getitem__
(
self
,
i
)
->
torch
.
Tensor
:
if
not
self
.
data_file
:
self
.
read_data
(
self
.
path
)
self
.
check_index
(
i
)
tensor_size
=
self
.
sizes
[
self
.
dim_offsets
[
i
]
:
self
.
dim_offsets
[
i
+
1
]]
a
=
np
.
empty
(
tensor_size
,
dtype
=
self
.
dtype
)
self
.
data_file
.
seek
(
self
.
data_offsets
[
i
]
*
self
.
element_size
)
self
.
data_file
.
readinto
(
a
)
item
=
torch
.
from_numpy
(
a
).
long
()
if
self
.
fix_lua_indexing
:
item
-=
1
# subtract 1 for 0-based indexing
return
item
def
__len__
(
self
):
return
self
.
_len
def
num_tokens
(
self
,
index
):
return
self
.
sizes
[
index
]
def
size
(
self
,
index
):
return
self
.
sizes
[
index
]
@
staticmethod
def
exists
(
path
):
return
PathManager
.
exists
(
index_file_path
(
path
))
and
PathManager
.
exists
(
data_file_path
(
path
)
)
@
property
def
supports_prefetch
(
self
):
return
False
# avoid prefetching to save memory
class
IndexedCachedDataset
(
IndexedDataset
):
def
__init__
(
self
,
path
,
fix_lua_indexing
=
False
):
super
().
__init__
(
path
,
fix_lua_indexing
=
fix_lua_indexing
)
self
.
cache
=
None
self
.
cache_index
=
{}
@
property
def
supports_prefetch
(
self
):
return
True
def
prefetch
(
self
,
indices
):
if
all
(
i
in
self
.
cache_index
for
i
in
indices
):
return
if
not
self
.
data_file
:
self
.
read_data
(
self
.
path
)
indices
=
sorted
(
set
(
indices
))
total_size
=
0
for
i
in
indices
:
total_size
+=
self
.
data_offsets
[
i
+
1
]
-
self
.
data_offsets
[
i
]
self
.
cache
=
np
.
empty
(
total_size
,
dtype
=
self
.
dtype
)
ptx
=
0
self
.
cache_index
.
clear
()
for
i
in
indices
:
self
.
cache_index
[
i
]
=
ptx
size
=
self
.
data_offsets
[
i
+
1
]
-
self
.
data_offsets
[
i
]
a
=
self
.
cache
[
ptx
:
ptx
+
size
]
self
.
data_file
.
seek
(
self
.
data_offsets
[
i
]
*
self
.
element_size
)
self
.
data_file
.
readinto
(
a
)
ptx
+=
size
if
self
.
data_file
:
# close and delete data file after prefetch so we can pickle
self
.
data_file
.
close
()
self
.
data_file
=
None
@
lru_cache
(
maxsize
=
8
)
def
__getitem__
(
self
,
i
):
self
.
check_index
(
i
)
tensor_size
=
self
.
sizes
[
self
.
dim_offsets
[
i
]
:
self
.
dim_offsets
[
i
+
1
]]
a
=
np
.
empty
(
tensor_size
,
dtype
=
self
.
dtype
)
ptx
=
self
.
cache_index
[
i
]
np
.
copyto
(
a
,
self
.
cache
[
ptx
:
ptx
+
a
.
size
])
item
=
torch
.
from_numpy
(
a
).
long
()
if
self
.
fix_lua_indexing
:
item
-=
1
# subtract 1 for 0-based indexing
return
item
class
IndexedRawTextDataset
(
FairseqDataset
):
"""Takes a text file as input and binarizes it in memory at instantiation.
Original lines are also kept in memory"""
def
__init__
(
self
,
path
,
dictionary
,
append_eos
=
True
,
reverse_order
=
False
):
self
.
tokens_list
=
[]
self
.
lines
=
[]
self
.
sizes
=
[]
self
.
append_eos
=
append_eos
self
.
reverse_order
=
reverse_order
self
.
read_data
(
path
,
dictionary
)
self
.
size
=
len
(
self
.
tokens_list
)
def
read_data
(
self
,
path
,
dictionary
):
with
open
(
path
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
for
line
in
f
:
self
.
lines
.
append
(
line
.
strip
(
"
\n
"
))
tokens
=
dictionary
.
encode_line
(
line
,
add_if_not_exist
=
False
,
append_eos
=
self
.
append_eos
,
reverse_order
=
self
.
reverse_order
,
).
long
()
self
.
tokens_list
.
append
(
tokens
)
self
.
sizes
.
append
(
len
(
tokens
))
self
.
sizes
=
np
.
array
(
self
.
sizes
)
def
check_index
(
self
,
i
):
if
i
<
0
or
i
>=
self
.
size
:
raise
IndexError
(
"index out of range"
)
@
lru_cache
(
maxsize
=
8
)
def
__getitem__
(
self
,
i
):
self
.
check_index
(
i
)
return
self
.
tokens_list
[
i
]
def
get_original_text
(
self
,
i
):
self
.
check_index
(
i
)
return
self
.
lines
[
i
]
def
__del__
(
self
):
pass
def
__len__
(
self
):
return
self
.
size
def
num_tokens
(
self
,
index
):
return
self
.
sizes
[
index
]
def
size
(
self
,
index
):
return
self
.
sizes
[
index
]
@
staticmethod
def
exists
(
path
):
return
PathManager
.
exists
(
path
)
class
IndexedDatasetBuilder
:
element_sizes
=
{
np
.
uint8
:
1
,
np
.
int8
:
1
,
np
.
int16
:
2
,
np
.
int32
:
4
,
np
.
int64
:
8
,
np
.
float64
:
4
,
np
.
double
:
8
,
}
def
__init__
(
self
,
out_file
,
dtype
=
np
.
int32
):
self
.
out_file
=
open
(
out_file
,
"wb"
)
self
.
dtype
=
dtype
self
.
data_offsets
=
[
0
]
self
.
dim_offsets
=
[
0
]
self
.
sizes
=
[]
self
.
element_size
=
self
.
element_sizes
[
self
.
dtype
]
def
add_item
(
self
,
tensor
):
# +1 for Lua compatibility
bytes
=
self
.
out_file
.
write
(
np
.
array
(
tensor
.
numpy
()
+
1
,
dtype
=
self
.
dtype
))
self
.
data_offsets
.
append
(
self
.
data_offsets
[
-
1
]
+
bytes
/
self
.
element_size
)
for
s
in
tensor
.
size
():
self
.
sizes
.
append
(
s
)
self
.
dim_offsets
.
append
(
self
.
dim_offsets
[
-
1
]
+
len
(
tensor
.
size
()))
def
merge_file_
(
self
,
another_file
):
index
=
IndexedDataset
(
another_file
)
assert
index
.
dtype
==
self
.
dtype
begin
=
self
.
data_offsets
[
-
1
]
for
offset
in
index
.
data_offsets
[
1
:]:
self
.
data_offsets
.
append
(
begin
+
offset
)
self
.
sizes
.
extend
(
index
.
sizes
)
begin
=
self
.
dim_offsets
[
-
1
]
for
dim_offset
in
index
.
dim_offsets
[
1
:]:
self
.
dim_offsets
.
append
(
begin
+
dim_offset
)
with
open
(
data_file_path
(
another_file
),
"rb"
)
as
f
:
while
True
:
data
=
f
.
read
(
1024
)
if
data
:
self
.
out_file
.
write
(
data
)
else
:
break
def
finalize
(
self
,
index_file
):
self
.
out_file
.
close
()
index
=
open
(
index_file
,
"wb"
)
index
.
write
(
b
"TNTIDX
\x00\x00
"
)
index
.
write
(
struct
.
pack
(
"<Q"
,
1
))
index
.
write
(
struct
.
pack
(
"<QQ"
,
_dtype_header_code
(
self
.
dtype
),
self
.
element_size
)
)
index
.
write
(
struct
.
pack
(
"<QQ"
,
len
(
self
.
data_offsets
)
-
1
,
len
(
self
.
sizes
)))
write_longs
(
index
,
self
.
dim_offsets
)
write_longs
(
index
,
self
.
data_offsets
)
write_longs
(
index
,
self
.
sizes
)
index
.
close
()
def
_warmup_mmap_file
(
path
):
with
open
(
path
,
"rb"
)
as
stream
:
while
stream
.
read
(
100
*
1024
*
1024
):
pass
class
MMapIndexedDataset
(
torch
.
utils
.
data
.
Dataset
):
class
Index
:
_HDR_MAGIC
=
b
"MMIDIDX
\x00\x00
"
@
classmethod
def
writer
(
cls
,
path
,
dtype
):
class
_Writer
:
def
__enter__
(
self
):
self
.
_file
=
open
(
path
,
"wb"
)
self
.
_file
.
write
(
cls
.
_HDR_MAGIC
)
self
.
_file
.
write
(
struct
.
pack
(
"<Q"
,
1
))
self
.
_file
.
write
(
struct
.
pack
(
"<B"
,
_dtype_header_code
(
dtype
)))
return
self
@
staticmethod
def
_get_pointers
(
sizes
):
dtype_size
=
dtype
().
itemsize
address
=
0
pointers
=
[]
for
size
in
sizes
:
pointers
.
append
(
address
)
address
+=
size
*
dtype_size
return
pointers
def
write
(
self
,
sizes
):
pointers
=
self
.
_get_pointers
(
sizes
)
self
.
_file
.
write
(
struct
.
pack
(
"<Q"
,
len
(
sizes
)))
sizes
=
np
.
array
(
sizes
,
dtype
=
np
.
int32
)
self
.
_file
.
write
(
sizes
.
tobytes
(
order
=
"C"
))
del
sizes
pointers
=
np
.
array
(
pointers
,
dtype
=
np
.
int64
)
self
.
_file
.
write
(
pointers
.
tobytes
(
order
=
"C"
))
del
pointers
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
_file
.
close
()
return
_Writer
()
def
__init__
(
self
,
path
):
with
open
(
path
,
"rb"
)
as
stream
:
magic_test
=
stream
.
read
(
9
)
assert
self
.
_HDR_MAGIC
==
magic_test
,
(
"Index file doesn't match expected format. "
"Make sure that --dataset-impl is configured properly."
)
version
=
struct
.
unpack
(
"<Q"
,
stream
.
read
(
8
))
assert
(
1
,)
==
version
(
dtype_code
,)
=
struct
.
unpack
(
"<B"
,
stream
.
read
(
1
))
self
.
_dtype
=
_code_to_dtype
[
dtype_code
]
self
.
_dtype_size
=
self
.
_dtype
().
itemsize
self
.
_len
=
struct
.
unpack
(
"<Q"
,
stream
.
read
(
8
))[
0
]
offset
=
stream
.
tell
()
_warmup_mmap_file
(
path
)
self
.
_bin_buffer_mmap
=
np
.
memmap
(
path
,
mode
=
"r"
,
order
=
"C"
)
self
.
_bin_buffer
=
memoryview
(
self
.
_bin_buffer_mmap
)
self
.
_sizes
=
np
.
frombuffer
(
self
.
_bin_buffer
,
dtype
=
np
.
int32
,
count
=
self
.
_len
,
offset
=
offset
)
self
.
_pointers
=
np
.
frombuffer
(
self
.
_bin_buffer
,
dtype
=
np
.
int64
,
count
=
self
.
_len
,
offset
=
offset
+
self
.
_sizes
.
nbytes
,
)
def
__del__
(
self
):
self
.
_bin_buffer_mmap
.
_mmap
.
close
()
del
self
.
_bin_buffer_mmap
@
property
def
dtype
(
self
):
return
self
.
_dtype
@
property
def
sizes
(
self
):
return
self
.
_sizes
@
lru_cache
(
maxsize
=
8
)
def
__getitem__
(
self
,
i
):
return
self
.
_pointers
[
i
],
self
.
_sizes
[
i
]
def
__len__
(
self
):
return
self
.
_len
def
__init__
(
self
,
path
):
super
().
__init__
()
self
.
_path
=
None
self
.
_index
=
None
self
.
_bin_buffer
=
None
self
.
_do_init
(
path
)
def
__getstate__
(
self
):
return
self
.
_path
def
__setstate__
(
self
,
state
):
self
.
_do_init
(
state
)
def
_do_init
(
self
,
path
):
self
.
_path
=
path
self
.
_index
=
self
.
Index
(
index_file_path
(
self
.
_path
))
_warmup_mmap_file
(
data_file_path
(
self
.
_path
))
self
.
_bin_buffer_mmap
=
np
.
memmap
(
data_file_path
(
self
.
_path
),
mode
=
"r"
,
order
=
"C"
)
self
.
_bin_buffer
=
memoryview
(
self
.
_bin_buffer_mmap
)
def
__del__
(
self
):
self
.
_bin_buffer_mmap
.
_mmap
.
close
()
del
self
.
_bin_buffer_mmap
del
self
.
_index
def
__len__
(
self
):
return
len
(
self
.
_index
)
@
lru_cache
(
maxsize
=
8
)
def
__getitem__
(
self
,
i
):
ptr
,
size
=
self
.
_index
[
i
]
np_array
=
np
.
frombuffer
(
self
.
_bin_buffer
,
dtype
=
self
.
_index
.
dtype
,
count
=
size
,
offset
=
ptr
)
if
self
.
_index
.
dtype
!=
np
.
int64
:
np_array
=
np_array
.
astype
(
np
.
int64
)
return
torch
.
from_numpy
(
np_array
)
@
property
def
sizes
(
self
):
return
self
.
_index
.
sizes
@
property
def
supports_prefetch
(
self
):
return
False
@
staticmethod
def
exists
(
path
):
return
PathManager
.
exists
(
index_file_path
(
path
))
and
PathManager
.
exists
(
data_file_path
(
path
)
)
def
get_indexed_dataset_to_local
(
path
)
->
str
:
local_index_path
=
PathManager
.
get_local_path
(
index_file_path
(
path
))
local_data_path
=
PathManager
.
get_local_path
(
data_file_path
(
path
))
assert
local_index_path
.
endswith
(
".idx"
)
and
local_data_path
.
endswith
(
".bin"
),
(
"PathManager.get_local_path does not return files with expected patterns: "
f
"
{
local_index_path
}
and
{
local_data_path
}
"
)
local_path
=
local_data_path
[:
-
4
]
# stripping surfix ".bin"
assert
local_path
==
local_index_path
[:
-
4
]
# stripping surfix ".idx"
return
local_path
class
MMapIndexedDatasetBuilder
:
def
__init__
(
self
,
out_file
,
dtype
=
np
.
int64
):
self
.
_data_file
=
open
(
out_file
,
"wb"
)
self
.
_dtype
=
dtype
self
.
_sizes
=
[]
def
add_item
(
self
,
tensor
):
np_array
=
np
.
array
(
tensor
.
numpy
(),
dtype
=
self
.
_dtype
)
self
.
_data_file
.
write
(
np_array
.
tobytes
(
order
=
"C"
))
self
.
_sizes
.
append
(
np_array
.
size
)
def
merge_file_
(
self
,
another_file
):
# Concatenate index
index
=
MMapIndexedDataset
.
Index
(
index_file_path
(
another_file
))
assert
index
.
dtype
==
self
.
_dtype
for
size
in
index
.
sizes
:
self
.
_sizes
.
append
(
size
)
# Concatenate data
with
open
(
data_file_path
(
another_file
),
"rb"
)
as
f
:
shutil
.
copyfileobj
(
f
,
self
.
_data_file
)
def
finalize
(
self
,
index_file
):
self
.
_data_file
.
close
()
with
MMapIndexedDataset
.
Index
.
writer
(
index_file
,
self
.
_dtype
)
as
index
:
index
.
write
(
self
.
_sizes
)
Prev
1
…
7
8
9
10
11
12
13
14
15
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment