Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
dcuai
dlexamples
Commits
c0f05c10
"docs/source/en/api/schedulers/multistep_dpm_solver.md" did not exist on "75d53cc83966b4046e5a329ddf7baa6aa24f52e2"
Commit
c0f05c10
authored
Nov 29, 2022
by
hepj
Browse files
更新transformer代码
parent
c056df78
Changes
595
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
1813 deletions
+0
-1813
PyTorch/NLP/Transformer/.dockerignore
PyTorch/NLP/Transformer/.dockerignore
+0
-2
PyTorch/NLP/Transformer/CONTRIBUTING.md
PyTorch/NLP/Transformer/CONTRIBUTING.md
+0
-30
PyTorch/NLP/Transformer/Dockerfile
PyTorch/NLP/Transformer/Dockerfile
+0
-47
PyTorch/NLP/Transformer/LICENSE
PyTorch/NLP/Transformer/LICENSE
+0
-31
PyTorch/NLP/Transformer/NOTICE
PyTorch/NLP/Transformer/NOTICE
+0
-6
PyTorch/NLP/Transformer/PATENTS
PyTorch/NLP/Transformer/PATENTS
+0
-33
PyTorch/NLP/Transformer/README.md.old
PyTorch/NLP/Transformer/README.md.old
+0
-563
PyTorch/NLP/Transformer/average_valid_loss.png
PyTorch/NLP/Transformer/average_valid_loss.png
+0
-0
PyTorch/NLP/Transformer/bleu_relationship.png
PyTorch/NLP/Transformer/bleu_relationship.png
+0
-0
PyTorch/NLP/Transformer/decorrelation_threshold.png
PyTorch/NLP/Transformer/decorrelation_threshold.png
+0
-0
PyTorch/NLP/Transformer/distributed_train.py
PyTorch/NLP/Transformer/distributed_train.py
+0
-63
PyTorch/NLP/Transformer/fairseq/__init__.py
PyTorch/NLP/Transformer/fairseq/__init__.py
+0
-10
PyTorch/NLP/Transformer/fairseq/criterions.py
PyTorch/NLP/Transformer/fairseq/criterions.py
+0
-48
PyTorch/NLP/Transformer/fairseq/data/__init__.py
PyTorch/NLP/Transformer/fairseq/data/__init__.py
+0
-27
PyTorch/NLP/Transformer/fairseq/data/csrc/make_batches.cpp
PyTorch/NLP/Transformer/fairseq/data/csrc/make_batches.cpp
+0
-77
PyTorch/NLP/Transformer/fairseq/data/data_utils.py
PyTorch/NLP/Transformer/fairseq/data/data_utils.py
+0
-327
PyTorch/NLP/Transformer/fairseq/data/fairseq_dataset.py
PyTorch/NLP/Transformer/fairseq/data/fairseq_dataset.py
+0
-35
PyTorch/NLP/Transformer/fairseq/data/indexed_dataset.py
PyTorch/NLP/Transformer/fairseq/data/indexed_dataset.py
+0
-206
PyTorch/NLP/Transformer/fairseq/data/language_pair_dataset.py
...rch/NLP/Transformer/fairseq/data/language_pair_dataset.py
+0
-200
PyTorch/NLP/Transformer/fairseq/data/token_block_dataset.py
PyTorch/NLP/Transformer/fairseq/data/token_block_dataset.py
+0
-108
No files found.
PyTorch/NLP/Transformer/.dockerignore
deleted
100644 → 0
View file @
c056df78
results
data
PyTorch/NLP/Transformer/CONTRIBUTING.md
deleted
100644 → 0
View file @
c056df78
# Contributing to FAIR Sequence-to-Sequence Toolkit (PyTorch)
We want to make contributing to this project as easy and transparent as
possible.
## Pull Requests
We actively welcome your pull requests.
1.
Fork the repo and create your branch from
`master`
.
2.
If you've added code that should be tested, add tests.
3.
If you've changed APIs, update the documentation.
4.
Ensure the test suite passes.
5.
Make sure your code lints.
6.
If you haven't already, complete the Contributor License Agreement ("CLA").
## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Facebook's open source projects.
Complete your CLA here:
<https://code.facebook.com/cla>
## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.
## Coding Style
We try to follow the PEP style guidelines and encourage you to as well.
## License
By contributing to FAIR Sequence-to-Sequence Toolkit, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.
\ No newline at end of file
PyTorch/NLP/Transformer/Dockerfile
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
ARG
FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.05-py3
FROM
${FROM_IMAGE_NAME}
WORKDIR
/workspace
#RUN git clone https://github.com/NVIDIA/apex \
# && cd apex \
# && pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
# Install Python dependencies
RUN
pip
install
--no-cache-dir
\
sacrebleu
\
sentencepiece
RUN
pip
install
jupyter
ARG
DEBIAN_FRONTEND=noninteractive
RUN
apt-get update
RUN
apt-get
install
-y
-q
cmake pkg-config protobuf-compiler libprotobuf-dev libgoogle-perftools-dev
RUN
git clone https://github.com/google/sentencepiece.git /workspace/sentencepiece
RUN
cd
/workspace/sentencepiece
\
&&
git checkout d4dd947
\
&&
mkdir
build
\
&&
cd
build
\
&&
cmake ..
\
&&
make
-j
8
\
&&
make
install
\
&&
ldconfig
-v
ENV
PYTHONPATH=/workspace/translation/examples/translation/subword-nmt/
WORKDIR
/workspace/translation
RUN
git clone https://github.com/rsennrich/subword-nmt.git /workspace/translation/examples/translation/subword-nmt/
RUN
git clone https://github.com/NVIDIA/cutlass.git
&&
cd
cutlass
&&
git checkout ed2ed4d6
&&
cd
..
COPY
. .
RUN
pip
install
-e
.
RUN
pip
install
git+https://github.com/NVIDIA/dllogger@v0.1.0#egg
=
dllogger
PyTorch/NLP/Transformer/LICENSE
deleted
100644 → 0
View file @
c056df78
BSD License
For fairseq software
Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
Copyright (c) 2019-present, NVIDIA CORPORATION. All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name Facebook nor the names of its contributors may be used to
endorse or promote products derived from this software without specific
prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
PyTorch/NLP/Transformer/NOTICE
deleted
100644 → 0
View file @
c056df78
Transformer PyTorch
This repository includes software from https://github.com/facebookresearch/fairseq
licensed under the BSD License.
PyTorch/NLP/Transformer/PATENTS
deleted
100644 → 0
View file @
c056df78
Additional Grant of Patent Rights Version 2
"Software" means the fairseq software distributed by Facebook, Inc.
Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software
("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable
(subject to the termination provision below) license under any Necessary
Claims, to make, have made, use, sell, offer to sell, import, and otherwise
transfer the Software. For avoidance of doubt, no license is granted under
Facebook’s rights in any patent claims that are infringed by (i) modifications
to the Software made by you or any third party or (ii) the Software in
combination with any software or other technology.
The license granted hereunder will terminate, automatically and without notice,
if you (or any of your subsidiaries, corporate affiliates or agents) initiate
directly or indirectly, or take a direct financial interest in, any Patent
Assertion: (i) against Facebook or any of its subsidiaries or corporate
affiliates, (ii) against any party if such Patent Assertion arises in whole or
in part from any software, technology, product or service of Facebook or any of
its subsidiaries or corporate affiliates, or (iii) against any party relating
to the Software. Notwithstanding the foregoing, if Facebook or any of its
subsidiaries or corporate affiliates files a lawsuit alleging patent
infringement against you in the first instance, and you respond by filing a
patent infringement counterclaim in that lawsuit against that party that is
unrelated to the Software, the license granted hereunder will not terminate
under section (i) of this paragraph due to such counterclaim.
A "Necessary Claim" is a claim of a patent owned by Facebook that is
necessarily infringed by the Software standing alone.
A "Patent Assertion" is any lawsuit or other action alleging direct, indirect,
or contributory infringement or inducement to infringe any patent, including a
cross-claim or counterclaim.
PyTorch/NLP/Transformer/README.md.old
deleted
100644 → 0
View file @
c056df78
This diff is collapsed.
Click to expand it.
PyTorch/NLP/Transformer/average_valid_loss.png
deleted
100644 → 0
View file @
c056df78
18.4 KB
PyTorch/NLP/Transformer/bleu_relationship.png
deleted
100644 → 0
View file @
c056df78
10.2 KB
PyTorch/NLP/Transformer/decorrelation_threshold.png
deleted
100644 → 0
View file @
c056df78
173 KB
PyTorch/NLP/Transformer/distributed_train.py
deleted
100644 → 0
View file @
c056df78
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
socket
import
subprocess
from
train
import
main
as
single_process_main
from
fairseq
import
distributed_utils
,
options
def
main
(
args
):
if
args
.
distributed_init_method
is
None
and
args
.
distributed_port
>
0
:
# We can determine the init method automatically for Slurm.
node_list
=
os
.
environ
.
get
(
'SLURM_JOB_NODELIST'
)
if
node_list
is
not
None
:
try
:
hostnames
=
subprocess
.
check_output
([
'scontrol'
,
'show'
,
'hostnames'
,
node_list
])
args
.
distributed_init_method
=
'tcp://{host}:{port}'
.
format
(
host
=
hostnames
.
split
()[
0
].
decode
(
'utf-8'
),
port
=
args
.
distributed_port
)
args
.
distributed_rank
=
int
(
os
.
environ
.
get
(
'SLURM_PROCID'
))
args
.
device_id
=
int
(
os
.
environ
.
get
(
'SLURM_LOCALID'
))
except
subprocess
.
CalledProcessError
as
e
:
# scontrol failed
raise
e
except
FileNotFoundError
as
e
:
# Slurm is not installed
pass
if
args
.
distributed_init_method
is
None
:
raise
ValueError
(
'--distributed-init-method or --distributed-port '
'must be specified for distributed training'
)
args
.
distributed_rank
=
distributed_utils
.
distributed_init
(
args
)
args
.
device_id
=
int
(
os
.
environ
.
get
(
'LOCAL_RANK'
,
args
.
local_rank
))
print
(
'| initialized host {} as rank {} and device id {}'
.
format
(
socket
.
gethostname
(),
args
.
distributed_rank
,
args
.
device_id
))
single_process_main
(
args
)
if
__name__
==
'__main__'
:
parser
=
options
.
get_training_parser
()
args
=
options
.
parse_args_and_arch
(
parser
)
main
(
args
)
PyTorch/NLP/Transformer/fairseq/__init__.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from
.multiprocessing_pdb
import
pdb
__all__
=
[
'pdb'
]
PyTorch/NLP/Transformer/fairseq/criterions.py
deleted
100644 → 0
View file @
c056df78
import
torch.nn.functional
as
F
from
torch.nn.modules.loss
import
_Loss
class
CrossEntropyCriterion
(
_Loss
):
def
__init__
(
self
,
args
):
super
().
__init__
()
self
.
padding_idx
=
args
.
padding_idx
def
forward
(
self
,
norm_probs
,
target
,
reduce
=
True
):
"""Compute the loss for the given sample.
"""
lprobs
=
norm_probs
.
view
(
-
1
,
norm_probs
.
size
(
-
1
))
target
=
target
.
view
(
-
1
)
loss
=
F
.
nll_loss
(
lprobs
,
target
,
size_average
=
False
,
ignore_index
=
self
.
padding_idx
,
reduce
=
reduce
)
return
loss
class
LabelSmoothedCrossEntropyCriterion
(
_Loss
):
def
__init__
(
self
,
args
):
super
().
__init__
()
self
.
eps
=
args
.
label_smoothing
self
.
padding_idx
=
args
.
padding_idx
def
forward
(
self
,
norm_probs
,
target
,
reduce
=
True
):
"""Compute the loss for the given sample.
"""
target
=
target
.
view
(
-
1
,
1
)
lprobs
=
norm_probs
.
view
(
-
1
,
norm_probs
.
size
(
-
1
))
non_pad_mask
=
target
.
ne
(
self
.
padding_idx
)
nll_loss
=
-
lprobs
.
gather
(
dim
=-
1
,
index
=
target
)[
non_pad_mask
]
smooth_loss
=
-
lprobs
.
sum
(
dim
=-
1
,
keepdim
=
True
)[
non_pad_mask
]
if
reduce
:
nll_loss
=
nll_loss
.
sum
()
smooth_loss
=
smooth_loss
.
sum
()
eps_i
=
self
.
eps
/
lprobs
.
size
(
-
1
)
loss
=
(
1.
-
self
.
eps
)
*
nll_loss
+
eps_i
*
smooth_loss
return
loss
CRITERION_REGISTRY
=
{
'label_smoothed_cross_entropy'
:
LabelSmoothedCrossEntropyCriterion
,
'cross_entropy'
:
CrossEntropyCriterion
,
}
PyTorch/NLP/Transformer/fairseq/data/__init__.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.dictionary
import
Dictionary
from
.indexed_dataset
import
IndexedDataset
,
IndexedInMemoryDataset
,
IndexedRawTextDataset
# noqa: F401
from
.language_pair_dataset
import
LanguagePairDataset
,
load_dataset_splits
from
.data_utils
import
EpochBatchIterator
PyTorch/NLP/Transformer/fairseq/data/csrc/make_batches.cpp
deleted
100644 → 0
View file @
c056df78
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
namespace
at
{
namespace
native
{
namespace
{
bool
is_batch_full
(
int64_t
num_tokens
,
int64_t
max_tokens
,
int64_t
max_sentences
,
int64_t
batch_length
){
if
(
batch_length
==
0
){
return
false
;
}
else
if
(
batch_length
==
max_sentences
||
num_tokens
>
max_tokens
){
return
true
;
}
else
{
return
false
;
}
}
}
std
::
vector
<
std
::
vector
<
int64_t
>
>
make_batches
(
py
::
array_t
<
int64_t
>
src_lengths
,
py
::
array_t
<
int64_t
>
tgt_lengths
,
py
::
array_t
<
int64_t
>
idx_list
,
int64_t
max_tokens
,
int64_t
max_sentences
,
uint64_t
bsz_mult
,
int64_t
max_len
){
std
::
vector
<
std
::
vector
<
int64_t
>
>
batches
;
auto
src_l
=
src_lengths
.
unchecked
<
1
>
();
auto
tgt_l
=
tgt_lengths
.
unchecked
<
1
>
();
auto
idx_l
=
idx_list
.
unchecked
<
1
>
();
AT_ASSERTM
(
src_l
.
shape
(
0
)
==
tgt_l
.
shape
(
0
),
"tgt_list and src_list should have the same shape"
);
AT_ASSERTM
(
idx_l
.
shape
(
0
)
==
tgt_l
.
shape
(
0
),
"idx_list and tgt_list should have the same shape"
);
ssize_t
nelem
=
src_l
.
shape
(
0
);
int64_t
sample_len
=
0
;
std
::
vector
<
int64_t
>
sample_lens
;
std
::
vector
<
int64_t
>
batch
;
for
(
ssize_t
i
=
0
;
i
<
nelem
;
i
++
){
int64_t
idx
=
idx_l
(
i
);
int64_t
sample_num_tokens
=
std
::
max
(
src_l
(
idx
),
tgt_l
(
idx
));
if
(
sample_num_tokens
>
max_len
)
continue
;
sample_len
=
std
::
max
(
sample_len
,
sample_num_tokens
);
sample_lens
.
push_back
(
sample_num_tokens
);
int64_t
num_tokens
=
(
batch
.
size
()
+
1
)
*
sample_len
;
if
(
is_batch_full
(
num_tokens
,
max_tokens
,
max_sentences
,
batch
.
size
())){
int64_t
mode_len
=
std
::
max
(
batch
.
size
()
/
bsz_mult
*
bsz_mult
,
batch
.
size
()
%
bsz_mult
);
std
::
vector
<
int64_t
>
new_batch
;
new_batch
.
reserve
(
mode_len
);
std
::
copy
(
batch
.
begin
()
+
mode_len
,
batch
.
end
(),
std
::
back_inserter
(
new_batch
));
batch
.
erase
(
batch
.
begin
()
+
mode_len
,
batch
.
end
());
sample_lens
.
erase
(
sample_lens
.
begin
(),
sample_lens
.
begin
()
+
mode_len
);
//sample_len always contains at least one element
sample_len
=
*
std
::
max_element
(
sample_lens
.
begin
(),
sample_lens
.
end
());
batches
.
push_back
(
batch
);
batch
=
new_batch
;
}
batch
.
push_back
(
idx
);
}
if
(
batch
.
size
()
>
0
)
batches
.
push_back
(
batch
);
return
batches
;
}
}}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
){
m
.
def
(
"make_batches"
,
&
at
::
native
::
make_batches
);
}
PyTorch/NLP/Transformer/fairseq/data/data_utils.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
contextlib
import
itertools
import
os
import
numpy
as
np
import
torch
import
fairseq.data.batch_C
import
sys
from
.dictionary
import
Dictionary
def
infer_language_pair
(
path
):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
src
,
dst
=
None
,
None
for
filename
in
os
.
listdir
(
path
):
parts
=
filename
.
split
(
'.'
)
if
len
(
parts
)
>=
3
and
len
(
parts
[
1
].
split
(
'-'
))
==
2
:
return
parts
[
1
].
split
(
'-'
)
return
src
,
dst
def
load_dictionaries
(
args
):
if
args
.
source_lang
is
None
or
args
.
target_lang
is
None
:
args
.
source_lang
,
args
.
target_lang
=
infer_language_pair
(
args
.
data
)
if
args
.
source_lang
is
None
or
args
.
target_lang
is
None
:
raise
Exception
(
'Could not infer language pair, please provide it explicitly'
)
# load dictionaries
src_dict
=
Dictionary
.
load
(
os
.
path
.
join
(
args
.
data
,
'dict.{}.txt'
.
format
(
args
.
source_lang
)))
tgt_dict
=
Dictionary
.
load
(
os
.
path
.
join
(
args
.
data
,
'dict.{}.txt'
.
format
(
args
.
target_lang
)))
assert
src_dict
.
pad
()
==
tgt_dict
.
pad
()
assert
src_dict
.
eos
()
==
tgt_dict
.
eos
()
assert
src_dict
.
unk
()
==
tgt_dict
.
unk
()
args
.
src_vocab_size
=
len
(
src_dict
)
args
.
tgt_vocab_size
=
len
(
tgt_dict
)
args
.
padding_idx
=
src_dict
.
pad
()
print
(
'| [{}] dictionary: {} types'
.
format
(
args
.
source_lang
,
len
(
src_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
args
.
target_lang
,
len
(
tgt_dict
)))
return
src_dict
,
tgt_dict
class
ShardedIterator
(
object
):
"""A sharded wrapper around an iterable (padded to length)."""
def
__init__
(
self
,
iterable
,
num_shards
,
shard_id
,
fill_value
=
None
):
if
shard_id
<
0
or
shard_id
>=
num_shards
:
raise
ValueError
(
'shard_id must be between 0 and num_shards'
)
self
.
_sharded_len
=
len
(
iterable
)
//
num_shards
if
len
(
iterable
)
%
num_shards
>
0
:
self
.
_sharded_len
+=
1
self
.
itr
=
itertools
.
zip_longest
(
range
(
self
.
_sharded_len
),
itertools
.
islice
(
iterable
,
shard_id
,
len
(
iterable
),
num_shards
),
fillvalue
=
fill_value
,
)
def
__len__
(
self
):
return
self
.
_sharded_len
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
return
next
(
self
.
itr
)[
1
]
class
CountingIterator
(
object
):
"""Wrapper around an iterable that maintains the iteration count."""
def
__init__
(
self
,
iterable
):
self
.
iterable
=
iterable
self
.
count
=
0
self
.
itr
=
iter
(
self
)
def
__len__
(
self
):
return
len
(
self
.
iterable
)
def
__iter__
(
self
):
for
x
in
self
.
iterable
:
self
.
count
+=
1
yield
x
def
__next__
(
self
):
return
next
(
self
.
itr
)
def
has_next
(
self
):
return
self
.
count
<
len
(
self
)
def
skip
(
self
,
num_to_skip
):
next
(
itertools
.
islice
(
self
.
itr
,
num_to_skip
,
num_to_skip
),
None
)
return
self
def
collate_tokens
(
values
,
pad_idx
,
eos_idx
,
left_pad
,
move_eos_to_beginning
=
False
,
pad_sequence
=
1
):
"""Convert a list of 1d tensors into a padded 2d tensor."""
#size = max(v.size(0) for v in values)
orig_size
=
max
(
v
.
size
(
0
)
for
v
in
values
)
size
=
0
if
pad_sequence
>
1
:
size
=
orig_size
//
pad_sequence
*
pad_sequence
if
orig_size
%
pad_sequence
>
0
:
size
+=
pad_sequence
else
:
size
=
orig_size
res
=
values
[
0
].
new
(
len
(
values
),
size
).
fill_
(
pad_idx
)
def
copy_tensor
(
src
,
dst
):
assert
dst
.
numel
()
==
src
.
numel
()
if
move_eos_to_beginning
:
assert
src
[
-
1
]
==
eos_idx
dst
[
0
]
=
eos_idx
dst
[
1
:]
=
src
[:
-
1
]
else
:
dst
.
copy_
(
src
)
for
i
,
v
in
enumerate
(
values
):
copy_tensor
(
v
,
res
[
i
][
size
-
len
(
v
):]
if
left_pad
else
res
[
i
][:
len
(
v
)])
return
res
def
collate
(
samples
,
pad_idx
,
eos_idx
,
left_pad_source
=
True
,
left_pad_target
=
False
,
pad_sequence
=
1
):
if
len
(
samples
)
==
0
:
return
{}
def
merge
(
key
,
left_pad
,
move_eos_to_beginning
=
False
):
return
collate_tokens
(
[
s
[
key
]
for
s
in
samples
],
pad_idx
,
eos_idx
,
left_pad
,
move_eos_to_beginning
,
pad_sequence
,
)
id
=
torch
.
LongTensor
([
s
[
'id'
]
for
s
in
samples
])
src_tokens
=
merge
(
'source'
,
left_pad
=
left_pad_source
)
# sort by descending source length
src_lengths
=
torch
.
LongTensor
([
s
[
'source'
].
numel
()
for
s
in
samples
])
src_lengths
,
sort_order
=
src_lengths
.
sort
(
descending
=
True
)
id
=
id
.
index_select
(
0
,
sort_order
)
src_tokens
=
src_tokens
.
index_select
(
0
,
sort_order
)
prev_output_tokens
=
None
target
=
None
if
samples
[
0
].
get
(
'target'
,
None
)
is
not
None
:
target
=
merge
(
'target'
,
left_pad
=
left_pad_target
)
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens
=
merge
(
'target'
,
left_pad
=
left_pad_target
,
move_eos_to_beginning
=
True
,
)
prev_output_tokens
=
prev_output_tokens
.
index_select
(
0
,
sort_order
)
target
=
target
.
index_select
(
0
,
sort_order
)
ntokens
=
sum
(
len
(
s
[
'target'
])
for
s
in
samples
)
else
:
ntokens
=
sum
(
len
(
s
[
'source'
])
for
s
in
samples
)
return
{
'id'
:
id
,
'ntokens'
:
ntokens
,
'net_input'
:
{
'src_tokens'
:
src_tokens
,
'src_lengths'
:
src_lengths
,
'prev_output_tokens'
:
prev_output_tokens
,
},
'target'
:
target
,
}
def
get_dummy_batch
(
num_tokens
,
src_dict
,
tgt_dict
,
src_len
=
128
,
tgt_len
=
128
,
left_pad_source
=
True
,
left_pad_target
=
False
,
pad_sequence
=
1
):
bsz
=
num_tokens
//
max
(
src_len
,
tgt_len
)
dummy_samples
=
[
{
'id'
:
i
,
'source'
:
src_dict
.
dummy_sentence
(
src_len
),
'target'
:
tgt_dict
.
dummy_sentence
(
tgt_len
)
if
tgt_dict
is
not
None
else
None
,
}
for
i
in
range
(
bsz
)
]
return
collate
(
dummy_samples
,
pad_idx
=
src_dict
.
pad
(),
eos_idx
=
src_dict
.
eos
(),
left_pad_source
=
left_pad_source
,
left_pad_target
=
left_pad_target
,
pad_sequence
=
pad_sequence
,
)
class
EpochBatchIterator
(
object
):
"""Iterate over a FairseqDataset and yield batches bucketed by size.
Batches may contain sequences of different lengths. This iterator can be
reused across multiple epochs with the next_epoch_itr() method.
Args:
dataset: a FairseqDataset
max_tokens: max number of tokens in each batch
max_sentences: max number of sentences in each batch
max_positions: max sentence length supported by the model
required_batch_size_multiple: require batch size to be a multiple of N
seed: seed for random number generator for reproducibility
num_shards: shard the data iterator into N shards
shard_id: which shard of the data iterator to return
"""
def
__init__
(
self
,
dataset
,
max_tokens
=
None
,
max_sentences
=
None
,
max_positions
=
None
,
required_batch_size_multiple
=
1
,
seed
=
1
,
num_shards
=
1
,
shard_id
=
0
,
epoch
=
0
):
self
.
dataset
=
dataset
self
.
max_tokens
=
max_tokens
if
max_tokens
is
not
None
else
float
(
'Inf'
)
self
.
max_sentences
=
max_sentences
if
max_sentences
is
not
None
else
float
(
'Inf'
)
self
.
max_positions
=
max_positions
self
.
bsz_mult
=
required_batch_size_multiple
self
.
seed
=
seed
self
.
num_shards
=
num_shards
self
.
shard_id
=
shard_id
self
.
epoch
=
epoch
self
.
_cur_epoch_itr
=
None
self
.
_next_epoch_itr
=
None
with
numpy_seed
(
self
.
seed
):
indices
=
self
.
dataset
.
ordered_indices
(
self
.
seed
,
self
.
epoch
)
#need integer, rather than float('Inf') values
max_sentences
=
max_sentences
if
max_sentences
is
not
None
else
sys
.
maxsize
max_positions_num
=
1024
max_tokens
=
max_tokens
if
max_tokens
is
not
None
else
sys
.
maxsize
#Following line is workaround due to the fact we cannot pass None object as argument
tgt_sizes
=
self
.
dataset
.
tgt_sizes
if
self
.
dataset
.
tgt_sizes
is
not
None
else
self
.
dataset
.
src_sizes
batches
=
fairseq
.
data
.
batch_C
.
make_batches
(
self
.
dataset
.
src_sizes
,
tgt_sizes
,
indices
,
max_tokens
,
max_sentences
,
self
.
bsz_mult
,
max_positions_num
)
self
.
frozen_batches
=
tuple
(
batches
)
def
__len__
(
self
):
return
len
(
self
.
frozen_batches
)
def
next_epoch_itr
(
self
,
shuffle
=
True
):
"""Shuffle batches and return a new iterator over the dataset."""
if
self
.
_next_epoch_itr
is
not
None
:
self
.
_cur_epoch_itr
=
self
.
_next_epoch_itr
self
.
_next_epoch_itr
=
None
else
:
self
.
epoch
+=
1
self
.
_cur_epoch_itr
=
self
.
_get_iterator_for_epoch
(
self
.
epoch
,
shuffle
)
return
self
.
_cur_epoch_itr
def
end_of_epoch
(
self
):
return
not
self
.
_cur_epoch_itr
.
has_next
()
@
property
def
iterations_in_epoch
(
self
):
if
self
.
_cur_epoch_itr
is
not
None
:
return
self
.
_cur_epoch_itr
.
count
elif
self
.
_next_epoch_itr
is
not
None
:
return
self
.
_next_epoch_itr
.
count
return
0
def
state_dict
(
self
):
return
{
'epoch'
:
self
.
epoch
,
'iterations_in_epoch'
:
self
.
iterations_in_epoch
,
}
def
load_state_dict
(
self
,
state_dict
):
self
.
epoch
=
state_dict
[
'epoch'
]
itr_pos
=
state_dict
.
get
(
'iterations_in_epoch'
,
0
)
if
itr_pos
>
0
:
# fast-forward epoch iterator
itr
=
self
.
_get_iterator_for_epoch
(
self
.
epoch
,
state_dict
.
get
(
'shuffle'
,
True
))
if
itr_pos
<
len
(
itr
):
self
.
_next_epoch_itr
=
itr
.
skip
(
itr_pos
)
def
_get_iterator_for_epoch
(
self
,
epoch
,
shuffle
):
if
shuffle
:
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with
numpy_seed
(
self
.
seed
+
epoch
):
batches
=
list
(
self
.
frozen_batches
)
# copy
np
.
random
.
shuffle
(
batches
)
else
:
batches
=
self
.
frozen_batches
return
CountingIterator
(
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset
,
collate_fn
=
self
.
dataset
.
collater
,
num_workers
=
1
,
batch_sampler
=
ShardedIterator
(
batches
,
self
.
num_shards
,
self
.
shard_id
,
fill_value
=
[]),
))
@
contextlib
.
contextmanager
def
numpy_seed
(
seed
):
"""Context manager which seeds the NumPy PRNG with the specified seed and
restores the state afterward"""
if
seed
is
None
:
yield
return
state
=
np
.
random
.
get_state
()
np
.
random
.
seed
(
seed
)
try
:
yield
finally
:
np
.
random
.
set_state
(
state
)
PyTorch/NLP/Transformer/fairseq/data/fairseq_dataset.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
torch.utils.data
class
FairseqDataset
(
torch
.
utils
.
data
.
Dataset
):
"""A dataset that provides helpers for batching."""
def
__getitem__
(
self
,
index
):
raise
NotImplementedError
def
__len__
(
self
):
raise
NotImplementedError
def
collater
(
self
,
samples
):
"""Merge a list of samples to form a mini-batch."""
raise
NotImplementedError
def
num_tokens
(
self
,
index
):
"""Return an example's length (number of tokens), used for batching."""
raise
NotImplementedError
def
ordered_indices
(
self
,
seed
=
None
,
epoch
=
0
):
"""Ordered indices for batching."""
raise
NotImplementedError
def
valid_size
(
self
,
index
,
max_positions
):
"""Check if an example's size is valid according to max_positions."""
raise
NotImplementedError
PyTorch/NLP/Transformer/fairseq/data/indexed_dataset.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
os
import
struct
import
numpy
as
np
import
torch
from
fairseq.tokenizer
import
Tokenizer
def
read_longs
(
f
,
n
):
a
=
np
.
empty
(
n
,
dtype
=
np
.
int64
)
f
.
readinto
(
a
)
return
a
def
write_longs
(
f
,
a
):
f
.
write
(
np
.
array
(
a
,
dtype
=
np
.
int64
))
dtypes
=
{
1
:
np
.
uint8
,
2
:
np
.
int8
,
3
:
np
.
int16
,
4
:
np
.
int32
,
5
:
np
.
int64
,
6
:
np
.
float
,
7
:
np
.
double
,
}
def
code
(
dtype
):
for
k
in
dtypes
.
keys
():
if
dtypes
[
k
]
==
dtype
:
return
k
def
index_file_path
(
prefix_path
):
return
prefix_path
+
'.idx'
def
data_file_path
(
prefix_path
):
return
prefix_path
+
'.bin'
class
IndexedDataset
(
torch
.
utils
.
data
.
Dataset
):
"""Loader for TorchNet IndexedDataset"""
def
__init__
(
self
,
path
,
fix_lua_indexing
=
False
):
super
().
__init__
()
self
.
fix_lua_indexing
=
fix_lua_indexing
with
open
(
index_file_path
(
path
),
'rb'
)
as
f
:
magic
=
f
.
read
(
8
)
assert
magic
==
b
'TNTIDX
\x00\x00
'
version
=
f
.
read
(
8
)
assert
struct
.
unpack
(
'<Q'
,
version
)
==
(
1
,)
code
,
self
.
element_size
=
struct
.
unpack
(
'<QQ'
,
f
.
read
(
16
))
self
.
dtype
=
dtypes
[
code
]
self
.
size
,
self
.
s
=
struct
.
unpack
(
'<QQ'
,
f
.
read
(
16
))
self
.
dim_offsets
=
read_longs
(
f
,
self
.
size
+
1
)
self
.
data_offsets
=
read_longs
(
f
,
self
.
size
+
1
)
self
.
sizes
=
read_longs
(
f
,
self
.
s
)
self
.
read_data
(
path
)
def
read_data
(
self
,
path
):
self
.
data_file
=
open
(
data_file_path
(
path
),
'rb'
,
buffering
=
0
)
def
check_index
(
self
,
i
):
if
i
<
0
or
i
>=
self
.
size
:
raise
IndexError
(
'index out of range'
)
def
__del__
(
self
):
self
.
data_file
.
close
()
def
__getitem__
(
self
,
i
):
self
.
check_index
(
i
)
tensor_size
=
self
.
sizes
[
self
.
dim_offsets
[
i
]:
self
.
dim_offsets
[
i
+
1
]]
a
=
np
.
empty
(
tensor_size
,
dtype
=
self
.
dtype
)
self
.
data_file
.
seek
(
self
.
data_offsets
[
i
]
*
self
.
element_size
)
self
.
data_file
.
readinto
(
a
)
item
=
torch
.
from_numpy
(
a
).
long
()
if
self
.
fix_lua_indexing
:
item
-=
1
# subtract 1 for 0-based indexing
return
item
def
__len__
(
self
):
return
self
.
size
@
staticmethod
def
exists
(
path
):
return
(
os
.
path
.
exists
(
index_file_path
(
path
))
and
os
.
path
.
exists
(
data_file_path
(
path
))
)
class
IndexedInMemoryDataset
(
IndexedDataset
):
"""Loader for TorchNet IndexedDataset, keeps all the data in memory"""
def
read_data
(
self
,
path
):
self
.
data_file
=
open
(
data_file_path
(
path
),
'rb'
)
self
.
buffer
=
np
.
empty
(
self
.
data_offsets
[
-
1
],
dtype
=
self
.
dtype
)
self
.
data_file
.
readinto
(
self
.
buffer
)
self
.
data_file
.
close
()
if
self
.
fix_lua_indexing
:
self
.
buffer
-=
1
# subtract 1 for 0-based indexing
def
__del__
(
self
):
pass
def
__getitem__
(
self
,
i
):
self
.
check_index
(
i
)
tensor_size
=
self
.
sizes
[
self
.
dim_offsets
[
i
]:
self
.
dim_offsets
[
i
+
1
]]
a
=
np
.
empty
(
tensor_size
,
dtype
=
self
.
dtype
)
np
.
copyto
(
a
,
self
.
buffer
[
self
.
data_offsets
[
i
]:
self
.
data_offsets
[
i
+
1
]])
return
torch
.
from_numpy
(
a
).
long
()
class
IndexedRawTextDataset
(
IndexedDataset
):
"""Takes a text file as input and binarizes it in memory at instantiation.
Original lines are also kept in memory"""
def
__init__
(
self
,
path
,
dictionary
,
append_eos
=
True
,
reverse_order
=
False
):
self
.
tokens_list
=
[]
self
.
lines
=
[]
self
.
sizes
=
[]
self
.
append_eos
=
append_eos
self
.
reverse_order
=
reverse_order
self
.
read_data
(
path
,
dictionary
)
self
.
size
=
len
(
self
.
tokens_list
)
def
read_data
(
self
,
path
,
dictionary
):
with
open
(
path
,
'r'
)
as
f
:
for
line
in
f
:
self
.
lines
.
append
(
line
.
strip
(
'
\n
'
))
tokens
=
Tokenizer
.
tokenize
(
line
,
dictionary
,
add_if_not_exist
=
False
,
append_eos
=
self
.
append_eos
,
reverse_order
=
self
.
reverse_order
,
).
long
()
self
.
tokens_list
.
append
(
tokens
)
self
.
sizes
.
append
(
len
(
tokens
))
self
.
sizes
=
np
.
array
(
self
.
sizes
)
def
__getitem__
(
self
,
i
):
self
.
check_index
(
i
)
return
self
.
tokens_list
[
i
]
def
get_original_text
(
self
,
i
):
self
.
check_index
(
i
)
return
self
.
lines
[
i
]
def
__del__
(
self
):
pass
def
__len__
(
self
):
return
self
.
size
@
staticmethod
def
exists
(
path
):
return
os
.
path
.
exists
(
path
)
class
IndexedDatasetBuilder
(
object
):
element_sizes
=
{
np
.
uint8
:
1
,
np
.
int8
:
1
,
np
.
int16
:
2
,
np
.
int32
:
4
,
np
.
int64
:
8
,
np
.
float
:
4
,
np
.
double
:
8
}
def
__init__
(
self
,
out_file
,
dtype
=
np
.
int32
):
self
.
out_file
=
open
(
out_file
,
'wb'
)
self
.
dtype
=
dtype
self
.
data_offsets
=
[
0
]
self
.
dim_offsets
=
[
0
]
self
.
sizes
=
[]
self
.
element_size
=
self
.
element_sizes
[
self
.
dtype
]
def
add_item
(
self
,
tensor
):
# +1 for Lua compatibility
bytes
=
self
.
out_file
.
write
(
np
.
array
(
tensor
.
numpy
()
+
1
,
dtype
=
self
.
dtype
))
self
.
data_offsets
.
append
(
self
.
data_offsets
[
-
1
]
+
bytes
/
self
.
element_size
)
for
s
in
tensor
.
size
():
self
.
sizes
.
append
(
s
)
self
.
dim_offsets
.
append
(
self
.
dim_offsets
[
-
1
]
+
len
(
tensor
.
size
()))
def
finalize
(
self
,
index_file
):
self
.
out_file
.
close
()
index
=
open
(
index_file
,
'wb'
)
index
.
write
(
b
'TNTIDX
\x00\x00
'
)
index
.
write
(
struct
.
pack
(
'<Q'
,
1
))
index
.
write
(
struct
.
pack
(
'<QQ'
,
code
(
self
.
dtype
),
self
.
element_size
))
index
.
write
(
struct
.
pack
(
'<QQ'
,
len
(
self
.
data_offsets
)
-
1
,
len
(
self
.
sizes
)))
write_longs
(
index
,
self
.
dim_offsets
)
write_longs
(
index
,
self
.
data_offsets
)
write_longs
(
index
,
self
.
sizes
)
index
.
close
()
PyTorch/NLP/Transformer/fairseq/data/language_pair_dataset.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
from
torch.utils.data
import
Dataset
,
ConcatDataset
from
.
import
data_utils
import
itertools
import
os
import
sys
from
fairseq.data
import
IndexedInMemoryDataset
,
IndexedRawTextDataset
class
LanguagePairDataset
(
Dataset
):
"""A pair of torch.utils.data.Datasets."""
def
__init__
(
self
,
src
,
src_sizes
,
src_dict
,
tgt
=
None
,
tgt_sizes
=
None
,
tgt_dict
=
None
,
left_pad_source
=
True
,
left_pad_target
=
False
,
max_source_positions
=
1024
,
max_target_positions
=
1024
,
pad_sequence
=
1
,
shuffle
=
True
,
):
if
tgt_dict
is
not
None
:
assert
src_dict
.
pad
()
==
tgt_dict
.
pad
()
assert
src_dict
.
eos
()
==
tgt_dict
.
eos
()
assert
src_dict
.
unk
()
==
tgt_dict
.
unk
()
self
.
src
=
src
self
.
tgt
=
tgt
self
.
src_sizes
=
np
.
array
(
src_sizes
)
self
.
tgt_sizes
=
np
.
array
(
tgt_sizes
)
if
tgt_sizes
is
not
None
else
None
self
.
src_dict
=
src_dict
self
.
tgt_dict
=
tgt_dict
self
.
left_pad_source
=
left_pad_source
self
.
left_pad_target
=
left_pad_target
self
.
max_source_positions
=
max_source_positions
self
.
max_target_positions
=
max_target_positions
self
.
pad_sequence
=
pad_sequence
self
.
shuffle
=
shuffle
print
(
"| Sentences are being padded to multiples of: {}"
.
format
(
self
.
pad_sequence
),
file
=
sys
.
stderr
)
def
__getitem__
(
self
,
index
):
return
{
'id'
:
index
,
'source'
:
self
.
src
[
index
],
'target'
:
self
.
tgt
[
index
]
if
self
.
tgt
is
not
None
else
None
,
}
def
__len__
(
self
):
return
len
(
self
.
src
)
def
collater
(
self
,
samples
):
"""Merge a list of samples to form a mini-batch."""
return
data_utils
.
collate
(
samples
,
pad_idx
=
self
.
src_dict
.
pad
(),
eos_idx
=
self
.
src_dict
.
eos
(),
left_pad_source
=
self
.
left_pad_source
,
left_pad_target
=
self
.
left_pad_target
,
pad_sequence
=
self
.
pad_sequence
,
)
def
num_tokens
(
self
,
index
):
"""Return an example's length (number of tokens), used for batching."""
orig_size
=
max
(
self
.
src_sizes
[
index
],
self
.
tgt_sizes
[
index
]
if
self
.
tgt_sizes
is
not
None
else
0
)
assert
self
.
pad_sequence
>
0
,
"Padding multiple has to be greater than 0"
size
=
0
if
self
.
pad_sequence
>
1
:
size
=
orig_size
//
self
.
pad_sequence
*
self
.
pad_sequence
if
orig_size
%
self
.
pad_sequence
>
0
:
size
+=
self
.
pad_sequence
else
:
size
=
orig_size
return
size
#return max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)
def
ordered_indices
(
self
,
seed
=
None
,
epoch
=
1
):
"""Ordered indices for batching."""
if
self
.
shuffle
:
indices
=
np
.
random
.
RandomState
(
seed
+
epoch
).
permutation
(
len
(
self
))
else
:
indices
=
np
.
arange
(
len
(
self
))
if
self
.
tgt_sizes
is
not
None
:
indices
=
indices
[
np
.
argsort
(
self
.
tgt_sizes
[
indices
],
kind
=
'mergesort'
)]
return
indices
[
np
.
argsort
(
self
.
src_sizes
[
indices
],
kind
=
'mergesort'
)]
def
valid_size
(
self
,
index
,
max_positions
):
"""Check if an example's size is valid according to max_positions."""
max_source_positions
,
max_target_positions
=
self
.
_get_max_positions
(
max_positions
)
return
(
self
.
src_sizes
[
index
]
<=
max_source_positions
and
(
self
.
tgt_sizes
is
None
or
self
.
tgt_sizes
[
index
]
<=
max_target_positions
)
)
def
_get_max_positions
(
self
,
max_positions
):
if
max_positions
is
None
:
return
self
.
max_source_positions
,
self
.
max_target_positions
assert
len
(
max_positions
)
==
2
max_src_pos
,
max_tgt_pos
=
max_positions
return
min
(
self
.
max_source_positions
,
max_src_pos
),
min
(
self
.
max_target_positions
,
max_tgt_pos
)
def
load_dataset
(
args
,
datasets
,
split
,
src_dict
,
tgt_dict
,
combine
=
False
):
"""Load a dataset split."""
def
split_exists
(
split
,
src
,
tgt
,
lang
):
filename
=
os
.
path
.
join
(
args
.
data
,
'{}.{}-{}.{}'
.
format
(
split
,
src
,
tgt
,
lang
))
if
args
.
raw_text
and
IndexedRawTextDataset
.
exists
(
filename
):
return
True
elif
not
args
.
raw_text
and
IndexedInMemoryDataset
.
exists
(
filename
):
return
True
return
False
def
indexed_dataset
(
path
,
dictionary
):
if
args
.
raw_text
:
return
IndexedRawTextDataset
(
path
,
dictionary
)
elif
IndexedInMemoryDataset
.
exists
(
path
):
return
IndexedInMemoryDataset
(
path
,
fix_lua_indexing
=
True
)
return
None
src_datasets
=
[]
tgt_datasets
=
[]
for
k
in
itertools
.
count
():
split_k
=
split
+
(
str
(
k
)
if
k
>
0
else
''
)
# infer langcode
src
,
tgt
=
args
.
source_lang
,
args
.
target_lang
if
split_exists
(
split_k
,
src
,
tgt
,
src
):
prefix
=
os
.
path
.
join
(
args
.
data
,
'{}.{}-{}.'
.
format
(
split_k
,
src
,
tgt
))
elif
split_exists
(
split_k
,
tgt
,
src
,
src
):
prefix
=
os
.
path
.
join
(
args
.
data
,
'{}.{}-{}.'
.
format
(
split_k
,
tgt
,
src
))
else
:
if
k
>
0
:
break
else
:
raise
FileNotFoundError
(
'Dataset not found: {} ({})'
.
format
(
split
,
args
.
data
))
src_datasets
.
append
(
indexed_dataset
(
prefix
+
src
,
src_dict
))
tgt_datasets
.
append
(
indexed_dataset
(
prefix
+
tgt
,
tgt_dict
))
print
(
'| {} {} {} examples'
.
format
(
args
.
data
,
split_k
,
len
(
src_datasets
[
-
1
])))
if
not
combine
:
break
assert
len
(
src_datasets
)
==
len
(
tgt_datasets
)
if
len
(
src_datasets
)
==
1
:
src_dataset
,
tgt_dataset
=
src_datasets
[
0
],
tgt_datasets
[
0
]
src_sizes
=
src_dataset
.
sizes
tgt_sizes
=
tgt_dataset
.
sizes
else
:
src_dataset
=
ConcatDataset
(
src_datasets
)
tgt_dataset
=
ConcatDataset
(
tgt_datasets
)
src_sizes
=
np
.
concatenate
([
ds
.
sizes
for
ds
in
src_datasets
])
tgt_sizes
=
np
.
concatenate
([
ds
.
sizes
for
ds
in
tgt_datasets
])
datasets
[
split
]
=
LanguagePairDataset
(
src_dataset
,
src_sizes
,
src_dict
,
tgt_dataset
,
tgt_sizes
,
tgt_dict
,
left_pad_source
=
args
.
left_pad_source
,
left_pad_target
=
args
.
left_pad_target
,
max_source_positions
=
args
.
max_source_positions
,
max_target_positions
=
args
.
max_target_positions
,
pad_sequence
=
args
.
pad_sequence
,
)
def
load_dataset_splits
(
args
,
splits
,
src_dict
,
tgt_dict
):
datasets
=
{}
for
split
in
splits
:
if
split
==
'train'
:
load_dataset
(
args
,
datasets
,
split
,
src_dict
,
tgt_dict
,
combine
=
True
)
else
:
for
k
in
itertools
.
count
():
split_k
=
split
+
(
str
(
k
)
if
k
>
0
else
''
)
try
:
load_dataset
(
args
,
datasets
,
split_k
,
src_dict
,
tgt_dict
,
combine
=
False
)
except
FileNotFoundError
as
e
:
if
k
>
0
:
break
raise
e
return
datasets
PyTorch/NLP/Transformer/fairseq/data/token_block_dataset.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
numpy
as
np
import
torch
class
TokenBlockDataset
(
torch
.
utils
.
data
.
Dataset
):
"""Break a 1d tensor of tokens into blocks.
The blocks are fetched from the original tensor so no additional memory is allocated.
Args:
tokens: 1d tensor of tokens to break into blocks
sizes: sentence lengths (required for 'complete' and 'eos')
block_size: maximum block size (ignored in 'eos' break mode)
break_mode: Mode used for breaking tokens. Values can be one of:
- 'none': break tokens into equally sized blocks (up to block_size)
- 'complete': break tokens into blocks (up to block_size) such that
blocks contains complete sentences, although block_size may be
exceeded if some sentences exceed block_size
- 'eos': each block contains one sentence (block_size is ignored)
include_targets: return next tokens as targets
"""
def
__init__
(
self
,
tokens
,
sizes
,
block_size
,
break_mode
=
None
,
include_targets
=
False
):
super
().
__init__
()
self
.
tokens
=
tokens
self
.
total_size
=
len
(
tokens
)
self
.
include_targets
=
include_targets
self
.
slice_indices
=
[]
if
break_mode
is
None
or
break_mode
==
'none'
:
length
=
math
.
ceil
(
len
(
tokens
)
/
block_size
)
def
block_at
(
i
):
start
=
i
*
block_size
end
=
min
(
start
+
block_size
,
len
(
tokens
))
return
(
start
,
end
)
self
.
slice_indices
=
[
block_at
(
i
)
for
i
in
range
(
length
)]
elif
break_mode
==
'complete'
:
assert
sizes
is
not
None
and
sum
(
sizes
)
==
len
(
tokens
),
'{} != {}'
.
format
(
sum
(
sizes
),
len
(
tokens
))
tok_idx
=
0
sz_idx
=
0
curr_size
=
0
while
sz_idx
<
len
(
sizes
):
if
curr_size
+
sizes
[
sz_idx
]
<=
block_size
or
curr_size
==
0
:
curr_size
+=
sizes
[
sz_idx
]
sz_idx
+=
1
else
:
self
.
slice_indices
.
append
((
tok_idx
,
tok_idx
+
curr_size
))
tok_idx
+=
curr_size
curr_size
=
0
if
curr_size
>
0
:
self
.
slice_indices
.
append
((
tok_idx
,
tok_idx
+
curr_size
))
elif
break_mode
==
'eos'
:
assert
sizes
is
not
None
and
sum
(
sizes
)
==
len
(
tokens
),
'{} != {}'
.
format
(
sum
(
sizes
),
len
(
tokens
))
curr
=
0
for
sz
in
sizes
:
# skip samples with just 1 example (which would be just the eos token)
if
sz
>
1
:
self
.
slice_indices
.
append
((
curr
,
curr
+
sz
))
curr
+=
sz
else
:
raise
ValueError
(
'Invalid break_mode: '
+
break_mode
)
self
.
sizes
=
np
.
array
([
e
-
s
for
s
,
e
in
self
.
slice_indices
])
def
__getitem__
(
self
,
index
):
s
,
e
=
self
.
slice_indices
[
index
]
item
=
torch
.
LongTensor
(
self
.
tokens
[
s
:
e
])
if
self
.
include_targets
:
# target is the sentence, for source, rotate item one token to the left (would start with eos)
if
s
==
0
:
source
=
np
.
concatenate
([
self
.
tokens
[
-
1
:],
self
.
tokens
[
0
:
e
-
1
]])
else
:
source
=
self
.
tokens
[
s
-
1
:
e
-
1
]
return
torch
.
LongTensor
(
source
),
item
return
item
def
__len__
(
self
):
return
len
(
self
.
slice_indices
)
Prev
1
2
3
4
5
…
30
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment